diff --git a/docstrange/__init__.py b/docstrange/__init__.py index 81b9934..79f3c14 100644 --- a/docstrange/__init__.py +++ b/docstrange/__init__.py @@ -1,5 +1,9 @@ """ Document Data Extractor - Extract structured data from any document into LLM-ready formats. + +For engineering drawing extraction use EngineeringDrawingPipeline directly — it is the +dedicated entry point for PDFs and images containing title blocks, dimensions, GD&T, BOM, +notes, and revision history. """ from .extractor import DocumentExtractor @@ -8,13 +12,56 @@ from .exceptions import ConversionError, UnsupportedFormatError from .config import InternalConfig +# Engineering drawing extraction surface +from .pipelines.engineering import EngineeringDrawingPipeline +from .schemas.engineering import ( + BBoxSchema, + ExtractionElement, + EngineeringDrawingResult, + DimensionElement, + TitleBlockField, + NoteElement, + GDTElement, + BOMRow, + RevisionEntry, +) +from .extractors import ( + BaseExtractor, + TitleBlockExtractor, + DimensionExtractor, + NoteExtractor, + GDTExtractor, + BOMExtractor, + RevisionExtractor, +) + __version__ = "1.1.5" __all__ = [ - "DocumentExtractor", - "ConversionResult", + # Generic document extraction + "DocumentExtractor", + "ConversionResult", "GPUConversionResult", "CloudConversionResult", - "ConversionError", - "UnsupportedFormatError", - "InternalConfig" -] \ No newline at end of file + "ConversionError", + "UnsupportedFormatError", + "InternalConfig", + # Engineering drawing extraction + "EngineeringDrawingPipeline", + "EngineeringDrawingResult", + "BBoxSchema", + "ExtractionElement", + "DimensionElement", + "TitleBlockField", + "NoteElement", + "GDTElement", + "BOMRow", + "RevisionEntry", + # Individual extractors (for custom pipelines) + "BaseExtractor", + "TitleBlockExtractor", + "DimensionExtractor", + "NoteExtractor", + "GDTExtractor", + "BOMExtractor", + "RevisionExtractor", +] \ No newline at end of file diff --git a/docstrange/api/__init__.py b/docstrange/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docstrange/api/main.py b/docstrange/api/main.py new file mode 100644 index 0000000..d3d7c3f --- /dev/null +++ b/docstrange/api/main.py @@ -0,0 +1,11 @@ +"""FastAPI application entry point. + +Run with: + uvicorn docstrange.api.main:app --reload --port 8000 + +Then visit http://localhost:8000/docs for the interactive API explorer. +""" + +from .routes import create_app + +app = create_app() diff --git a/docstrange/api/models.py b/docstrange/api/models.py new file mode 100644 index 0000000..44abd89 --- /dev/null +++ b/docstrange/api/models.py @@ -0,0 +1,86 @@ +"""Pydantic response models for the DocStrange Engineering API. + +These are thin wrappers / re-exports that let FastAPI generate accurate +OpenAPI schemas for every endpoint without duplicating the core schemas. +""" + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + +from ..schemas.engineering import ( + BBoxSchema, + BOMRow, + DimensionElement, + EngineeringDrawingResult, + ExtractionMetadata, + GDTElement, + NoteElement, + RevisionEntry, + TitleBlockField, +) + +__all__ = [ + "HealthResponse", + "FullExtractionResponse", + "OverlayBBox", + "OverlayAnnotation", + "OverlaySummary", + "OverlayImageSize", + "OverlayResponse", + # Re-exported schema types used as list element response models + "TitleBlockField", + "DimensionElement", + "NoteElement", + "GDTElement", + "BOMRow", + "RevisionEntry", +] + + +class HealthResponse(BaseModel): + status: str + service: str + version: str + + +class FullExtractionResponse(EngineeringDrawingResult): + """EngineeringDrawingResult extended with an optional overlay payload.""" + overlay_json: Optional[Dict[str, Any]] = None + + +class OverlayBBox(BaseModel): + x: float + y: float + width: float + height: float + + +class OverlayAnnotation(BaseModel): + change_id: str + type: str + text: str + page: int + confidence: float + bbox: OverlayBBox + bbox_normalized: OverlayBBox + color: str + label: str + + +class OverlaySummary(BaseModel): + by_type: Dict[str, int] + total: int + + +class OverlayImageSize(BaseModel): + width: int + height: int + + +class OverlayResponse(BaseModel): + image_size: OverlayImageSize + page_filter: Optional[int] = None + total_annotations: int + summary: OverlaySummary + annotations: List[OverlayAnnotation] diff --git a/docstrange/api/routes.py b/docstrange/api/routes.py new file mode 100644 index 0000000..13b3ea4 --- /dev/null +++ b/docstrange/api/routes.py @@ -0,0 +1,319 @@ +"""FastAPI routes for engineering drawing extraction. + +Install: pip install 'docstrange[engineering]' +Run: uvicorn docstrange.api.main:app --reload --port 8000 +Docs: http://localhost:8000/docs +""" + +import asyncio +import functools +import logging +import os +import tempfile +from pathlib import Path +from typing import List, Optional + +logger = logging.getLogger(__name__) + +try: + from fastapi import Depends, FastAPI, File, HTTPException, Query, UploadFile + from fastapi.middleware.cors import CORSMiddleware + from fastapi.responses import JSONResponse + _FASTAPI_AVAILABLE = True +except ImportError: + _FASTAPI_AVAILABLE = False + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +MAX_FILE_BYTES = 50 * 1024 * 1024 # 50 MB +ALLOWED_EXTENSIONS = {".pdf", ".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp", ".webp"} + +# --------------------------------------------------------------------------- +# Lazy singletons (initialized on first request, not at import time) +# --------------------------------------------------------------------------- + +_pipeline = None +_overlay_gen = None + + +def _get_pipeline(): + global _pipeline + if _pipeline is None: + from ..pipelines.engineering import EngineeringDrawingPipeline + _pipeline = EngineeringDrawingPipeline() + return _pipeline + + +def _get_overlay_gen(): + global _overlay_gen + if _overlay_gen is None: + from ..overlays.generator import OverlayGenerator + _overlay_gen = OverlayGenerator() + return _overlay_gen + + +# --------------------------------------------------------------------------- +# Async helpers +# --------------------------------------------------------------------------- + +def _cleanup(path: str) -> None: + try: + os.unlink(path) + except OSError: + pass + + +async def _save_upload(file: UploadFile) -> str: + """Validate, read, and write the uploaded file to a temp path.""" + ext = Path(file.filename or "").suffix.lower() or ".pdf" + if ext not in ALLOWED_EXTENSIONS: + raise HTTPException( + status_code=415, + detail=f"Unsupported file type '{ext}'. Allowed: {sorted(ALLOWED_EXTENSIONS)}", + ) + + content = await file.read() + if len(content) > MAX_FILE_BYTES: + raise HTTPException( + status_code=413, + detail=f"File too large ({len(content) // (1024*1024)} MB). Maximum is 50 MB.", + ) + if not content: + raise HTTPException(status_code=400, detail="Uploaded file is empty.") + + tmp = tempfile.NamedTemporaryFile(suffix=ext, delete=False) + try: + tmp.write(content) + finally: + tmp.close() + return tmp.name + + +async def _extract(file_path: str, extractors: Optional[List[str]] = None): + """Run the pipeline in a thread pool so the event loop is not blocked.""" + pipeline = _get_pipeline() + ext = Path(file_path).suffix.lower() + loop = asyncio.get_running_loop() + + try: + if ext == ".pdf": + fn = functools.partial(pipeline.extract_from_pdf, file_path, extractors=extractors) + else: + fn = functools.partial(pipeline.extract_from_image, file_path, extractors=extractors) + return await loop.run_in_executor(None, fn) + except FileNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except Exception as exc: + logger.error("Extraction failed for %s: %s", file_path, exc, exc_info=True) + raise HTTPException(status_code=500, detail="Extraction failed. Check server logs.") from exc + + +# --------------------------------------------------------------------------- +# Application factory +# --------------------------------------------------------------------------- + +def create_app() -> "FastAPI": + """Create and return the FastAPI application instance.""" + if not _FASTAPI_AVAILABLE: + raise ImportError( + "FastAPI is required for the HTTP API. " + "Install with: pip install 'docstrange[engineering]'" + ) + + from .models import ( + BOMRow, DimensionElement, FullExtractionResponse, GDTElement, + HealthResponse, NoteElement, OverlayResponse, RevisionEntry, TitleBlockField, + ) + + app = FastAPI( + title="DocStrange Engineering API", + description=( + "Modular engineering drawing extraction server. " + "Upload a PDF or image to extract title block fields, dimensions, GD&T symbols, " + "Bill of Materials rows, notes, and revision history — each with bounding boxes " + "and confidence scores." + ), + version="1.0.0", + ) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["POST", "GET"], + allow_headers=["*"], + ) + + # ------------------------------------------------------------------ + # Extraction endpoints + # ------------------------------------------------------------------ + + @app.post( + "/extract/title-block", + response_model=List[TitleBlockField], + summary="Extract title block fields", + tags=["extraction"], + ) + async def extract_title_block(file: UploadFile = File(...)): + tmp = await _save_upload(file) + try: + result = await _extract(tmp, extractors=["title_block"]) + return result.title_block + finally: + _cleanup(tmp) + + @app.post( + "/extract/dimensions", + response_model=List[DimensionElement], + summary="Extract dimension annotations", + tags=["extraction"], + ) + async def extract_dimensions(file: UploadFile = File(...)): + tmp = await _save_upload(file) + try: + result = await _extract(tmp, extractors=["dimensions"]) + return result.dimensions + finally: + _cleanup(tmp) + + @app.post( + "/extract/notes", + response_model=List[NoteElement], + summary="Extract general notes and numbered annotations", + tags=["extraction"], + ) + async def extract_notes(file: UploadFile = File(...)): + tmp = await _save_upload(file) + try: + result = await _extract(tmp, extractors=["notes"]) + return result.notes + finally: + _cleanup(tmp) + + @app.post( + "/extract/gdt", + response_model=List[GDTElement], + summary="Extract GD&T symbols and feature control frames", + tags=["extraction"], + ) + async def extract_gdt(file: UploadFile = File(...)): + tmp = await _save_upload(file) + try: + result = await _extract(tmp, extractors=["gdt"]) + return result.gdt + finally: + _cleanup(tmp) + + @app.post( + "/extract/bom", + response_model=List[BOMRow], + summary="Extract Bill of Materials rows", + tags=["extraction"], + ) + async def extract_bom(file: UploadFile = File(...)): + tmp = await _save_upload(file) + try: + result = await _extract(tmp, extractors=["bom"]) + return result.bom + finally: + _cleanup(tmp) + + @app.post( + "/extract/revisions", + response_model=List[RevisionEntry], + summary="Extract revision block entries", + tags=["extraction"], + ) + async def extract_revisions(file: UploadFile = File(...)): + tmp = await _save_upload(file) + try: + result = await _extract(tmp, extractors=["revisions"]) + return result.revisions + finally: + _cleanup(tmp) + + @app.post( + "/extract/full", + response_model=FullExtractionResponse, + summary="Run all extractors and return the complete result", + tags=["extraction"], + ) + async def extract_full( + file: UploadFile = File(...), + include_overlays: bool = Query( + default=False, description="Attach UI overlay JSON to the response" + ), + image_width: int = Query( + default=0, ge=0, description="Source image width in pixels (for overlay normalisation)" + ), + image_height: int = Query( + default=0, ge=0, description="Source image height in pixels (for overlay normalisation)" + ), + page: Optional[int] = Query( + default=None, ge=1, description="Filter overlay annotations to a single page" + ), + ): + tmp = await _save_upload(file) + try: + result = await _extract(tmp) + data = result.model_dump() + if include_overlays: + gen = _get_overlay_gen() + data["overlay_json"] = gen.generate( + result, + image_width=image_width, + image_height=image_height, + image_path=tmp, + page=page, + ) + return data + finally: + _cleanup(tmp) + + # ------------------------------------------------------------------ + # Overlay endpoint + # ------------------------------------------------------------------ + + @app.post( + "/generate/overlays", + response_model=OverlayResponse, + summary="Generate UI-ready overlay JSON with bounding boxes", + tags=["overlays"], + ) + async def generate_overlays( + file: UploadFile = File(...), + image_width: int = Query(default=0, ge=0, description="Image width in pixels"), + image_height: int = Query(default=0, ge=0, description="Image height in pixels"), + page: Optional[int] = Query( + default=None, ge=1, description="Return annotations for a single page only" + ), + ): + tmp = await _save_upload(file) + try: + result = await _extract(tmp) + gen = _get_overlay_gen() + return gen.generate( + result, + image_width=image_width, + image_height=image_height, + image_path=tmp, + page=page, + ) + finally: + _cleanup(tmp) + + # ------------------------------------------------------------------ + # Health check + # ------------------------------------------------------------------ + + @app.get( + "/health", + response_model=HealthResponse, + summary="Service health check", + tags=["meta"], + ) + async def health(): + return HealthResponse(status="ok", service="docstrange-engineering", version="1.0.0") + + return app diff --git a/docstrange/extractors/__init__.py b/docstrange/extractors/__init__.py new file mode 100644 index 0000000..3060601 --- /dev/null +++ b/docstrange/extractors/__init__.py @@ -0,0 +1,17 @@ +from .base import BaseExtractor +from .dimensions import DimensionExtractor +from .title_block import TitleBlockExtractor +from .notes import NoteExtractor +from .gdt import GDTExtractor +from .bom import BOMExtractor +from .revisions import RevisionExtractor + +__all__ = [ + "BaseExtractor", + "DimensionExtractor", + "TitleBlockExtractor", + "NoteExtractor", + "GDTExtractor", + "BOMExtractor", + "RevisionExtractor", +] diff --git a/docstrange/extractors/base.py b/docstrange/extractors/base.py new file mode 100644 index 0000000..2064a20 --- /dev/null +++ b/docstrange/extractors/base.py @@ -0,0 +1,23 @@ +"""Base extractor contract for all engineering drawing extractors.""" + +from abc import ABC, abstractmethod +from typing import List, TYPE_CHECKING + +if TYPE_CHECKING: + from ..pipeline.layout_detector import LayoutElement + from ..schemas.engineering import BBoxSchema, ExtractionElement + + +class BaseExtractor(ABC): + """All domain extractors receive a flat List[LayoutElement] and return typed schema objects.""" + + @abstractmethod + def extract(self, elements: List["LayoutElement"]) -> list: + pass + + def _to_bbox(self, el: "LayoutElement") -> "BBoxSchema": + from ..schemas.engineering import BBoxSchema + return BBoxSchema(x=float(el.x), y=float(el.y), width=float(el.width), height=float(el.height)) + + def _confidence(self, el: "LayoutElement", multiplier: float = 1.0) -> float: + return min(float(getattr(el, 'confidence', 0.8)) * multiplier, 1.0) diff --git a/docstrange/extractors/bom.py b/docstrange/extractors/bom.py new file mode 100644 index 0000000..ce4cc74 --- /dev/null +++ b/docstrange/extractors/bom.py @@ -0,0 +1,119 @@ +"""Bill of Materials (BOM) extractor for engineering drawings.""" + +import re +from typing import Dict, List, Optional, Tuple + +from .base import BaseExtractor + +_BOM_HEADER = re.compile(r'\b(BILL\s+OF\s+MATERIALS?|BOM|PARTS?\s+LIST)\b', re.I) + +_COLUMN_MAP: Dict[str, re.Pattern] = { + "item": re.compile(r'\b(ITEM|NO\.?|#)\b', re.I), + "qty": re.compile(r'\b(QTY|QUANTITY|NO\.\s*REQ)\b', re.I), + "part_number": re.compile(r'\b(PART\s*NO\.?|P/?N|PART\s*#)\b', re.I), + "description": re.compile(r'\b(DESCRIPTION|DESC\.?|NAME)\b', re.I), + "material": re.compile(r'\b(MATERIAL|MAT\.?)\b', re.I), +} + + +def _row_key(el) -> float: + """Round y to nearest 10px bucket for row grouping.""" + return round(float(el.y) / 10) * 10 + + +def _group_into_rows(elements: List) -> List[List]: + """Cluster elements that share the same y-bucket into rows, sorted by x.""" + from collections import defaultdict + buckets: Dict[float, list] = defaultdict(list) + for el in elements: + buckets[_row_key(el)].append(el) + rows = [] + for key in sorted(buckets): + row = sorted(buckets[key], key=lambda e: float(e.x)) + rows.append(row) + return rows + + +def _infer_column_order(header_row: List) -> Dict[int, str]: + """Map column index → field name from the header row.""" + mapping: Dict[int, str] = {} + for idx, el in enumerate(header_row): + text = (el.text or '').strip() + for field, pattern in _COLUMN_MAP.items(): + if pattern.search(text): + mapping[idx] = field + break + return mapping + + +class BOMExtractor(BaseExtractor): + """Extracts Bill of Materials tables from engineering drawings.""" + + def extract(self, elements: List) -> List: + from ..schemas.engineering import BOMRow + + if not elements: + return [] + + sorted_els = sorted( + (e for e in elements if e.text and e.text.strip()), + key=lambda e: (e.y, e.x), + ) + + # Find the BOM header element + header_idx = None + header_el = None + for i, el in enumerate(sorted_els): + if _BOM_HEADER.search(el.text or ''): + header_idx = i + header_el = el + break + + if header_el is None: + return [] + + # Collect elements below the header with similar x-range + hx1, hx2 = float(header_el.x), float(header_el.x) + float(header_el.width) + x_margin = max(float(header_el.width) * 3, 200) # generous margin + body_els = [ + el for el in sorted_els[header_idx + 1:] + if float(el.y) > float(header_el.y) + and float(el.x) < hx2 + x_margin + ] + + if not body_els: + return [] + + rows = _group_into_rows(body_els) + if not rows: + return [] + + # First row is assumed to be column headers + col_map = _infer_column_order(rows[0]) + has_mapping = bool(col_map) + + bom_rows = [] + data_rows = rows[1:] if has_mapping else rows + avg_conf = sum(float(getattr(e, 'confidence', 0.8)) for e in body_els) / max(len(body_els), 1) + mult = 1.0 if has_mapping else 0.75 + + for row_els in data_rows: + cells = [el.text.strip() for el in row_els] + if not any(cells): + continue + + kwargs: Dict = {"raw_cells": cells} + for idx, field in col_map.items(): + if idx < len(cells): + kwargs[field] = cells[idx] + + # Use bbox of the leftmost cell in the row for position + bbox = self._to_bbox(row_els[0]) if row_els else self._to_bbox(header_el) + bom_rows.append(BOMRow( + confidence=min(avg_conf * mult, 1.0), + bbox=bbox, + **{k: v for k, v in kwargs.items() if k in + {"item_number", "quantity", "part_number", "description", "material", "raw_cells"}}, + )) + + return bom_rows diff --git a/docstrange/extractors/dimensions.py b/docstrange/extractors/dimensions.py new file mode 100644 index 0000000..8dd5965 --- /dev/null +++ b/docstrange/extractors/dimensions.py @@ -0,0 +1,140 @@ +"""Dimension extractor for engineering drawings.""" + +import re +from typing import List, Optional, Tuple + +from .base import BaseExtractor + + +# Compiled at class level — do not duplicate inside methods +_LINEAR = re.compile( + r'([+\-]?\d+\.?\d*)\s*' + r'(?:±\s*(\d+\.?\d*)' # ±tolerance + r'|([+\-]\d+\.?\d*)/([+\-]\d+\.?\d*))?' # bilateral +x/-y + r'\s*(mm|in|")?', + re.IGNORECASE, +) +_DIAMETER = re.compile(r'[⌀Ø]\s*(\d+\.?\d*)', re.IGNORECASE) +_DIAMETER_TEXT = re.compile(r'\bDIA\.?\s+(\d+\.?\d*)', re.IGNORECASE) +_ANGULAR = re.compile(r'(\d+\.?\d*)\s*(?:°|DEG\.?)', re.IGNORECASE) +_RADIAL = re.compile(r'\bR\s*(\d+\.?\d*)\b', re.IGNORECASE) +# Reject: pure part numbers, long digit strings without units +_PART_NUMBER = re.compile(r'^\s*[A-Z]{0,4}\d{6,}\s*$') +_DIMENSION_GUARD = re.compile(r'\d') + + +class DimensionExtractor(BaseExtractor): + """Extracts dimension annotations from engineering drawing layout elements.""" + + def extract(self, elements: List) -> List: + from ..schemas.engineering import DimensionElement + results = [] + for el in elements: + if getattr(el, 'element_type', '') == 'picture': + continue + text = (el.text or '').strip() + if not text or not _DIMENSION_GUARD.search(text): + continue + if _PART_NUMBER.match(text): + continue + + found = self._parse(text, el) + results.extend(found) + return results + + def _parse(self, text: str, el) -> List: + from ..schemas.engineering import DimensionElement + results = [] + + # Diameter (⌀ prefix takes priority) + for m in _DIAMETER.finditer(text): + results.append(DimensionElement( + text=m.group(0), + nominal=float(m.group(1)), + dimension_type="diameter", + confidence=self._confidence(el, 1.0), + bbox=self._to_bbox(el), + )) + + for m in _DIAMETER_TEXT.finditer(text): + results.append(DimensionElement( + text=m.group(0), + nominal=float(m.group(1)), + dimension_type="diameter", + confidence=self._confidence(el, 0.95), + bbox=self._to_bbox(el), + )) + + # Angular + for m in _ANGULAR.finditer(text): + results.append(DimensionElement( + text=m.group(0), + nominal=float(m.group(1)), + dimension_type="angular", + unit="deg", + confidence=self._confidence(el, 1.0), + bbox=self._to_bbox(el), + )) + + # Radial + for m in _RADIAL.finditer(text): + results.append(DimensionElement( + text=m.group(0), + nominal=float(m.group(1)), + dimension_type="radial", + confidence=self._confidence(el, 1.0), + bbox=self._to_bbox(el), + )) + + # Linear (only if no specialised type already matched this token) + matched_spans = {r.text for r in results} + for m in _LINEAR.finditer(text): + raw = m.group(0).strip() + if not raw or raw in matched_spans: + continue + nominal_str = m.group(1) + if not nominal_str: + continue + try: + nominal = float(nominal_str) + except ValueError: + continue + # Skip very large integers that look like dates/serial numbers + if nominal > 9999 and not m.group(5): + continue + + upper_tol = lower_tol = None + multiplier = 0.8 # partial match default + + if m.group(2): # ± symmetric tolerance + try: + t = float(m.group(2)) + upper_tol = t + lower_tol = -t + multiplier = 1.0 + except ValueError: + pass + elif m.group(3) and m.group(4): # bilateral + try: + upper_tol = float(m.group(3)) + lower_tol = float(m.group(4)) + multiplier = 1.0 + except ValueError: + pass + + unit = m.group(5) or None + if unit == '"': + unit = 'in' + + results.append(DimensionElement( + text=raw, + nominal=nominal, + upper_tolerance=upper_tol, + lower_tolerance=lower_tol, + unit=unit, + dimension_type="linear", + confidence=self._confidence(el, multiplier), + bbox=self._to_bbox(el), + )) + + return results diff --git a/docstrange/extractors/gdt.py b/docstrange/extractors/gdt.py new file mode 100644 index 0000000..f77a5ce --- /dev/null +++ b/docstrange/extractors/gdt.py @@ -0,0 +1,138 @@ +"""GD&T (Geometric Dimensioning and Tolerancing) extractor for engineering drawings.""" + +import re +from typing import Dict, List, Optional + +from .base import BaseExtractor + +# Unicode GD&T symbols mapped to their names +_GDT_SYMBOL_MAP: Dict[str, str] = { + "⊙": "position", + "⊕": "position", + "○": "roundness", + "⌭": "cylindricity", + "⌒": "profile_surface", + "⌓": "profile_line", + "∥": "parallelism", + "⊥": "perpendicularity", + "∠": "angularity", + "⌒": "circularity", + "⌤": "flatness", + "↗": "circular_runout", + "⌀": "diameter", + "Ⓜ": "maximum_material_condition", + "Ⓛ": "least_material_condition", + "Ⓕ": "free_state", + "Ⓟ": "projected_tolerance_zone", +} + +# Text abbreviation fallbacks when OCR misses the symbol +_GDT_TEXT = re.compile( + r'\b(PERP(?:ENDICULARITY)?|PAR(?:ALLELISM)?|POS(?:ITION)?' + r'|FLAT(?:NESS)?|CYL(?:INDRICITY)?|ROUND(?:NESS)?' + r'|CONC(?:ENTRICITY)?|SYM(?:METRY)?|ANG(?:ULARITY)?' + r'|PROF(?:ILE)?|RUN(?:OUT)?|STR(?:AIGHTNESS)?|CIRC(?:ULARITY)?)' + r'\s+(\d+\.?\d*)', + re.IGNORECASE, +) + +# Feature control frame: |symbol|tolerance|datum(s)| +_FEATURE_CTRL = re.compile(r'\|([^|\n]{1,20})\|([^|\n]{1,20})(?:\|([^|\n]{1,20}))?\|?') + +# Numeric tolerance adjacent to a symbol +_TOLERANCE_VALUE = re.compile(r'(\d+\.?\d*)') + +# Datum reference: isolated single uppercase letter (A–Z) not part of a longer word +_DATUM_REF = re.compile(r'(? List: + from ..schemas.engineering import GDTElement + + results = [] + for el in elements: + text = (el.text or '').strip() + if not text: + continue + + # Check for feature control frame pattern + for m in _FEATURE_CTRL.finditer(text): + raw_sym = m.group(1).strip() + tol = m.group(2).strip() + datum = m.group(3).strip() if m.group(3) else None + symbol = self._resolve_symbol(raw_sym) + if symbol: + results.append(GDTElement( + text=m.group(0), + symbol=symbol, + tolerance_value=tol or None, + datum_reference=datum, + confidence=self._confidence(el, 1.0), + bbox=self._to_bbox(el), + )) + + # Check for Unicode symbols + for char, sym_name in _GDT_SYMBOL_MAP.items(): + if char in text: + tol = self._extract_tolerance(text) + datum = self._extract_datum(text) + results.append(GDTElement( + text=text, + symbol=sym_name, + tolerance_value=tol, + datum_reference=datum, + confidence=self._confidence(el, 0.95), + bbox=self._to_bbox(el), + )) + break # one entry per element for unicode match + + # Check for text abbreviations + for m in _GDT_TEXT.finditer(text): + abbrev = m.group(1).upper() + sym_name = _ABBREV_TO_NAME.get(abbrev, abbrev.lower()) + tol = m.group(2) + datum = self._extract_datum(text) + results.append(GDTElement( + text=m.group(0), + symbol=sym_name, + tolerance_value=tol, + datum_reference=datum, + confidence=self._confidence(el, 0.9), + bbox=self._to_bbox(el), + )) + + return results + + def _resolve_symbol(self, raw: str) -> Optional[str]: + for char, name in _GDT_SYMBOL_MAP.items(): + if char in raw: + return name + abbrev = raw.upper().split()[0] if raw.split() else raw.upper() + return _ABBREV_TO_NAME.get(abbrev) + + def _extract_tolerance(self, text: str) -> Optional[str]: + m = _TOLERANCE_VALUE.search(text) + return m.group(1) if m else None + + def _extract_datum(self, text: str) -> Optional[str]: + datums = _DATUM_REF.findall(text) + return datums[0] if datums else None diff --git a/docstrange/extractors/notes.py b/docstrange/extractors/notes.py new file mode 100644 index 0000000..d92a4be --- /dev/null +++ b/docstrange/extractors/notes.py @@ -0,0 +1,76 @@ +"""Notes extractor for engineering drawings.""" + +import re +from typing import List, Optional + +from .base import BaseExtractor + +_NOTE_HEADER = re.compile(r'\b(GENERAL\s*NOTES?|NOTES?:?)\b', re.I) +_NUMBERED = re.compile(r'^\s*(\d+)\.\s+(.+)', re.DOTALL) +_CONTINUATION = re.compile(r'^\s{2,}') + + +class NoteExtractor(BaseExtractor): + """Extracts notes and annotations from engineering drawings.""" + + def extract(self, elements: List) -> List: + from ..schemas.engineering import NoteElement + + results = [] + in_notes_section = False + sorted_els = sorted( + (e for e in elements if e.text and e.text.strip()), + key=lambda e: (e.y, e.x), + ) + + for el in sorted_els: + text = el.text.strip() + + # Detect notes section header + if _NOTE_HEADER.search(text): + in_notes_section = True + # Don't emit the header itself as a note + continue + + if in_notes_section: + m = _NUMBERED.match(text) + if m: + results.append(NoteElement( + text=text, + note_number=int(m.group(1)), + is_general=False, + confidence=self._confidence(el, 1.0), + bbox=self._to_bbox(el), + )) + elif _CONTINUATION.match(el.text): + # Indented continuation — attach to last note or emit as general + if results: + # Merge into the last note's text (keep bbox of last note) + last = results[-1] + results[-1] = NoteElement( + text=last.text + " " + text, + note_number=last.note_number, + is_general=last.is_general, + confidence=last.confidence, + bbox=last.bbox, + ) + else: + results.append(NoteElement( + text=text, + is_general=True, + confidence=self._confidence(el, 0.85), + bbox=self._to_bbox(el), + )) + else: + # Non-numbered, non-indented — possible general note line + results.append(NoteElement( + text=text, + is_general=True, + confidence=self._confidence(el, 0.8), + bbox=self._to_bbox(el), + )) + # A clearly unrelated heading or separator ends the section + if len(text) < 4 or text.isupper() and not any(c.isdigit() for c in text): + in_notes_section = False + + return results diff --git a/docstrange/extractors/revisions.py b/docstrange/extractors/revisions.py new file mode 100644 index 0000000..00a772e --- /dev/null +++ b/docstrange/extractors/revisions.py @@ -0,0 +1,107 @@ +"""Revision block extractor for engineering drawings.""" + +import re +from typing import Dict, List, Optional + +from .base import BaseExtractor + +_REV_HEADER = re.compile( + r'\b(REV(?:ISION)?\s*(?:HISTORY|BLOCK|TABLE)?|CHANGE\s*LOG)\b', re.I +) +_REV_LETTER = re.compile(r'(? float: + return round(float(el.y) / 10) * 10 + + +def _group_into_rows(elements: List) -> List[List]: + from collections import defaultdict + buckets: Dict[float, list] = defaultdict(list) + for el in elements: + buckets[_row_key(el)].append(el) + rows = [] + for key in sorted(buckets): + rows.append(sorted(buckets[key], key=lambda e: float(e.x))) + return rows + + +def _parse_row(cells: List[str]) -> Dict[str, Optional[str]]: + """Heuristically infer which cell is revision / date / description / approver.""" + result: Dict[str, Optional[str]] = { + "revision": None, "date": None, "description": None, "approved_by": None, + } + for cell in cells: + cell = cell.strip() + if not cell: + continue + if result["date"] is None and _DATE_PATTERN.search(cell): + result["date"] = cell + elif result["revision"] is None and _REV_LETTER.fullmatch(cell): + result["revision"] = cell + elif result["approved_by"] is None and _APPROVER_HINT.search(cell) and len(cell) < 30: + result["approved_by"] = cell + elif result["description"] is None: + result["description"] = cell + return result + + +class RevisionExtractor(BaseExtractor): + """Extracts revision block entries from engineering drawings.""" + + def extract(self, elements: List) -> List: + from ..schemas.engineering import RevisionEntry + + if not elements: + return [] + + sorted_els = sorted( + (e for e in elements if e.text and e.text.strip()), + key=lambda e: (e.y, e.x), + ) + + # Find revision block header + header_el = None + header_idx = None + for i, el in enumerate(sorted_els): + if _REV_HEADER.search(el.text or ''): + header_el = el + header_idx = i + break + + if header_el is None: + return [] + + # Collect elements below the header + body_els = [ + el for el in sorted_els[header_idx + 1:] + if float(el.y) > float(header_el.y) + ] + + if not body_els: + return [] + + rows = _group_into_rows(body_els) + avg_conf = sum(float(getattr(e, 'confidence', 0.8)) for e in body_els) / max(len(body_els), 1) + results = [] + + for row_els in rows: + cells = [el.text.strip() for el in row_els] + if not any(cells): + continue + parsed = _parse_row(cells) + bbox = self._to_bbox(row_els[0]) if row_els else self._to_bbox(header_el) + results.append(RevisionEntry( + revision=parsed["revision"], + date=parsed["date"], + description=parsed["description"], + approved_by=parsed["approved_by"], + confidence=min(avg_conf, 1.0), + bbox=bbox, + )) + + return results diff --git a/docstrange/extractors/title_block.py b/docstrange/extractors/title_block.py new file mode 100644 index 0000000..32ecc78 --- /dev/null +++ b/docstrange/extractors/title_block.py @@ -0,0 +1,113 @@ +"""Title block extractor for engineering drawings.""" + +import re +from typing import List, Dict + +from .base import BaseExtractor + +_FIELD_KEYWORDS: Dict[str, re.Pattern] = { + "drawing_number": re.compile(r'\b(DWG\.?\s*NO\.?|DRAWING\s*NO\.?|DOC\s*NO\.?)\b', re.I), + "title": re.compile(r'\b(TITLE|PART\s*NAME)\b', re.I), + "scale": re.compile(r'\bSCALE\b', re.I), + "date": re.compile(r'\b(DATE|DRAWN)\b', re.I), + "revision": re.compile(r'\b(REV\.?|REVISION)\b', re.I), + "material": re.compile(r'\b(MATERIAL|MAT\.?)\b', re.I), + "drawn_by": re.compile(r'\b(DRAWN\s*BY|DRN\.?\s*BY)\b', re.I), + "approved_by": re.compile(r'\b(APPR\.?|APPROVED\s*BY)\b', re.I), + "sheet": re.compile(r'\bSHEET\b', re.I), + "tolerance": re.compile(r'\b(TOLERANCE|TOL\.?|UNLESS\s*OTHERWISE)\b', re.I), + "company": re.compile(r'\b(COMPANY|ORGANIZATION|ORG)\b', re.I), + "part_number": re.compile(r'\b(PART\s*NO\.?|P/N|PART\s*#)\b', re.I), + "surface_finish": re.compile(r'\b(SURFACE\s*FINISH|FINISH)\b', re.I), + "weight": re.compile(r'\bWEIGHT\b', re.I), +} + + +def _image_bounds(elements: List) -> tuple: + """Return (max_x, max_y) across all elements to estimate image dimensions.""" + max_x = max_y = 1.0 + for el in elements: + max_x = max(max_x, float(el.x) + float(el.width)) + max_y = max(max_y, float(el.y) + float(el.height)) + return max_x, max_y + + +def _in_title_block_zone(el, max_x: float, max_y: float) -> bool: + """Title block heuristic: lower-right 35% × lower 20% of image.""" + return float(el.x) > 0.65 * max_x or float(el.y) > 0.80 * max_y + + +def _match_field(text: str): + """Return (field_name, confidence_multiplier) or (None, 0).""" + for field_name, pattern in _FIELD_KEYWORDS.items(): + if pattern.search(text): + return field_name, 1.0 + return None, 0.0 + + +def _avg_line_height(elements: List) -> float: + heights = [float(el.height) for el in elements if float(el.height) > 0] + return (sum(heights) / len(heights)) if heights else 20.0 + + +class TitleBlockExtractor(BaseExtractor): + """Extracts title block fields using zone heuristics and keyword matching.""" + + def extract(self, elements: List) -> List: + from ..schemas.engineering import TitleBlockField + + if not elements: + return [] + + max_x, max_y = _image_bounds(elements) + avg_h = _avg_line_height(elements) + threshold = 1.5 * avg_h + + results = [] + # Filter to probable title block elements + candidates = [el for el in elements if el.text and el.text.strip()] + + # Sort by position for stable pairing + candidates_sorted = sorted(candidates, key=lambda e: (e.y, e.x)) + + for i, el in enumerate(candidates_sorted): + text = el.text.strip() + in_zone = _in_title_block_zone(el, max_x, max_y) + field_name, kw_mult = _match_field(text) + + if field_name: + mult = 1.0 if in_zone else 0.6 + # Try to find a nearby value element (same row, slightly to the right) + value_el = self._find_value_element(el, candidates_sorted, threshold) + field_value = value_el.text.strip() if value_el else text + + results.append(TitleBlockField( + text=text, + field_name=field_name, + field_value=field_value, + confidence=self._confidence(el, mult), + bbox=self._to_bbox(el), + )) + elif in_zone and text: + # Zone match without keyword — emit as "unknown" field + results.append(TitleBlockField( + text=text, + field_name="unknown", + field_value=text, + confidence=self._confidence(el, 0.5), + bbox=self._to_bbox(el), + )) + + return results + + def _find_value_element(self, label_el, candidates: List, threshold: float): + """Find the element spatially adjacent to label_el (same row, to the right).""" + lx, ly, lw, lh = float(label_el.x), float(label_el.y), float(label_el.width), float(label_el.height) + for el in candidates: + if el is label_el: + continue + ex, ey = float(el.x), float(el.y) + # Same row: y-centers within threshold, to the right + if abs(ey - ly) < threshold and ex > lx + lw * 0.5: + return el + return None diff --git a/docstrange/mcp_server/__init__.py b/docstrange/mcp_server/__init__.py new file mode 100644 index 0000000..7e398a4 --- /dev/null +++ b/docstrange/mcp_server/__init__.py @@ -0,0 +1,3 @@ +from .server import EngineeringMCPServer + +__all__ = ["EngineeringMCPServer"] diff --git a/docstrange/mcp_server/__main__.py b/docstrange/mcp_server/__main__.py new file mode 100644 index 0000000..5ef3895 --- /dev/null +++ b/docstrange/mcp_server/__main__.py @@ -0,0 +1,5 @@ +import asyncio +from .server import main + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docstrange/mcp_server/cache.py b/docstrange/mcp_server/cache.py new file mode 100644 index 0000000..72eb6e5 --- /dev/null +++ b/docstrange/mcp_server/cache.py @@ -0,0 +1,63 @@ +"""LRU extraction result cache keyed on file identity (path + mtime + size). + +Using filesystem stat rather than MD5 means zero hashing overhead on every +call. A file that is modified in-place will have a changed mtime/size and +automatically get a fresh extraction. +""" + +import os +from collections import OrderedDict +from typing import Any, List, Optional, Tuple + + +class ExtractionCache: + """Thread-unsafe LRU cache for EngineeringDrawingResult objects. + + The MCP server runs in a single asyncio event loop so no locking is needed. + """ + + def __init__(self, maxsize: int = 20) -> None: + self._cache: OrderedDict = OrderedDict() + self._maxsize = max(1, maxsize) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _key(self, path: str, extractors: Optional[List[str]]) -> Tuple: + stat = os.stat(path) + return (path, stat.st_mtime, stat.st_size, tuple(sorted(extractors or []))) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def get(self, path: str, extractors: Optional[List[str]]) -> Optional[Any]: + """Return cached result or ``None`` on a miss.""" + try: + key = self._key(path, extractors) + except OSError: + return None + if key not in self._cache: + return None + self._cache.move_to_end(key) + return self._cache[key] + + def put(self, path: str, extractors: Optional[List[str]], value: Any) -> None: + """Store *value* in the cache, evicting the LRU entry if full.""" + try: + key = self._key(path, extractors) + except OSError: + return + if key in self._cache: + self._cache.move_to_end(key) + else: + if len(self._cache) >= self._maxsize: + self._cache.popitem(last=False) + self._cache[key] = value + + def clear(self) -> None: + self._cache.clear() + + def __len__(self) -> int: + return len(self._cache) diff --git a/docstrange/mcp_server/server.py b/docstrange/mcp_server/server.py new file mode 100644 index 0000000..e92ac60 --- /dev/null +++ b/docstrange/mcp_server/server.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +"""Engineering Drawing Extraction MCP Server powered by DocStrange.""" + +import asyncio +import json +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp.types import TextContent, Tool + +from .cache import ExtractionCache +from .tools import TOOLS + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +_ALLOWED_EXT = {".pdf", ".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp", ".webp"} + + +class EngineeringMCPServer: + """MCP Server exposing DocStrange engineering drawing extraction tools. + + Architecture + ------------ + - Tool definitions live in ``tools.py`` (schema registry) + - Caching lives in ``cache.py`` (LRU, stat-based key) + - This file owns server lifecycle, validation, and dispatch routing + """ + + def __init__(self) -> None: + self.server = Server("docstrange-engineering") + self._pipeline = None + self._overlay_gen = None + self._cache = ExtractionCache(maxsize=20) + self._setup_handlers() + + # ------------------------------------------------------------------ + # Lazy singletons + # ------------------------------------------------------------------ + + def _get_pipeline(self): + if self._pipeline is None: + logger.info("Initialising EngineeringDrawingPipeline…") + from ..pipelines.engineering import EngineeringDrawingPipeline + self._pipeline = EngineeringDrawingPipeline() + return self._pipeline + + def _get_overlay_generator(self): + if self._overlay_gen is None: + from ..overlays.generator import OverlayGenerator + self._overlay_gen = OverlayGenerator() + return self._overlay_gen + + # ------------------------------------------------------------------ + # Core extraction (validation + cache) + # ------------------------------------------------------------------ + + def _extract(self, file_path: str, extractors: Optional[List[str]] = None): + """Validate, cache-check, and run extraction.""" + path = Path(file_path).resolve() + + if not path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + if not path.is_file(): + raise ValueError(f"Path is not a regular file: {file_path}") + + ext = path.suffix.lower() + if ext not in _ALLOWED_EXT: + raise ValueError( + f"Unsupported file type '{ext}'. Allowed: {sorted(_ALLOWED_EXT)}" + ) + + cached = self._cache.get(str(path), extractors) + if cached is not None: + logger.debug("Cache hit: %s", path.name) + return cached + + pipeline = self._get_pipeline() + if ext == ".pdf": + result = pipeline.extract_from_pdf(str(path), extractors=extractors) + else: + result = pipeline.extract_from_image(str(path), extractors=extractors) + + self._cache.put(str(path), extractors, result) + return result + + # ------------------------------------------------------------------ + # MCP handler registration + # ------------------------------------------------------------------ + + def _setup_handlers(self) -> None: + @self.server.list_tools() + async def list_tools() -> List[Tool]: + return TOOLS + + @self.server.call_tool() + async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: + return await self._dispatch(name, arguments) + + # ------------------------------------------------------------------ + # Dispatch routing + # ------------------------------------------------------------------ + + async def _dispatch(self, name: str, arguments: Dict[str, Any]) -> List[TextContent]: + file_path = arguments.get("file_path", "") + + # Helper: convert 0-sentinel to None for page filtering + raw_page = arguments.get("page", 0) + page: Optional[int] = int(raw_page) if raw_page else None + + try: + return await self._dispatch_inner(name, arguments, file_path, page) + except (FileNotFoundError, ValueError) as exc: + return [TextContent( + type="text", + text=json.dumps({"status": "error", "error": str(exc)}, indent=2), + )] + except Exception as exc: + logger.error("Tool '%s' failed: %s", name, exc, exc_info=True) + return [TextContent( + type="text", + text=json.dumps({"status": "error", "error": "Internal error. Check server logs."}, indent=2), + )] + + async def _dispatch_inner( + self, name: str, arguments: Dict[str, Any], file_path: str, page: Optional[int] + ) -> List[TextContent]: + if name == "extract_dimensions": + result = self._extract(file_path, extractors=["dimensions"]) + data = [el.model_dump() for el in result.dimensions] + + elif name == "extract_title_block": + result = self._extract(file_path, extractors=["title_block"]) + data = [el.model_dump() for el in result.title_block] + + elif name == "extract_notes": + result = self._extract(file_path, extractors=["notes"]) + data = [el.model_dump() for el in result.notes] + + elif name == "extract_gdt": + result = self._extract(file_path, extractors=["gdt"]) + data = [el.model_dump() for el in result.gdt] + + elif name == "extract_bom": + result = self._extract(file_path, extractors=["bom"]) + data = [row.model_dump() for row in result.bom] + + elif name == "extract_revisions": + result = self._extract(file_path, extractors=["revisions"]) + data = [entry.model_dump() for entry in result.revisions] + + elif name == "extract_full": + result = self._extract(file_path) + data = result.model_dump() + if arguments.get("include_overlays", False): + gen = self._get_overlay_generator() + data["overlay_json"] = gen.generate( + result, + image_width=arguments.get("image_width", 0), + image_height=arguments.get("image_height", 0), + image_path=file_path, + page=page, + ) + + elif name == "generate_overlays": + result = self._extract(file_path) + gen = self._get_overlay_generator() + data = gen.generate( + result, + image_width=arguments.get("image_width", 0), + image_height=arguments.get("image_height", 0), + image_path=file_path, + page=page, + ) + + else: + return [TextContent( + type="text", + text=json.dumps({"status": "error", "error": f"Unknown tool: {name}"}, indent=2), + )] + + return [TextContent(type="text", text=json.dumps(data, indent=2, default=str))] + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def run(self) -> None: + async with stdio_server() as (read_stream, write_stream): + await self.server.run( + read_stream, + write_stream, + initialization_options=self.server.create_initialization_options(), + ) + + +async def main() -> None: + server = EngineeringMCPServer() + await server.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docstrange/mcp_server/tools.py b/docstrange/mcp_server/tools.py new file mode 100644 index 0000000..fc4b141 --- /dev/null +++ b/docstrange/mcp_server/tools.py @@ -0,0 +1,191 @@ +"""MCP tool definitions for the DocStrange Engineering Drawing server. + +Keeping tool schemas in their own module means they can be read, tested, and +updated without touching server routing logic. +""" + +from typing import List + +from mcp.types import Tool + +TOOLS: List[Tool] = [ + Tool( + name="extract_dimensions", + description=( + "Extract dimension annotations (linear, angular, radial, diameter) " + "from an engineering drawing PDF or image file. Returns structured JSON " + "with text, type, confidence, bounding box, and page number for each entity." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the drawing file (PDF, PNG, JPG, TIFF, BMP).", + }, + "page": { + "type": "integer", + "description": "1-based page number to extract from. 0 = all pages (default).", + "default": 0, + }, + }, + "required": ["file_path"], + }, + ), + Tool( + name="extract_title_block", + description=( + "Extract title block fields (drawing number, title, scale, date, material, " + "revision, drawn by, approved by, sheet, tolerance, company, part number) " + "from an engineering drawing." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the drawing file.", + }, + }, + "required": ["file_path"], + }, + ), + Tool( + name="extract_notes", + description=( + "Extract general notes and numbered annotations from an engineering drawing." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the drawing file.", + }, + }, + "required": ["file_path"], + }, + ), + Tool( + name="extract_gdt", + description=( + "Extract GD&T (Geometric Dimensioning and Tolerancing) symbols, feature " + "control frames, tolerance values, and datum references from an engineering drawing." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the drawing file.", + }, + }, + "required": ["file_path"], + }, + ), + Tool( + name="extract_bom", + description=( + "Extract Bill of Materials (parts list) rows including item number, quantity, " + "part number, description, and material columns." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the drawing file.", + }, + }, + "required": ["file_path"], + }, + ), + Tool( + name="extract_revisions", + description=( + "Extract revision history block entries including revision letter, date, " + "description, and approver name." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the drawing file.", + }, + }, + "required": ["file_path"], + }, + ), + Tool( + name="extract_full", + description=( + "Run all extractors and return the complete EngineeringDrawingResult containing " + "title_block, dimensions, notes, gdt, bom, and revisions. Optionally include " + "UI overlay JSON for frontend rendering." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the drawing file.", + }, + "include_overlays": { + "type": "boolean", + "description": "Include overlay JSON for UI rendering.", + "default": False, + }, + "image_width": { + "type": "integer", + "description": "Source image width in pixels for overlay normalisation. 0 = skip.", + "default": 0, + }, + "image_height": { + "type": "integer", + "description": "Source image height in pixels for overlay normalisation. 0 = skip.", + "default": 0, + }, + "page": { + "type": "integer", + "description": "Filter overlay annotations to this 1-based page number. 0 = all pages.", + "default": 0, + }, + }, + "required": ["file_path"], + }, + ), + Tool( + name="generate_overlays", + description=( + "Generate UI-ready overlay JSON from an engineering drawing. Each detected entity " + "is annotated with normalised and pixel bounding boxes, colour-coded by type, " + "and tagged with its source page number." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the drawing file.", + }, + "image_width": { + "type": "integer", + "description": "Image width in pixels for coordinate normalisation. 0 = skip.", + "default": 0, + }, + "image_height": { + "type": "integer", + "description": "Image height in pixels for coordinate normalisation. 0 = skip.", + "default": 0, + }, + "page": { + "type": "integer", + "description": "Return annotations for this 1-based page number only. 0 = all pages.", + "default": 0, + }, + }, + "required": ["file_path"], + }, + ), +] diff --git a/docstrange/overlays/__init__.py b/docstrange/overlays/__init__.py new file mode 100644 index 0000000..04ce227 --- /dev/null +++ b/docstrange/overlays/__init__.py @@ -0,0 +1,3 @@ +from .generator import OverlayGenerator + +__all__ = ["OverlayGenerator"] diff --git a/docstrange/overlays/generator.py b/docstrange/overlays/generator.py new file mode 100644 index 0000000..6c5b86b --- /dev/null +++ b/docstrange/overlays/generator.py @@ -0,0 +1,133 @@ +"""Overlay JSON generator for UI rendering of engineering drawing extractions.""" + +from typing import Any, Dict, List, Optional + + +_COLOR_MAP: Dict[str, str] = { + "dimension": "#FF6B35", + "title_block": "#2196F3", + "note": "#4CAF50", + "gdt": "#9C27B0", + "bom": "#FF9800", + "revision": "#607D8B", + "unknown": "#9E9E9E", +} + + +def _normalize_bbox(x: float, y: float, w: float, h: float, img_w: int, img_h: int) -> Dict: + def _clamp(v: float) -> float: + return max(0.0, min(1.0, v)) + + if img_w and img_h: + return { + "x": _clamp(round(x / img_w, 4)), + "y": _clamp(round(y / img_h, 4)), + "width": _clamp(round(w / img_w, 4)), + "height": _clamp(round(h / img_h, 4)), + } + # Skip normalisation when dimensions are unknown + return {"x": 0.0, "y": 0.0, "width": 0.0, "height": 0.0} + + +def _annotation_from_element(el, img_w: int, img_h: int, seq: int) -> Dict[str, Any]: + """Build one overlay annotation from any extracted element. + + Works with ExtractionElement subclasses (DimensionElement, TitleBlockField, + NoteElement, GDTElement) and with BOMRow / RevisionEntry — all carry .text, + .type, .confidence, .bbox, and .page after Phase 2. + + ``seq`` is the 1-based position in the final annotation list; it produces a + stable ``change_id`` (``chg_001`` etc.) that is consistent across re-runs + for the same extraction result ordering. + """ + annotation_type = str(getattr(el, "type", "unknown")) + bbox = el.bbox + text = str(getattr(el, "text", "")) + bbox_pixels = { + "x": float(bbox.x), + "y": float(bbox.y), + "width": float(bbox.width), + "height": float(bbox.height), + } + return { + "change_id": f"chg_{seq:03d}", + "type": annotation_type, + "text": text, + "page": int(getattr(el, "page", 1)), + "confidence": round(float(el.confidence), 4), + "bbox": bbox_pixels, + "bbox_normalized": _normalize_bbox( + bbox.x, bbox.y, bbox.width, bbox.height, img_w, img_h + ), + "color": _COLOR_MAP.get(annotation_type, _COLOR_MAP["unknown"]), + "label": text[:60], + } + + +def _auto_image_size(image_path: str): + """Return (width, height) from image file, or (0, 0) on failure.""" + try: + from PIL import Image + with Image.open(image_path) as img: + return img.width, img.height + except Exception: + return 0, 0 + + +class OverlayGenerator: + """Converts EngineeringDrawingResult into a UI-renderable overlay JSON structure. + + Each annotation carries both pixel and normalised (0–1) bounding boxes so the + consumer can render at any resolution. Pass ``page`` to filter to a single PDF + page; omit it to include all pages in one payload. + """ + + def generate( + self, + result, + image_width: int = 0, + image_height: int = 0, + image_path: Optional[str] = None, + page: Optional[int] = None, + ) -> Dict[str, Any]: + """Return the overlay dict. + + Args: + result: :class:`EngineeringDrawingResult` instance. + image_width: Pixel width of the source image. 0 → auto-detect or skip. + image_height: Pixel height of the source image. + image_path: Optional path used to auto-detect dimensions when + ``image_width``/``image_height`` are both 0. + page: If set, only annotations from this PDF page are included. + """ + if image_width == 0 and image_height == 0 and image_path: + image_width, image_height = _auto_image_size(image_path) + + all_elements = ( + list(result.title_block) + + list(result.dimensions) + + list(result.notes) + + list(result.gdt) + + list(result.bom) + + list(result.revisions) + ) + + annotations: List[Dict] = [] + seq = 0 + for el in all_elements: + if page is not None and int(getattr(el, "page", 1)) != page: + continue + seq += 1 + annotations.append(_annotation_from_element(el, image_width, image_height, seq)) + + by_type: Dict[str, int] = {} + for ann in annotations: + by_type[ann["type"]] = by_type.get(ann["type"], 0) + 1 + + return { + "image_size": {"width": image_width, "height": image_height}, + "page_filter": page, + "total_annotations": len(annotations), + "summary": {"by_type": by_type, "total": len(annotations)}, + "annotations": annotations, + } diff --git a/docstrange/pipeline/__init__.py b/docstrange/pipeline/__init__.py index 136d52f..72f678d 100644 --- a/docstrange/pipeline/__init__.py +++ b/docstrange/pipeline/__init__.py @@ -1 +1,5 @@ -"""Pipeline package for document processing and OCR.""" \ No newline at end of file +"""Pipeline package for document processing and OCR.""" + +from .ocr_service import OCRService, OCRServiceFactory, NeuralOCRService, NanonetsOCRService + +__all__ = ["OCRService", "OCRServiceFactory", "NeuralOCRService", "NanonetsOCRService"] diff --git a/docstrange/pipeline/layout_detector.py b/docstrange/pipeline/layout_detector.py index 97f2ab0..5512973 100644 --- a/docstrange/pipeline/layout_detector.py +++ b/docstrange/pipeline/layout_detector.py @@ -170,19 +170,13 @@ def _join_paragraph_text_advanced(self, text_blocks: List[LayoutElement]) -> str return result.strip() def _post_process_text(self, text: str) -> str: - """Post-process text to improve readability.""" - # Fix common OCR issues - text = text.replace('|', 'I') # Common OCR mistake - text = text.replace('0', 'o') # Common OCR mistake in certain contexts - text = text.replace('1', 'l') # Common OCR mistake in certain contexts - - # Fix spacing issues - text = re.sub(r'\s+', ' ', text) # Multiple spaces to single space - text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text) # Fix sentence spacing - - # Fix common OCR artifacts - text = re.sub(r'[^\w\s.,!?;:()[\]{}"\'-]', '', text) # Remove strange characters - + """Post-process text to improve readability. + + Engineering safety: do NOT substitute digits or pipe characters — they carry + meaning in dimension values (0, 1) and GD&T feature control frames (|⊥|0.5|A|). + """ + text = re.sub(r'\s+', ' ', text) # collapse multiple spaces + text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text) # sentence spacing return text def _classify_paragraph(self, text: str) -> str: diff --git a/docstrange/pipeline/nanonets_processor.py b/docstrange/pipeline/nanonets_processor.py index 0c17b5e..c63adf4 100644 --- a/docstrange/pipeline/nanonets_processor.py +++ b/docstrange/pipeline/nanonets_processor.py @@ -128,6 +128,45 @@ def _extract_text_with_nanonets(self, image_path: str, max_new_tokens: int = 409 logger.error(f"Nanonets OCR extraction failed: {e}") return "" + def extract_layout_elements(self, image_path: str) -> List: + """Return line-level LayoutElements derived from the Nanonets OCR text output. + + The Nanonets transformer returns structured text (not positional data), so + we synthesise spatial coordinates by distributing lines across the image + height. This gives downstream extractors sequential order and approximate + y-positions suitable for zone heuristics (title block region, notes section, + etc.). + """ + from .layout_detector import LayoutElement + + text = self.extract_text(image_path) + if not text: + return [] + + try: + with Image.open(image_path) as img: + img_w, img_h = img.size + except Exception: + img_w, img_h = 1000, 1400 # A4-like fallback + + lines = [ln.strip() for ln in text.split("\n") if ln.strip()] + if not lines: + return [] + + line_height = max(img_h / len(lines), 15) + elements = [] + for idx, line in enumerate(lines): + elements.append(LayoutElement( + text=line, + x=0, + y=int(idx * line_height), + width=img_w, + height=int(line_height), + element_type="paragraph", + confidence=0.8, + )) + return elements + def __del__(self): """Cleanup resources.""" - pass \ No newline at end of file + pass \ No newline at end of file diff --git a/docstrange/pipeline/neural_document_processor.py b/docstrange/pipeline/neural_document_processor.py index e0a4a40..13bae19 100644 --- a/docstrange/pipeline/neural_document_processor.py +++ b/docstrange/pipeline/neural_document_processor.py @@ -280,12 +280,108 @@ def extract_text_with_layout(self, image_path: str) -> str: if not os.path.exists(image_path): logger.error(f"Image file does not exist: {image_path}") return "" - + return self._extract_text_with_layout_advanced(image_path) - + except Exception as e: logger.error(f"Layout-aware OCR extraction failed: {e}") return "" + + def extract_layout_elements(self, image_path: str) -> List: + """Return raw LayoutElement objects for engineering drawing extraction. + + Returns the list of positioned text elements before markdown conversion, + enabling downstream extractors to apply domain-specific logic. + """ + from .layout_detector import LayoutElement + + if not os.path.exists(image_path): + logger.error(f"Image file does not exist: {image_path}") + return [] + + try: + with Image.open(image_path) as img: + if img.mode != 'RGB': + img = img.convert('RGB') + + # Fallback mode: use EasyOCR directly (no layout predictor available) + if getattr(self, '_use_fallback_mode', False) or not getattr(self, 'use_advanced_models', False): + return self._extract_layout_elements_easyocr(img) + + layout_results = list(self.layout_predictor.predict(img)) + text_blocks = [] + + for pred in layout_results: + label = pred.get('label', '').lower().replace(' ', '_').replace('-', '_') + + if all(k in pred for k in ['l', 't', 'r', 'b']): + bbox = [pred['l'], pred['t'], pred['r'], pred['b']] + else: + bbox = pred.get('bbox') or pred.get('box') + if not bbox: + continue + + region_text = self._extract_text_from_region(img, bbox) + if not region_text or pred.get('confidence', 1.0) < 0.5: + continue + + if label in ['title', 'section_header', 'subtitle_level_1']: + element_type = 'heading' + elif label == 'list_item': + element_type = 'list_item' + elif label in ['table', 'document_index']: + element_type = 'table' + else: + element_type = 'paragraph' + + text_blocks.append(LayoutElement( + text=region_text, + x=bbox[0], + y=bbox[1], + width=bbox[2] - bbox[0], + height=bbox[3] - bbox[1], + element_type=element_type, + confidence=pred.get('confidence', 1.0), + )) + + text_blocks.sort(key=lambda el: (el.y, el.x)) + return text_blocks + + except Exception as e: + logger.error(f"extract_layout_elements failed: {e}") + return [] + + def _extract_layout_elements_easyocr(self, img: "Image.Image") -> List: + """Build LayoutElement list from EasyOCR (bbox_points, text, confidence) tuples.""" + from .layout_detector import LayoutElement + + try: + results = self.ocr_reader.readtext(img) + except Exception as e: + logger.error(f"EasyOCR readtext failed: {e}") + return [] + + elements = [] + for (bbox_pts, text, confidence) in results: + if confidence < 0.5 or not text.strip(): + continue + # bbox_pts is [[x1,y1],[x2,y1],[x2,y2],[x1,y2]] + xs = [pt[0] for pt in bbox_pts] + ys = [pt[1] for pt in bbox_pts] + x, y = min(xs), min(ys) + w, h = max(xs) - x, max(ys) - y + elements.append(LayoutElement( + text=text, + x=x, + y=y, + width=w, + height=h, + element_type='paragraph', + confidence=confidence, + )) + + elements.sort(key=lambda el: (el.y, el.x)) + return elements def _extract_text_advanced(self, image_path: str) -> str: """Extract text using docling's advanced models.""" diff --git a/docstrange/pipeline/ocr_service.py b/docstrange/pipeline/ocr_service.py index a9062a3..e856ea9 100644 --- a/docstrange/pipeline/ocr_service.py +++ b/docstrange/pipeline/ocr_service.py @@ -1,222 +1,171 @@ """OCR Service abstraction for neural document processing.""" -import os import logging +import os from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional +from typing import List, Optional logger = logging.getLogger(__name__) class OCRService(ABC): """Abstract base class for OCR services.""" - + @abstractmethod def extract_text(self, image_path: str) -> str: - """Extract text from image. - - Args: - image_path: Path to the image file - - Returns: - Extracted text as string - """ + """Extract plain text from an image file.""" pass - + @abstractmethod def extract_text_with_layout(self, image_path: str) -> str: - """Extract text with layout awareness from image. - - Args: - image_path: Path to the image file - - Returns: - Layout-aware extracted text as markdown + """Extract text with layout-aware markdown from an image file.""" + pass + + @abstractmethod + def extract_layout_elements(self, image_path: str) -> List: + """Return LayoutElement objects (text + bbox + confidence) for an image. + + This is the primary entry point for engineering drawing extraction — every + element must carry accurate positional data so extractors can apply zone + heuristics (title block region, BOM table area, etc.). """ pass + # ------------------------------------------------------------------ + # Shared validation helper + # ------------------------------------------------------------------ + + def _validate_image(self, image_path: str) -> bool: + if not os.path.exists(image_path): + logger.error("Image not found: %s", image_path) + return False + return True + class NanonetsOCRService(OCRService): - """Nanonets OCR implementation using NanonetsDocumentProcessor.""" - + """OCR service backed by NanonetsDocumentProcessor (transformer model).""" + def __init__(self): - """Initialize the service.""" from .nanonets_processor import NanonetsDocumentProcessor self._processor = NanonetsDocumentProcessor() - logger.info("NanonetsOCRService initialized") - + + # Properties kept for callers that inspect the underlying model objects @property def model(self): - """Get the Nanonets model.""" return self._processor.model - + @property def processor(self): - """Get the Nanonets processor.""" return self._processor.processor - + @property def tokenizer(self): - """Get the Nanonets tokenizer.""" return self._processor.tokenizer - + def extract_text(self, image_path: str) -> str: - """Extract text using Nanonets OCR.""" + if not self._validate_image(image_path): + return "" try: - # Validate image file - if not os.path.exists(image_path): - logger.error(f"Image file does not exist: {image_path}") - return "" - - # Check if file is readable - try: - from PIL import Image - with Image.open(image_path) as img: - logger.info(f"Image loaded successfully: {img.size} {img.mode}") - except Exception as e: - logger.error(f"Failed to load image: {e}") - return "" - - try: - text = self._processor.extract_text(image_path) - logger.info(f"Extracted text length: {len(text)}") - return text.strip() - except Exception as e: - logger.error(f"Nanonets OCR extraction failed: {e}") - return "" - + return self._processor.extract_text(image_path).strip() except Exception as e: - logger.error(f"Nanonets OCR extraction failed: {e}") + logger.error("NanonetsOCRService.extract_text failed: %s", e) return "" - + def extract_text_with_layout(self, image_path: str) -> str: - """Extract text with layout awareness using Nanonets OCR.""" + if not self._validate_image(image_path): + return "" try: - # Validate image file - if not os.path.exists(image_path): - logger.error(f"Image file does not exist: {image_path}") - return "" - - # Check if file is readable - try: - from PIL import Image - with Image.open(image_path) as img: - logger.info(f"Image loaded successfully: {img.size} {img.mode}") - except Exception as e: - logger.error(f"Failed to load image: {e}") - return "" - - try: - text = self._processor.extract_text_with_layout(image_path) - logger.info(f"Layout-aware extracted text length: {len(text)}") - return text.strip() - except Exception as e: - logger.error(f"Nanonets OCR layout-aware extraction failed: {e}") - return "" - + return self._processor.extract_text_with_layout(image_path).strip() except Exception as e: - logger.error(f"Nanonets OCR layout-aware extraction failed: {e}") + logger.error("NanonetsOCRService.extract_text_with_layout failed: %s", e) return "" + def extract_layout_elements(self, image_path: str) -> List: + if not self._validate_image(image_path): + return [] + try: + return self._processor.extract_layout_elements(image_path) + except Exception as e: + logger.error("NanonetsOCRService.extract_layout_elements failed: %s", e) + return [] + class NeuralOCRService(OCRService): - """Neural OCR implementation using docling's pre-trained models.""" - + """OCR service backed by NeuralDocumentProcessor (docling + EasyOCR).""" + def __init__(self): - """Initialize the service.""" from .neural_document_processor import NeuralDocumentProcessor self._processor = NeuralDocumentProcessor() - logger.info("NeuralOCRService initialized") - + def extract_text(self, image_path: str) -> str: - """Extract text using Neural OCR (docling models).""" + if not self._validate_image(image_path): + return "" try: - # Validate image file - if not os.path.exists(image_path): - logger.error(f"Image file does not exist: {image_path}") - return "" - - # Check if file is readable - try: - from PIL import Image - with Image.open(image_path) as img: - logger.info(f"Image loaded successfully: {img.size} {img.mode}") - except Exception as e: - logger.error(f"Failed to load image: {e}") - return "" - - try: - text = self._processor.extract_text(image_path) - logger.info(f"Extracted text length: {len(text)}") - return text.strip() - except Exception as e: - logger.error(f"Neural OCR extraction failed: {e}") - return "" - + return self._processor.extract_text(image_path).strip() except Exception as e: - logger.error(f"Neural OCR extraction failed: {e}") + logger.error("NeuralOCRService.extract_text failed: %s", e) return "" - + def extract_text_with_layout(self, image_path: str) -> str: - """Extract text with layout awareness using Neural OCR.""" + if not self._validate_image(image_path): + return "" try: - # Validate image file - if not os.path.exists(image_path): - logger.error(f"Image file does not exist: {image_path}") - return "" - - # Check if file is readable - try: - from PIL import Image - with Image.open(image_path) as img: - logger.info(f"Image loaded successfully: {img.size} {img.mode}") - except Exception as e: - logger.error(f"Failed to load image: {e}") - return "" - - try: - text = self._processor.extract_text_with_layout(image_path) - logger.info(f"Layout-aware extracted text length: {len(text)}") - return text.strip() - except Exception as e: - logger.error(f"Neural OCR layout-aware extraction failed: {e}") - return "" - + return self._processor.extract_text_with_layout(image_path).strip() except Exception as e: - logger.error(f"Neural OCR layout-aware extraction failed: {e}") + logger.error("NeuralOCRService.extract_text_with_layout failed: %s", e) return "" + def extract_layout_elements(self, image_path: str) -> List: + if not self._validate_image(image_path): + return [] + try: + return self._processor.extract_layout_elements(image_path) + except Exception as e: + logger.error("NeuralOCRService.extract_layout_elements failed: %s", e) + return [] + class OCRServiceFactory: - """Factory for creating OCR services based on configuration.""" - + """Creates OCR service instances with automatic provider fallback.""" + + _PROVIDERS = { + "nanonets": NanonetsOCRService, + "neural": NeuralOCRService, + } + @staticmethod - def create_service(provider: str = None) -> OCRService: - """Create OCR service based on provider configuration. - + def create_service(provider: Optional[str] = None) -> OCRService: + """Instantiate the requested provider; fall back to the other if it fails. + Args: - provider: OCR provider name (defaults to config) - - Returns: - OCRService instance + provider: ``"nanonets"`` or ``"neural"``. Defaults to + ``InternalConfig.ocr_provider`` (which defaults to ``"nanonets"``). """ from docstrange.config import InternalConfig - - if provider is None: - provider = getattr(InternalConfig, 'ocr_provider', 'nanonets') - - if provider.lower() == 'nanonets': - return NanonetsOCRService() - elif provider.lower() == 'neural': - return NeuralOCRService() - else: - raise ValueError(f"Unsupported OCR provider: {provider}") - + + preferred = (provider or getattr(InternalConfig, "ocr_provider", "nanonets")).lower() + # Build a try-order: preferred first, then the other option + order = [preferred] + [p for p in ("nanonets", "neural") if p != preferred] + + last_error: Optional[Exception] = None + for name in order: + cls = OCRServiceFactory._PROVIDERS.get(name) + if cls is None: + logger.warning("Unknown OCR provider '%s', skipping.", name) + continue + try: + service = cls() + if name != preferred: + logger.info("OCR provider '%s' unavailable; using '%s' instead.", preferred, name) + return service + except Exception as exc: + logger.warning("OCR provider '%s' failed to initialise: %s", name, exc) + last_error = exc + + raise RuntimeError( + f"All OCR providers failed to initialise. Last error: {last_error}" + ) + @staticmethod def get_available_providers() -> List[str]: - """Get list of available OCR providers. - - Returns: - List of available provider names - """ - return ['nanonets', 'neural'] \ No newline at end of file + return list(OCRServiceFactory._PROVIDERS.keys()) diff --git a/docstrange/pipelines/__init__.py b/docstrange/pipelines/__init__.py new file mode 100644 index 0000000..80baa86 --- /dev/null +++ b/docstrange/pipelines/__init__.py @@ -0,0 +1,3 @@ +from .engineering import EngineeringDrawingPipeline + +__all__ = ["EngineeringDrawingPipeline"] diff --git a/docstrange/pipelines/engineering.py b/docstrange/pipelines/engineering.py new file mode 100644 index 0000000..41a9018 --- /dev/null +++ b/docstrange/pipelines/engineering.py @@ -0,0 +1,158 @@ +"""Engineering drawing extraction pipeline orchestrator.""" + +import logging +import os +import tempfile +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class EngineeringDrawingPipeline: + """Orchestrates full engineering drawing extraction from an image or PDF. + + Usage:: + + pipeline = EngineeringDrawingPipeline() + result = pipeline.extract_from_pdf("drawing.pdf") + print(result.model_dump()) + """ + + _ALL_EXTRACTORS = ["title_block", "dimensions", "notes", "gdt", "bom", "revisions"] + + def __init__(self, ocr_service=None): + from ..pipeline.ocr_service import OCRServiceFactory + from ..extractors import ( + TitleBlockExtractor, DimensionExtractor, NoteExtractor, + GDTExtractor, BOMExtractor, RevisionExtractor, + ) + self._ocr = ocr_service or OCRServiceFactory.create_service() + self._extractors: Dict = { + "title_block": TitleBlockExtractor(), + "dimensions": DimensionExtractor(), + "notes": NoteExtractor(), + "gdt": GDTExtractor(), + "bom": BOMExtractor(), + "revisions": RevisionExtractor(), + } + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def extract_from_image( + self, + image_path: str, + extractors: Optional[List[str]] = None, + ): + """Run selected (or all) extractors on a single image file. + + Returns an :class:`EngineeringDrawingResult`. + """ + from ..schemas.engineering import EngineeringDrawingResult + + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image not found: {image_path}") + + elements = self._get_layout_elements(image_path) + return self._run_extractors(elements, extractors, metadata={"source": image_path, "pages": 1}) + + def extract_from_pdf( + self, + pdf_path: str, + extractors: Optional[List[str]] = None, + ): + """Convert each PDF page to an image and run extraction, merging results. + + Returns an :class:`EngineeringDrawingResult`. + """ + from ..schemas.engineering import EngineeringDrawingResult + + if not os.path.exists(pdf_path): + raise FileNotFoundError(f"PDF not found: {pdf_path}") + + images = self._pdf_to_images(pdf_path) + if not images: + logger.warning(f"No pages extracted from PDF: {pdf_path}") + return EngineeringDrawingResult(metadata={"source": pdf_path, "pages": 0}) + + page_results = [] + for page_num, img_path in enumerate(images, start=1): + try: + elements = self._get_layout_elements(img_path) + result = self._run_extractors( + elements, extractors, + metadata={"source": pdf_path, "page": page_num}, + ) + page_results.append(result) + except Exception as e: + logger.error(f"Failed to process page {page_num}: {e}") + finally: + try: + os.unlink(img_path) + except OSError: + pass + + return self._merge_page_results(page_results, metadata={"source": pdf_path, "pages": len(images)}) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_layout_elements(self, image_path: str) -> List: + return self._ocr.extract_layout_elements(image_path) + + def _run_extractors(self, elements: List, extractors: Optional[List[str]], metadata: dict): + from ..schemas.engineering import EngineeringDrawingResult + + names = extractors if extractors else self._ALL_EXTRACTORS + page_num = metadata.get("page", 1) + kwargs: Dict[str, list] = {} + + for name in names: + extractor = self._extractors.get(name) + if extractor is None: + logger.warning(f"Unknown extractor: {name}") + continue + try: + items = extractor.extract(elements) + except Exception as e: + logger.error(f"Extractor '{name}' failed: {e}") + items = [] + # Stamp the source page on every element for multi-page traceability + for item in items: + item.page = page_num + kwargs[name] = items # type: ignore[assignment] + + return EngineeringDrawingResult(metadata=metadata, **kwargs) + + def _merge_page_results(self, pages: List, metadata: dict): + from ..schemas.engineering import EngineeringDrawingResult + + if not pages: + return EngineeringDrawingResult(metadata=metadata) + + merged = EngineeringDrawingResult(metadata=metadata) + merged.title_block = [item for p in pages for item in p.title_block] + merged.dimensions = [item for p in pages for item in p.dimensions] + merged.notes = [item for p in pages for item in p.notes] + merged.gdt = [item for p in pages for item in p.gdt] + merged.bom = [item for p in pages for item in p.bom] + merged.revisions = [item for p in pages for item in p.revisions] + return merged + + def _pdf_to_images(self, pdf_path: str) -> List[str]: + """Convert PDF pages to temporary image files. Returns list of image paths.""" + try: + from pdf2image import convert_from_path + except ImportError: + raise ImportError("pdf2image is required for PDF processing: pip install pdf2image") + + images = convert_from_path(pdf_path, dpi=300) + paths = [] + for img in images: + tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + img.save(tmp.name, "PNG") + tmp.close() + paths.append(tmp.name) + return paths diff --git a/docstrange/schemas/__init__.py b/docstrange/schemas/__init__.py new file mode 100644 index 0000000..43206d6 --- /dev/null +++ b/docstrange/schemas/__init__.py @@ -0,0 +1,23 @@ +from .engineering import ( + BBoxSchema, + ExtractionElement, + DimensionElement, + TitleBlockField, + NoteElement, + GDTElement, + BOMRow, + RevisionEntry, + EngineeringDrawingResult, +) + +__all__ = [ + "BBoxSchema", + "ExtractionElement", + "DimensionElement", + "TitleBlockField", + "NoteElement", + "GDTElement", + "BOMRow", + "RevisionEntry", + "EngineeringDrawingResult", +] diff --git a/docstrange/schemas/engineering.py b/docstrange/schemas/engineering.py new file mode 100644 index 0000000..352dad3 --- /dev/null +++ b/docstrange/schemas/engineering.py @@ -0,0 +1,123 @@ +"""Pydantic schemas for engineering drawing extraction output.""" + +from typing import Any, Dict, List, Literal, Optional + +try: + from pydantic import BaseModel, Field, model_validator +except ImportError as e: + raise ImportError( + "pydantic is required for engineering extraction schemas. " + "Install with: pip install 'docstrange[engineering]'" + ) from e + + +class BBoxSchema(BaseModel): + x: float + y: float + width: float + height: float + + +class ExtractionMetadata(BaseModel): + """Structured metadata attached to every EngineeringDrawingResult.""" + source: str = "" + pages: int = 0 + page: Optional[int] = None # set on per-page intermediate results + extractor_version: str = "1.0.0" + + model_config = {"extra": "allow"} # absorb unknown keys from legacy callers + + +class ExtractionElement(BaseModel): + """Base schema for a single extracted entity with spatial context.""" + text: str + type: str + confidence: float = Field(ge=0.0, le=1.0) + bbox: BBoxSchema + page: int = 1 + + +class DimensionElement(ExtractionElement): + type: Literal["dimension"] = "dimension" + nominal: Optional[float] = None + upper_tolerance: Optional[float] = None + lower_tolerance: Optional[float] = None + unit: Optional[str] = None + dimension_type: Optional[str] = None # "linear" | "angular" | "radial" | "diameter" + + +class TitleBlockField(ExtractionElement): + type: Literal["title_block"] = "title_block" + field_name: str + field_value: str + + +class NoteElement(ExtractionElement): + type: Literal["note"] = "note" + note_number: Optional[int] = None + is_general: bool = False + + +class GDTElement(ExtractionElement): + type: Literal["gdt"] = "gdt" + symbol: str + tolerance_value: Optional[str] = None + datum_reference: Optional[str] = None + + +class BOMRow(BaseModel): + """A single row from a Bill of Materials table.""" + type: Literal["bom"] = "bom" + text: str = "" # summary text; auto-filled from raw_cells + item_number: Optional[str] = None + quantity: Optional[str] = None + part_number: Optional[str] = None + description: Optional[str] = None + material: Optional[str] = None + raw_cells: List[str] = [] + confidence: float = Field(ge=0.0, le=1.0) + bbox: BBoxSchema + page: int = 1 + + @model_validator(mode="after") + def _fill_text(self): + if not self.text and self.raw_cells: + self.text = " | ".join(c for c in self.raw_cells if c) + return self + + +class RevisionEntry(BaseModel): + """A single entry from a revision history block.""" + type: Literal["revision"] = "revision" + text: str = "" # summary text; auto-filled from fields + revision: Optional[str] = None + date: Optional[str] = None + description: Optional[str] = None + approved_by: Optional[str] = None + confidence: float = Field(ge=0.0, le=1.0) + bbox: BBoxSchema + page: int = 1 + + @model_validator(mode="after") + def _fill_text(self): + if not self.text: + parts = [p for p in [self.revision, self.date, self.description] if p] + self.text = " | ".join(parts) + return self + + +class EngineeringDrawingResult(BaseModel): + title_block: List[TitleBlockField] = [] + dimensions: List[DimensionElement] = [] + notes: List[NoteElement] = [] + gdt: List[GDTElement] = [] + bom: List[BOMRow] = [] + revisions: List[RevisionEntry] = [] + metadata: ExtractionMetadata = Field(default_factory=ExtractionMetadata) + + @model_validator(mode="before") + @classmethod + def _coerce_metadata(cls, data): + if isinstance(data, dict) and isinstance(data.get("metadata"), dict): + data["metadata"] = ExtractionMetadata(**data["metadata"]) + return data diff --git a/pyproject.toml b/pyproject.toml index 34c6ebf..b29e38b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,10 +74,17 @@ local-llm = [ web = [ "Flask>=2.0.0", ] +engineering = [ + "pydantic>=2.0.0", + "fastapi>=0.100.0; python_version>='3.10'", + "uvicorn[standard]>=0.22.0; python_version>='3.10'", + "python-multipart>=0.0.6; python_version>='3.10'", +] [project.scripts] docstrange = "docstrange.cli:main" docstrange-web = "docstrange.web_app:run_web_app" +docstrange-eng-mcp = "docstrange.mcp_server.__main__:main" [project.urls] Homepage = "https://github.com/nanonets/docstrange" diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..bc6edde --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,274 @@ +"""Tests for the DocStrange Engineering FastAPI service layer. + +Uses FastAPI's TestClient (sync) so no real OCR models are needed — the pipeline +is monkey-patched with a mock before each test. +""" + +import io +import json +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip("fastapi", reason="fastapi not installed") +pytest.importorskip("httpx", reason="httpx not installed (required by TestClient)") + +from fastapi.testclient import TestClient + +from docstrange.schemas.engineering import ( + BBoxSchema, + BOMRow, + DimensionElement, + EngineeringDrawingResult, + GDTElement, + NoteElement, + RevisionEntry, + TitleBlockField, +) + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +_BBOX = BBoxSchema(x=10.0, y=20.0, width=50.0, height=15.0) + +_FULL_RESULT = EngineeringDrawingResult( + title_block=[TitleBlockField(text="DWG-001", field_name="drawing_number", + field_value="DWG-001", confidence=0.95, bbox=_BBOX)], + dimensions=[DimensionElement(text="25.4 mm", nominal=25.4, unit="mm", + dimension_type="linear", confidence=0.9, bbox=_BBOX)], + notes=[NoteElement(text="ALL DIMENSIONS IN MM", is_general=True, + confidence=0.85, bbox=_BBOX)], + gdt=[GDTElement(text="⊥ 0.05 A", symbol="perpendicularity", + tolerance_value="0.05", datum_reference="A", + confidence=0.88, bbox=_BBOX)], + bom=[BOMRow(item_number="1", quantity="2", description="Bolt M6", + raw_cells=["1", "2", "Bolt M6"], confidence=0.9, bbox=_BBOX)], + revisions=[RevisionEntry(revision="A", date="2024-01-15", + description="Initial release", confidence=0.8, bbox=_BBOX)], + metadata={"source": "test.pdf", "pages": 1}, +) + + +@pytest.fixture(scope="module") +def client(): + """Return a TestClient with the pipeline swapped for a mock.""" + import docstrange.api.routes as routes_module + + mock_pipeline = MagicMock() + mock_pipeline.extract_from_image.return_value = _FULL_RESULT + mock_pipeline.extract_from_pdf.return_value = _FULL_RESULT + + # Patch the lazy singleton so no real models are loaded + routes_module._pipeline = mock_pipeline + routes_module._overlay_gen = None # let the real OverlayGenerator instantiate + + from docstrange.api.routes import create_app + app = create_app() + return TestClient(app) + + +def _png_bytes() -> bytes: + """Return a minimal 1×1 white PNG — valid enough for extension detection.""" + from PIL import Image + buf = io.BytesIO() + Image.new("RGB", (1, 1), (255, 255, 255)).save(buf, format="PNG") + return buf.getvalue() + + +def _pdf_bytes() -> bytes: + """Minimal valid PDF bytes (enough to pass extension check).""" + return b"%PDF-1.4 fake" + + +# --------------------------------------------------------------------------- +# Health check +# --------------------------------------------------------------------------- + +class TestHealth: + def test_returns_ok(self, client): + resp = client.get("/health") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert body["service"] == "docstrange-engineering" + + +# --------------------------------------------------------------------------- +# Individual extraction endpoints +# --------------------------------------------------------------------------- + +class TestExtractionEndpoints: + + def _upload(self, client, url: str, content: bytes, filename: str = "drawing.png"): + return client.post(url, files={"file": (filename, content, "image/png")}) + + def test_title_block_returns_list(self, client): + resp = self._upload(client, "/extract/title-block", _png_bytes()) + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + assert body[0]["type"] == "title_block" + assert body[0]["field_name"] == "drawing_number" + + def test_dimensions_returns_list(self, client): + resp = self._upload(client, "/extract/dimensions", _png_bytes()) + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + assert body[0]["type"] == "dimension" + assert body[0]["nominal"] == 25.4 + + def test_notes_returns_list(self, client): + resp = self._upload(client, "/extract/notes", _png_bytes()) + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + assert body[0]["type"] == "note" + + def test_gdt_returns_list(self, client): + resp = self._upload(client, "/extract/gdt", _png_bytes()) + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + assert body[0]["type"] == "gdt" + assert body[0]["symbol"] == "perpendicularity" + + def test_bom_returns_list(self, client): + resp = self._upload(client, "/extract/bom", _png_bytes()) + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + assert body[0]["type"] == "bom" + assert "Bolt M6" in body[0]["text"] + + def test_revisions_returns_list(self, client): + resp = self._upload(client, "/extract/revisions", _png_bytes()) + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + assert body[0]["type"] == "revision" + assert body[0]["revision"] == "A" + + +# --------------------------------------------------------------------------- +# Full extraction +# --------------------------------------------------------------------------- + +class TestFullExtraction: + + def test_full_returns_all_sections(self, client): + resp = client.post( + "/extract/full", + files={"file": ("drawing.png", _png_bytes(), "image/png")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "title_block" in body + assert "dimensions" in body + assert "notes" in body + assert "gdt" in body + assert "bom" in body + assert "revisions" in body + + def test_full_no_overlay_by_default(self, client): + resp = client.post( + "/extract/full", + files={"file": ("drawing.png", _png_bytes(), "image/png")}, + ) + body = resp.json() + assert "overlay_json" not in body or body.get("overlay_json") is None + + def test_full_with_overlays(self, client): + resp = client.post( + "/extract/full?include_overlays=true&image_width=1000&image_height=800", + files={"file": ("drawing.png", _png_bytes(), "image/png")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "overlay_json" in body + overlay = body["overlay_json"] + assert "annotations" in overlay + assert "total_annotations" in overlay + + def test_full_metadata_is_typed(self, client): + resp = client.post( + "/extract/full", + files={"file": ("drawing.pdf", _pdf_bytes(), "application/pdf")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "metadata" in body + # ExtractionMetadata fields + assert "source" in body["metadata"] + assert "pages" in body["metadata"] + assert "extractor_version" in body["metadata"] + + +# --------------------------------------------------------------------------- +# Overlay endpoint +# --------------------------------------------------------------------------- + +class TestOverlayEndpoint: + + def test_overlay_structure(self, client): + resp = client.post( + "/generate/overlays?image_width=1000&image_height=800", + files={"file": ("drawing.png", _png_bytes(), "image/png")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "annotations" in body + assert "image_size" in body + assert body["image_size"]["width"] == 1000 + assert body["total_annotations"] == len(body["annotations"]) + + def test_overlay_annotations_have_page_field(self, client): + resp = client.post( + "/generate/overlays?image_width=1000&image_height=800", + files={"file": ("drawing.png", _png_bytes(), "image/png")}, + ) + for ann in resp.json()["annotations"]: + assert "page" in ann + assert isinstance(ann["page"], int) + + def test_overlay_page_filter(self, client): + resp = client.post( + "/generate/overlays?image_width=1000&image_height=800&page=1", + files={"file": ("drawing.png", _png_bytes(), "image/png")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["page_filter"] == 1 + for ann in body["annotations"]: + assert ann["page"] == 1 + + +# --------------------------------------------------------------------------- +# Validation — file type and size +# --------------------------------------------------------------------------- + +class TestValidation: + + def test_unsupported_extension_returns_415(self, client): + resp = client.post( + "/extract/dimensions", + files={"file": ("drawing.docx", b"fake", "application/octet-stream")}, + ) + assert resp.status_code == 415 + + def test_empty_file_returns_400(self, client): + resp = client.post( + "/extract/dimensions", + files={"file": ("drawing.png", b"", "image/png")}, + ) + assert resp.status_code == 400 + + def test_oversized_file_returns_413(self, client): + big = b"X" * (51 * 1024 * 1024) # 51 MB + resp = client.post( + "/extract/dimensions", + files={"file": ("drawing.png", big, "image/png")}, + ) + assert resp.status_code == 413 diff --git a/tests/test_e2e_engineering.py b/tests/test_e2e_engineering.py new file mode 100644 index 0000000..7159916 --- /dev/null +++ b/tests/test_e2e_engineering.py @@ -0,0 +1,201 @@ +"""Smoke tests — verify the OCR → LayoutElement → Extractor wiring works end-to-end. + +These tests do NOT require real OCR models. They inject a synthetic OCR service and a +PIL-generated image so the full EngineeringDrawingPipeline can run without GPU or network. +""" + +import os +import tempfile + +import pytest + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_synthetic_image(path: str) -> None: + """Write a small white PNG to *path* using PIL.""" + from PIL import Image, ImageDraw + + img = Image.new("RGB", (800, 600), color=(255, 255, 255)) + draw = ImageDraw.Draw(img) + draw.text((10, 10), "DRAWING TITLE: Test Part", fill=(0, 0, 0)) + draw.text((10, 40), "25.4 mm", fill=(0, 0, 0)) + draw.text((10, 70), "GENERAL NOTES:", fill=(0, 0, 0)) + draw.text((10, 90), "1. All dimensions in mm.", fill=(0, 0, 0)) + img.save(path, "PNG") + + +def _make_layout_elements(): + """Return a small list of LayoutElement objects covering each extractor.""" + from docstrange.pipeline.layout_detector import LayoutElement + + def el(text, x=10.0, y=10.0, w=200.0, h=20.0, confidence=0.9): + return LayoutElement(text=text, x=x, y=y, width=w, height=h, + element_type="paragraph", confidence=confidence) + + return [ + # Title block zone (lower-right corner, 800×600 image) + el("DRAWING NUMBER: DWG-001", x=560, y=500, w=220, h=18), + el("TITLE: Test Part", x=560, y=520, w=220, h=18), + el("SCALE: 1:1", x=560, y=540, w=220, h=18), + # Dimension + el("25.4 mm", x=100, y=100), + el("⌀12.5", x=200, y=130), + # Notes + el("GENERAL NOTES:", x=10, y=200), + el("1. All dimensions in mm.",x=10, y=220), + # Revision + el("REVISION HISTORY", x=10, y=300), + el("A 2024-01-15 Initial release J. Smith", x=10, y=320), + ] + + +class _MockOCRService: + """Returns pre-defined LayoutElements without running any model.""" + + def extract_layout_elements(self, image_path: str): + return _make_layout_elements() + + def extract_text(self, image_path: str) -> str: + return " ".join(el.text for el in _make_layout_elements()) + + def extract_text_with_layout(self, image_path: str) -> str: + return self.extract_text(image_path) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestEngineeringPipelineE2E: + + def test_extract_from_image_returns_result_type(self, tmp_path): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + from docstrange.schemas.engineering import EngineeringDrawingResult + + img_path = str(tmp_path / "test_drawing.png") + _make_synthetic_image(img_path) + + pipeline = EngineeringDrawingPipeline(ocr_service=_MockOCRService()) + result = pipeline.extract_from_image(img_path) + + assert isinstance(result, EngineeringDrawingResult) + + def test_extract_from_image_has_metadata(self, tmp_path): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + from docstrange.schemas.engineering import ExtractionMetadata + + img_path = str(tmp_path / "test_drawing.png") + _make_synthetic_image(img_path) + + pipeline = EngineeringDrawingPipeline(ocr_service=_MockOCRService()) + result = pipeline.extract_from_image(img_path) + + assert isinstance(result.metadata, ExtractionMetadata) + assert result.metadata.source == img_path + assert result.metadata.pages == 1 + + def test_dimensions_extracted(self, tmp_path): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + img_path = str(tmp_path / "test_drawing.png") + _make_synthetic_image(img_path) + + pipeline = EngineeringDrawingPipeline(ocr_service=_MockOCRService()) + result = pipeline.extract_from_image(img_path, extractors=["dimensions"]) + + assert len(result.dimensions) >= 1 + texts = [d.text for d in result.dimensions] + assert any("25.4" in t or "12.5" in t for t in texts) + + def test_notes_extracted(self, tmp_path): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + img_path = str(tmp_path / "test_drawing.png") + _make_synthetic_image(img_path) + + pipeline = EngineeringDrawingPipeline(ocr_service=_MockOCRService()) + result = pipeline.extract_from_image(img_path, extractors=["notes"]) + + assert len(result.notes) >= 1 + + def test_title_block_extracted(self, tmp_path): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + img_path = str(tmp_path / "test_drawing.png") + _make_synthetic_image(img_path) + + pipeline = EngineeringDrawingPipeline(ocr_service=_MockOCRService()) + result = pipeline.extract_from_image(img_path, extractors=["title_block"]) + + assert len(result.title_block) >= 1 + + def test_model_dump_is_serialisable(self, tmp_path): + """Result must serialise to JSON-safe dict without errors.""" + import json + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + img_path = str(tmp_path / "test_drawing.png") + _make_synthetic_image(img_path) + + pipeline = EngineeringDrawingPipeline(ocr_service=_MockOCRService()) + result = pipeline.extract_from_image(img_path) + + dumped = result.model_dump() + serialised = json.dumps(dumped) # must not raise + assert isinstance(serialised, str) + + def test_file_not_found_raises(self): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + pipeline = EngineeringDrawingPipeline(ocr_service=_MockOCRService()) + with pytest.raises(FileNotFoundError): + pipeline.extract_from_image("/nonexistent/drawing.png") + + def test_selective_extractors_only(self, tmp_path): + """Requesting only 'dimensions' must leave other result fields empty.""" + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + img_path = str(tmp_path / "test_drawing.png") + _make_synthetic_image(img_path) + + pipeline = EngineeringDrawingPipeline(ocr_service=_MockOCRService()) + result = pipeline.extract_from_image(img_path, extractors=["dimensions"]) + + assert result.title_block == [] + assert result.notes == [] + assert result.gdt == [] + assert result.bom == [] + assert result.revisions == [] + + +# --------------------------------------------------------------------------- +# Public import surface tests +# --------------------------------------------------------------------------- + +class TestPublicImports: + + def test_engineering_pipeline_importable_from_docstrange(self): + from docstrange import EngineeringDrawingPipeline # noqa: F401 + + def test_engineering_result_importable_from_docstrange(self): + from docstrange import EngineeringDrawingResult # noqa: F401 + + def test_extractors_importable_from_docstrange(self): + from docstrange import ( # noqa: F401 + TitleBlockExtractor, + DimensionExtractor, + NoteExtractor, + GDTExtractor, + BOMExtractor, + RevisionExtractor, + ) + + def test_mcp_server_importable(self): + pytest.importorskip("mcp", reason="mcp package not installed") + from docstrange.mcp_server import EngineeringMCPServer # noqa: F401 + + def test_ocr_factory_importable_from_pipeline(self): + from docstrange.pipeline import OCRServiceFactory # noqa: F401 diff --git a/tests/test_extraction_accuracy.py b/tests/test_extraction_accuracy.py new file mode 100644 index 0000000..cbb1397 --- /dev/null +++ b/tests/test_extraction_accuracy.py @@ -0,0 +1,534 @@ +"""Phase 7 — Extraction accuracy, bounding box fidelity, multi-page stamping, +malformed input handling, and schema validation tests. + +Focuses on gaps not covered by the existing extractor unit tests: +- Edge-case dimension notation (imperial, bilateral tolerance, DIA. text) +- GD&T variants (position, parallelism, feature control frames with multiple datums) +- Bounding box end-to-end precision +- Multi-page page-number stamping through the pipeline +- Degenerate / malformed input robustness +- Pydantic schema constraint validation +- Result completeness invariants +""" + +import pytest +from unittest.mock import MagicMock + +from docstrange.pipeline.layout_detector import LayoutElement +from docstrange.schemas.engineering import ( + BBoxSchema, + BOMRow, + DimensionElement, + EngineeringDrawingResult, + ExtractionMetadata, + GDTElement, + NoteElement, + RevisionEntry, + TitleBlockField, +) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +def el(text, x=0.0, y=0.0, w=100.0, h=20.0, confidence=0.9, element_type="paragraph"): + return LayoutElement( + text=text, x=x, y=y, width=w, height=h, + element_type=element_type, confidence=confidence, + ) + + +def _pipeline_with_elements(elements): + """Return an EngineeringDrawingPipeline whose OCR service returns *elements*.""" + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + svc = MagicMock() + svc.extract_layout_elements.return_value = elements + return EngineeringDrawingPipeline(ocr_service=svc) + + +# --------------------------------------------------------------------------- +# Dimension accuracy — edge cases +# --------------------------------------------------------------------------- + +class TestDimensionAccuracy: + + def test_imperial_inch_unit_double_quote(self): + from docstrange.extractors.dimensions import DimensionExtractor + results = DimensionExtractor().extract([el('1.375"')]) + linears = [r for r in results if r.dimension_type == "linear"] + assert len(linears) >= 1 + assert linears[0].nominal == 1.375 + assert linears[0].unit == "in" + + def test_imperial_inch_unit_in_suffix(self): + from docstrange.extractors.dimensions import DimensionExtractor + results = DimensionExtractor().extract([el("2.500 in")]) + linears = [r for r in results if r.dimension_type == "linear"] + assert len(linears) >= 1 + assert linears[0].unit == "in" + + def test_bilateral_tolerance_upper_and_lower(self): + from docstrange.extractors.dimensions import DimensionExtractor + results = DimensionExtractor().extract([el("25.4 +0.5/-0.2 mm")]) + linears = [r for r in results if r.dimension_type == "linear"] + assert len(linears) >= 1 + d = linears[0] + assert d.nominal == 25.4 + assert d.upper_tolerance == 0.5 + assert d.lower_tolerance == -0.2 + + def test_diameter_text_notation(self): + from docstrange.extractors.dimensions import DimensionExtractor + results = DimensionExtractor().extract([el("DIA. 25.4")]) + diams = [r for r in results if r.dimension_type == "diameter"] + assert len(diams) >= 1 + assert diams[0].nominal == 25.4 + + def test_diameter_uppercase_o(self): + from docstrange.extractors.dimensions import DimensionExtractor + results = DimensionExtractor().extract([el("Ø12.5")]) + diams = [r for r in results if r.dimension_type == "diameter"] + assert len(diams) >= 1 + assert diams[0].nominal == 12.5 + + def test_empty_text_skipped(self): + from docstrange.extractors.dimensions import DimensionExtractor + results = DimensionExtractor().extract([el(""), el(" ")]) + assert results == [] + + def test_no_unit_still_extracts_linear(self): + from docstrange.extractors.dimensions import DimensionExtractor + results = DimensionExtractor().extract([el("50.0")]) + linears = [r for r in results if r.dimension_type == "linear"] + assert len(linears) >= 1 + assert linears[0].nominal == 50.0 + assert linears[0].unit is None + + def test_very_large_integer_no_unit_skipped(self): + from docstrange.extractors.dimensions import DimensionExtractor + # Pure large integer without a unit should not produce a result + results = DimensionExtractor().extract([el("99999")]) + for r in results: + assert not (r.dimension_type == "linear" and r.nominal == 99999 and r.unit is None) + + def test_whitespace_only_element_skipped(self): + from docstrange.extractors.dimensions import DimensionExtractor + assert DimensionExtractor().extract([el("\t\n ")]) == [] + + def test_long_text_no_crash(self): + from docstrange.extractors.dimensions import DimensionExtractor + long_text = "PART REF " + " ".join(f"{i}.{i}" for i in range(100)) + results = DimensionExtractor().extract([el(long_text)]) + assert isinstance(results, list) + + +# --------------------------------------------------------------------------- +# GD&T accuracy — edge cases +# --------------------------------------------------------------------------- + +class TestGDTAccuracy: + + def test_position_unicode_symbol(self): + from docstrange.extractors.gdt import GDTExtractor + results = GDTExtractor().extract([el("⊙ 0.1 A")]) + pos = [r for r in results if r.symbol == "position"] + assert len(pos) >= 1 + + def test_parallelism_text_abbreviation(self): + from docstrange.extractors.gdt import GDTExtractor + results = GDTExtractor().extract([el("PAR 0.03")]) + par = [r for r in results if r.symbol == "parallelism"] + assert len(par) >= 1 + assert par[0].tolerance_value == "0.03" + + def test_straightness_text_abbreviation(self): + from docstrange.extractors.gdt import GDTExtractor + results = GDTExtractor().extract([el("STRAIGHTNESS 0.01")]) + s = [r for r in results if r.symbol == "straightness"] + assert len(s) >= 1 + + def test_feature_control_frame_with_two_datums(self): + from docstrange.extractors.gdt import GDTExtractor + results = GDTExtractor().extract([el("|⊥|0.05|A|B|")]) + assert len(results) >= 1 + r = results[0] + assert r.symbol == "perpendicularity" + assert r.tolerance_value == "0.05" + + def test_datum_reference_extracted_from_feature_frame(self): + from docstrange.extractors.gdt import GDTExtractor + results = GDTExtractor().extract([el("|⊥|0.02|C|")]) + r = next((r for r in results if r.datum_reference is not None), None) + assert r is not None + assert r.datum_reference == "C" + + def test_empty_elements_skipped(self): + from docstrange.extractors.gdt import GDTExtractor + assert GDTExtractor().extract([el("")]) == [] + + +# --------------------------------------------------------------------------- +# Bounding box end-to-end precision +# --------------------------------------------------------------------------- + +class TestBoundingBoxPrecision: + + def test_dimension_extractor_preserves_exact_bbox(self): + from docstrange.extractors.dimensions import DimensionExtractor + elem = el("10.0 mm", x=123.4, y=567.8, w=88.0, h=22.5) + results = DimensionExtractor().extract([elem]) + assert len(results) >= 1 + b = results[0].bbox + assert b.x == 123.4 + assert b.y == 567.8 + assert b.width == 88.0 + assert b.height == 22.5 + + def test_two_elements_get_independent_bboxes(self): + from docstrange.extractors.dimensions import DimensionExtractor + e1 = el("10.0 mm", x=10.0, y=20.0, w=50.0, h=15.0) + e2 = el("20.0 mm", x=200.0, y=400.0, w=60.0, h=18.0) + results = DimensionExtractor().extract([e1, e2]) + bboxes = [r.bbox for r in results if r.dimension_type == "linear"] + xs = {b.x for b in bboxes} + assert 10.0 in xs + assert 200.0 in xs + + def test_gdt_extractor_bbox_from_source_element(self): + from docstrange.extractors.gdt import GDTExtractor + elem = el("⊥ 0.05 A", x=300.0, y=150.0, w=70.0, h=12.0) + results = GDTExtractor().extract([elem]) + assert len(results) >= 1 + b = results[0].bbox + assert b.x == 300.0 + assert b.y == 150.0 + + def test_bom_row_bbox_preserved_in_schema(self): + bbox = BBoxSchema(x=40.0, y=80.0, width=300.0, height=20.0) + row = BOMRow( + confidence=0.9, bbox=bbox, + item_number="1", description="Hex bolt", + raw_cells=["1", "2", "Hex bolt"], + ) + assert row.bbox.x == 40.0 + assert row.bbox.y == 80.0 + assert row.bbox.width == 300.0 + + def test_revision_entry_bbox_preserved(self): + bbox = BBoxSchema(x=700.0, y=50.0, width=200.0, height=15.0) + entry = RevisionEntry( + revision="B", date="2024-06-01", description="Updated", + confidence=0.85, bbox=bbox, + ) + assert entry.bbox.x == 700.0 + + +# --------------------------------------------------------------------------- +# Multi-page page stamping +# --------------------------------------------------------------------------- + +class TestMultiPageStamping: + + def test_run_extractors_stamps_page_number(self): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + elements = [ + el("25.4 mm", x=200.0, y=300.0), + el("GENERAL NOTES:", x=50.0, y=700.0), + el("1. ALL DIMS IN MM", x=50.0, y=720.0), + ] + pipeline = _pipeline_with_elements(elements) + result = pipeline._run_extractors(elements, None, metadata={"source": "test.pdf", "page": 3}) + + for dim in result.dimensions: + assert dim.page == 3 + for note in result.notes: + assert note.page == 3 + + def test_merge_preserves_per_page_numbers(self): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + pipeline = _pipeline_with_elements([]) + bbox = BBoxSchema(x=0, y=0, width=50, height=20) + + p1 = EngineeringDrawingResult( + dimensions=[DimensionElement(text="10mm", nominal=10, dimension_type="linear", + confidence=0.9, bbox=bbox, page=1)] + ) + p2 = EngineeringDrawingResult( + dimensions=[DimensionElement(text="20mm", nominal=20, dimension_type="linear", + confidence=0.85, bbox=bbox, page=2)] + ) + p3 = EngineeringDrawingResult( + dimensions=[DimensionElement(text="30mm", nominal=30, dimension_type="linear", + confidence=0.8, bbox=bbox, page=3)] + ) + merged = pipeline._merge_page_results([p1, p2, p3], metadata={"pages": 3}) + + assert len(merged.dimensions) == 3 + assert merged.dimensions[0].page == 1 + assert merged.dimensions[1].page == 2 + assert merged.dimensions[2].page == 3 + + def test_merge_metadata_pages_count(self): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + pipeline = _pipeline_with_elements([]) + bbox = BBoxSchema(x=0, y=0, width=10, height=10) + pages = [EngineeringDrawingResult() for _ in range(5)] + merged = pipeline._merge_page_results(pages, metadata={"pages": 5, "source": "multi.pdf"}) + assert merged.metadata.pages == 5 + + def test_page_1_default_for_single_image(self, tmp_path): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + elements = [el("50.0 mm")] + pipeline = _pipeline_with_elements(elements) + img = tmp_path / "drawing.png" + img.write_bytes(b"fake") + result = pipeline.extract_from_image(str(img)) + for dim in result.dimensions: + assert dim.page == 1 + + def test_all_sections_empty_on_empty_page(self): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + pipeline = _pipeline_with_elements([]) + result = pipeline._run_extractors([], None, metadata={"source": "blank.pdf", "page": 1}) + assert result.dimensions == [] + assert result.title_block == [] + assert result.notes == [] + assert result.gdt == [] + assert result.bom == [] + assert result.revisions == [] + + +# --------------------------------------------------------------------------- +# Malformed / degenerate input +# --------------------------------------------------------------------------- + +class TestMalformedInput: + + def test_empty_elements_list_all_extractors(self, tmp_path): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + pipeline = _pipeline_with_elements([]) + img = tmp_path / "blank.png" + img.write_bytes(b"fake") + result = pipeline.extract_from_image(str(img)) + assert result is not None + assert len(result.dimensions) == 0 + assert len(result.gdt) == 0 + + def test_elements_with_zero_confidence_no_crash(self): + from docstrange.extractors.dimensions import DimensionExtractor + elem = el("25.4 mm", confidence=0.0) + results = DimensionExtractor().extract([elem]) + assert isinstance(results, list) + for r in results: + assert 0.0 <= r.confidence <= 1.0 + + def test_picture_type_element_skipped_by_dimension_extractor(self): + from docstrange.extractors.dimensions import DimensionExtractor + elem = el("100.0 mm", element_type="picture") + assert DimensionExtractor().extract([elem]) == [] + + def test_extremely_long_text_no_crash_bom(self): + from docstrange.extractors.bom import BOMExtractor + long_row = " | ".join(["cell"] * 200) + header = el("BILL OF MATERIALS", x=50.0, y=100.0, w=200.0) + row = el(long_row, x=50.0, y=120.0, w=500.0) + results = BOMExtractor().extract([header, row]) + assert isinstance(results, list) + + def test_note_section_ends_gracefully(self): + from docstrange.extractors.notes import NoteExtractor + # Texts must not contain "note" — _NOTE_HEADER regex (NOTES?:?) would + # match that word and consume the lines as headers instead of content. + elements = [ + el("NOTES:", x=50.0, y=100.0), + el("1. ALL DIMS IN MM", x=50.0, y=120.0), + el("2. REMOVE SHARP EDGES", x=50.0, y=140.0), + ] + results = NoteExtractor().extract(elements) + numbered = [r for r in results if r.note_number is not None] + assert len(numbered) == 2 + assert numbered[0].note_number == 1 + assert numbered[1].note_number == 2 + + def test_dimension_extractor_handles_none_text_gracefully(self): + from docstrange.extractors.dimensions import DimensionExtractor + # LayoutElement where text might come out as empty + elem = el("") + results = DimensionExtractor().extract([elem]) + assert results == [] + + +# --------------------------------------------------------------------------- +# Note continuation merging +# --------------------------------------------------------------------------- + +class TestNoteContinuation: + + def test_indented_continuation_merged_into_previous_note(self): + from docstrange.extractors.notes import NoteExtractor + header = el("GENERAL NOTES:", x=50.0, y=100.0) + note1 = el("1. DO NOT SCALE DRAWING", x=50.0, y=120.0) + cont = el(" REFER TO DXF FILE", x=50.0, y=135.0) # 3 leading spaces + results = NoteExtractor().extract([header, note1, cont]) + numbered = [r for r in results if r.note_number == 1] + assert len(numbered) == 1 + assert "REFER TO DXF FILE" in numbered[0].text + + def test_multiple_numbered_notes_in_sequence(self): + from docstrange.extractors.notes import NoteExtractor + elems = [ + el("NOTES:", x=50.0, y=100.0), + el("1. ALL DIMS IN MM", x=50.0, y=120.0), + el("2. BREAK SHARP EDGES", x=50.0, y=140.0), + el("3. FINISH: ANODIZE", x=50.0, y=160.0), + ] + results = NoteExtractor().extract(elems) + nums = sorted(r.note_number for r in results if r.note_number) + assert nums == [1, 2, 3] + + +# --------------------------------------------------------------------------- +# Schema validation — Pydantic constraint enforcement +# --------------------------------------------------------------------------- + +class TestSchemaValidation: + + def test_confidence_above_one_raises(self): + from pydantic import ValidationError + bbox = BBoxSchema(x=0, y=0, width=10, height=10) + with pytest.raises(ValidationError): + DimensionElement(text="10mm", nominal=10, dimension_type="linear", + confidence=1.5, bbox=bbox) + + def test_confidence_below_zero_raises(self): + from pydantic import ValidationError + bbox = BBoxSchema(x=0, y=0, width=10, height=10) + with pytest.raises(ValidationError): + DimensionElement(text="10mm", nominal=10, dimension_type="linear", + confidence=-0.1, bbox=bbox) + + def test_extraction_metadata_version_default(self): + m = ExtractionMetadata() + assert m.extractor_version == "1.0.0" + + def test_extraction_metadata_version_non_empty(self): + m = ExtractionMetadata(source="drawing.pdf", pages=2) + assert m.extractor_version != "" + + def test_bom_row_text_auto_filled_from_raw_cells(self): + bbox = BBoxSchema(x=0, y=0, width=100, height=20) + row = BOMRow( + confidence=0.9, bbox=bbox, + item_number="3", quantity="5", + description="Washer M8", + raw_cells=["3", "5", "Washer M8"], + ) + assert "Washer M8" in row.text + + def test_revision_entry_text_auto_filled(self): + bbox = BBoxSchema(x=0, y=0, width=100, height=20) + entry = RevisionEntry( + revision="C", date="2025-03-01", description="Tolerance update", + confidence=0.85, bbox=bbox, + ) + assert entry.text != "" + assert "C" in entry.text or "Tolerance update" in entry.text + + def test_engineering_result_sections_default_to_empty_lists(self): + result = EngineeringDrawingResult() + assert result.dimensions == [] + assert result.title_block == [] + assert result.notes == [] + assert result.gdt == [] + assert result.bom == [] + assert result.revisions == [] + + def test_metadata_dict_coerced_to_extraction_metadata(self): + result = EngineeringDrawingResult(metadata={"source": "drawing.pdf", "pages": 3}) + assert isinstance(result.metadata, ExtractionMetadata) + assert result.metadata.source == "drawing.pdf" + assert result.metadata.pages == 3 + + +# --------------------------------------------------------------------------- +# Result completeness invariants +# --------------------------------------------------------------------------- + +class TestResultInvariants: + + def test_all_elements_have_page_at_least_one(self, tmp_path): + elements = [ + el("DWG NO. A-001", x=750.0, y=950.0), + el("25.4 mm", x=200.0, y=300.0), + el("GENERAL NOTES:", x=50.0, y=700.0), + el("1. ALL DIMS IN MM", x=50.0, y=720.0), + el("⊥ 0.05 A", x=400.0, y=500.0), + ] + pipeline = _pipeline_with_elements(elements) + img = tmp_path / "drawing.png" + img.write_bytes(b"fake") + result = pipeline.extract_from_image(str(img)) + + all_items = ( + result.title_block + result.dimensions + result.notes + + result.gdt + result.bom + result.revisions + ) + for item in all_items: + assert item.page >= 1, f"{type(item).__name__} has page={item.page}" + + def test_all_confidence_values_in_range(self, tmp_path): + elements = [ + el("25.4 mm"), el("⊙ 0.1 A"), el("GENERAL NOTES:", x=50.0, y=100.0), + el("1. note", x=50.0, y=120.0), + ] + pipeline = _pipeline_with_elements(elements) + img = tmp_path / "drawing.png" + img.write_bytes(b"fake") + result = pipeline.extract_from_image(str(img)) + + all_items = ( + result.title_block + result.dimensions + result.notes + + result.gdt + result.bom + result.revisions + ) + for item in all_items: + assert 0.0 <= item.confidence <= 1.0, \ + f"{type(item).__name__} confidence={item.confidence} out of range" + + def test_result_json_serialisable_with_unicode_gdt(self, tmp_path): + import json + elements = [ + el("⊥ 0.05 A"), el("⊙ 0.1 B"), el("⌀12.5"), + el("∥ 0.02"), el("⌤ 0.03"), + ] + pipeline = _pipeline_with_elements(elements) + img = tmp_path / "drawing.png" + img.write_bytes(b"fake") + result = pipeline.extract_from_image(str(img)) + serialised = json.dumps(result.model_dump()) + assert isinstance(serialised, str) + data = json.loads(serialised) + assert "dimensions" in data + assert "gdt" in data + + def test_extractor_version_in_metadata(self, tmp_path): + pipeline = _pipeline_with_elements([]) + img = tmp_path / "drawing.png" + img.write_bytes(b"fake") + result = pipeline.extract_from_image(str(img)) + assert result.metadata.extractor_version + assert isinstance(result.metadata.extractor_version, str) + + def test_selective_extraction_leaves_other_sections_empty(self, tmp_path): + elements = [ + el("25.4 mm"), el("DWG NO. A-001", x=750.0, y=950.0), + el("GENERAL NOTES:", x=50.0, y=700.0), + el("1. ALL DIMS IN MM", x=50.0, y=720.0), + ] + pipeline = _pipeline_with_elements(elements) + img = tmp_path / "drawing.png" + img.write_bytes(b"fake") + result = pipeline.extract_from_image(str(img), extractors=["gdt", "bom"]) + assert result.dimensions == [] + assert result.title_block == [] + assert result.notes == [] diff --git a/tests/test_extractors.py b/tests/test_extractors.py new file mode 100644 index 0000000..7ecb2c7 --- /dev/null +++ b/tests/test_extractors.py @@ -0,0 +1,248 @@ +"""Unit tests for engineering drawing extractors using synthetic LayoutElement fixtures.""" + +import pytest + + +def make_element(text, x=0.0, y=0.0, w=100.0, h=20.0, confidence=0.9, element_type="paragraph"): + """Create a synthetic LayoutElement without invoking OCR.""" + from docstrange.pipeline.layout_detector import LayoutElement + return LayoutElement( + text=text, + x=x, y=y, width=w, height=h, + element_type=element_type, + confidence=confidence, + ) + + +# --------------------------------------------------------------------------- +# DimensionExtractor +# --------------------------------------------------------------------------- + +class TestDimensionExtractor: + + def test_linear_dimension_nominal_only(self): + from docstrange.extractors.dimensions import DimensionExtractor + el = make_element("25.4 mm") + results = DimensionExtractor().extract([el]) + dims = [r for r in results if r.dimension_type == "linear"] + assert len(dims) >= 1 + assert dims[0].nominal == 25.4 + assert dims[0].unit == "mm" + + def test_linear_with_symmetric_tolerance(self): + from docstrange.extractors.dimensions import DimensionExtractor + el = make_element("25.4 ±0.1 mm") + results = DimensionExtractor().extract([el]) + dims = [r for r in results if r.dimension_type == "linear"] + assert len(dims) >= 1 + d = dims[0] + assert d.nominal == 25.4 + assert d.upper_tolerance == 0.1 + assert d.lower_tolerance == -0.1 + assert d.confidence >= 0.9 # full strict match + + def test_diameter_unicode(self): + from docstrange.extractors.dimensions import DimensionExtractor + el = make_element("⌀12.5") + results = DimensionExtractor().extract([el]) + diam = [r for r in results if r.dimension_type == "diameter"] + assert len(diam) >= 1 + assert diam[0].nominal == 12.5 + + def test_angular(self): + from docstrange.extractors.dimensions import DimensionExtractor + el = make_element("45°") + results = DimensionExtractor().extract([el]) + ang = [r for r in results if r.dimension_type == "angular"] + assert len(ang) >= 1 + assert ang[0].nominal == 45.0 + + def test_radial(self): + from docstrange.extractors.dimensions import DimensionExtractor + el = make_element("R12.5") + results = DimensionExtractor().extract([el]) + rad = [r for r in results if r.dimension_type == "radial"] + assert len(rad) >= 1 + assert rad[0].nominal == 12.5 + + def test_no_false_positive_on_long_serial(self): + from docstrange.extractors.dimensions import DimensionExtractor + el = make_element("PART NO. 1234567") + results = DimensionExtractor().extract([el]) + # No valid dimension should be extracted for a long serial number + assert all(r.nominal is None or r.nominal < 9999 for r in results) + + def test_picture_elements_skipped(self): + from docstrange.extractors.dimensions import DimensionExtractor + el = make_element("25.4 mm", element_type="picture") + results = DimensionExtractor().extract([el]) + assert results == [] + + def test_bbox_populated(self): + from docstrange.extractors.dimensions import DimensionExtractor + el = make_element("10.0 mm", x=120.0, y=240.0, w=80.0, h=20.0) + results = DimensionExtractor().extract([el]) + assert len(results) > 0 + bbox = results[0].bbox + assert bbox.x == 120.0 + assert bbox.y == 240.0 + assert bbox.width == 80.0 + assert bbox.height == 20.0 + + +# --------------------------------------------------------------------------- +# TitleBlockExtractor +# --------------------------------------------------------------------------- + +class TestTitleBlockExtractor: + + def test_keyword_match_in_zone(self): + from docstrange.extractors.title_block import TitleBlockExtractor + # Place in title block zone (high x, high y) + el = make_element("SCALE 1:2", x=700.0, y=900.0, w=80.0, h=15.0) + results = TitleBlockExtractor().extract([el]) + scale_fields = [r for r in results if r.field_name == "scale"] + assert len(scale_fields) >= 1 + + def test_drawing_number_detection(self): + from docstrange.extractors.title_block import TitleBlockExtractor + el = make_element("DWG NO. A-1234", x=750.0, y=950.0) + results = TitleBlockExtractor().extract([el]) + dn = [r for r in results if r.field_name == "drawing_number"] + assert len(dn) >= 1 + + def test_material_detection(self): + from docstrange.extractors.title_block import TitleBlockExtractor + el = make_element("MATERIAL: SS316", x=710.0, y=920.0) + results = TitleBlockExtractor().extract([el]) + mat = [r for r in results if r.field_name == "material"] + assert len(mat) >= 1 + + def test_low_confidence_outside_zone_without_keyword(self): + from docstrange.extractors.title_block import TitleBlockExtractor + # Outside zone, no keyword + el = make_element("Some random text", x=10.0, y=10.0, w=100.0, h=15.0) + results = TitleBlockExtractor().extract([el]) + assert all(r.confidence < 0.8 for r in results) + + def test_empty_elements(self): + from docstrange.extractors.title_block import TitleBlockExtractor + assert TitleBlockExtractor().extract([]) == [] + + +# --------------------------------------------------------------------------- +# NoteExtractor +# --------------------------------------------------------------------------- + +class TestNoteExtractor: + + def test_numbered_note(self): + from docstrange.extractors.notes import NoteExtractor + header = make_element("GENERAL NOTES:", x=50.0, y=100.0) + note1 = make_element("1. ALL DIMENSIONS IN MM", x=50.0, y=120.0) + note2 = make_element("2. REMOVE ALL BURRS", x=50.0, y=140.0) + results = NoteExtractor().extract([header, note1, note2]) + numbered = [r for r in results if r.note_number is not None] + assert len(numbered) == 2 + assert numbered[0].note_number == 1 + assert numbered[1].note_number == 2 + + def test_general_note(self): + from docstrange.extractors.notes import NoteExtractor + header = make_element("NOTES", x=50.0, y=100.0) + note = make_element("FINISH ALL OVER", x=50.0, y=120.0) + results = NoteExtractor().extract([header, note]) + general = [r for r in results if r.is_general] + assert len(general) >= 1 + + def test_no_notes_without_header(self): + from docstrange.extractors.notes import NoteExtractor + el = make_element("1. Some text without notes header", x=0.0, y=0.0) + results = NoteExtractor().extract([el]) + assert results == [] + + +# --------------------------------------------------------------------------- +# GDTExtractor +# --------------------------------------------------------------------------- + +class TestGDTExtractor: + + def test_unicode_perpendicularity(self): + from docstrange.extractors.gdt import GDTExtractor + el = make_element("⊥ 0.05 A") + results = GDTExtractor().extract([el]) + gdt = [r for r in results if r.symbol == "perpendicularity"] + assert len(gdt) >= 1 + + def test_text_abbreviation(self): + from docstrange.extractors.gdt import GDTExtractor + el = make_element("FLATNESS 0.02") + results = GDTExtractor().extract([el]) + flat = [r for r in results if r.symbol == "flatness"] + assert len(flat) >= 1 + assert flat[0].tolerance_value == "0.02" + + def test_feature_control_frame(self): + from docstrange.extractors.gdt import GDTExtractor + el = make_element("|⊥|0.05|A|") + results = GDTExtractor().extract([el]) + assert len(results) >= 1 + + def test_datum_reference_extracted(self): + from docstrange.extractors.gdt import GDTExtractor + el = make_element("⊥ 0.05 A") + results = GDTExtractor().extract([el]) + assert any(r.datum_reference == "A" for r in results) + + +# --------------------------------------------------------------------------- +# BOMExtractor +# --------------------------------------------------------------------------- + +class TestBOMExtractor: + + def test_bom_detected_from_header(self): + from docstrange.extractors.bom import BOMExtractor + header = make_element("BILL OF MATERIALS", x=50.0, y=100.0, w=200.0) + col_header = make_element("ITEM QTY PART NO. DESCRIPTION", x=50.0, y=125.0, w=400.0) + row1 = make_element("1 2 M6-BOLT HEX BOLT M6x20", x=50.0, y=145.0, w=400.0) + results = BOMExtractor().extract([header, col_header, row1]) + assert len(results) >= 1 + + def test_no_bom_without_header(self): + from docstrange.extractors.bom import BOMExtractor + el = make_element("1 2 M6-BOLT HEX BOLT M6x20", x=50.0, y=145.0) + results = BOMExtractor().extract([el]) + assert results == [] + + def test_empty_elements(self): + from docstrange.extractors.bom import BOMExtractor + assert BOMExtractor().extract([]) == [] + + +# --------------------------------------------------------------------------- +# RevisionExtractor +# --------------------------------------------------------------------------- + +class TestRevisionExtractor: + + def test_revision_entry_detected(self): + from docstrange.extractors.revisions import RevisionExtractor + header = make_element("REVISION HISTORY", x=700.0, y=10.0) + entry = make_element("B 2024-01-15 Updated tolerances J.Smith", x=700.0, y=30.0) + results = RevisionExtractor().extract([header, entry]) + assert len(results) >= 1 + + def test_date_extracted(self): + from docstrange.extractors.revisions import RevisionExtractor + header = make_element("REV BLOCK", x=700.0, y=10.0) + entry = make_element("A 01/15/2024 Initial release", x=700.0, y=30.0) + results = RevisionExtractor().extract([header, entry]) + assert any(r.date is not None for r in results) + + def test_no_results_without_header(self): + from docstrange.extractors.revisions import RevisionExtractor + el = make_element("A 01/15/2024 Initial release") + results = RevisionExtractor().extract([el]) + assert results == [] diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py new file mode 100644 index 0000000..937fe34 --- /dev/null +++ b/tests/test_mcp_server.py @@ -0,0 +1,282 @@ +"""Tests for the DocStrange Engineering MCP server layer. + +The MCP server is exercised by calling _dispatch() and _extract() directly — +no real OCR models or stdio transport needed. +""" + +import json +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +pytest.importorskip("mcp", reason="mcp package not installed") + +from docstrange.schemas.engineering import ( + BBoxSchema, + BOMRow, + DimensionElement, + EngineeringDrawingResult, + GDTElement, + NoteElement, + RevisionEntry, + TitleBlockField, +) + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +_BBOX = BBoxSchema(x=10.0, y=20.0, width=50.0, height=15.0) + +_FULL_RESULT = EngineeringDrawingResult( + title_block=[TitleBlockField(text="DWG-001", field_name="drawing_number", + field_value="DWG-001", confidence=0.95, bbox=_BBOX)], + dimensions=[DimensionElement(text="25.4 mm", nominal=25.4, unit="mm", + dimension_type="linear", confidence=0.9, bbox=_BBOX)], + notes=[NoteElement(text="ALL DIMS IN MM", is_general=True, + confidence=0.85, bbox=_BBOX)], + gdt=[GDTElement(text="⊥ 0.05 A", symbol="perpendicularity", + tolerance_value="0.05", datum_reference="A", + confidence=0.88, bbox=_BBOX)], + bom=[BOMRow(item_number="1", quantity="2", description="Bolt M6", + raw_cells=["1", "2", "Bolt M6"], confidence=0.9, bbox=_BBOX)], + revisions=[RevisionEntry(revision="A", date="2024-01-15", + description="Initial release", confidence=0.8, bbox=_BBOX)], + metadata={"source": "test.pdf", "pages": 1}, +) + + +@pytest.fixture +def tmp_png(tmp_path): + """Write a minimal PNG file and return its path string.""" + p = tmp_path / "drawing.png" + p.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) # minimal header + return str(p) + + +@pytest.fixture +def server(tmp_png): + """Return an EngineeringMCPServer with the pipeline swapped for a mock.""" + from docstrange.mcp_server.server import EngineeringMCPServer + + mock_pipeline = MagicMock() + mock_pipeline.extract_from_image.return_value = _FULL_RESULT + mock_pipeline.extract_from_pdf.return_value = _FULL_RESULT + + srv = EngineeringMCPServer.__new__(EngineeringMCPServer) + # Bypass __init__ to avoid mcp.Server instantiation in tests + from docstrange.mcp_server.cache import ExtractionCache + srv._pipeline = mock_pipeline + srv._overlay_gen = None + srv._cache = ExtractionCache(maxsize=5) + return srv + + +def _parse(content_list) -> dict | list: + return json.loads(content_list[0].text) + + +# --------------------------------------------------------------------------- +# Tool registry +# --------------------------------------------------------------------------- + +class TestToolRegistry: + + def test_all_eight_tools_present(self): + from docstrange.mcp_server.tools import TOOLS + names = {t.name for t in TOOLS} + expected = { + "extract_dimensions", "extract_title_block", "extract_notes", + "extract_gdt", "extract_bom", "extract_revisions", + "extract_full", "generate_overlays", + } + assert expected == names + + def test_generate_overlays_has_page_param(self): + from docstrange.mcp_server.tools import TOOLS + t = next(t for t in TOOLS if t.name == "generate_overlays") + assert "page" in t.inputSchema["properties"] + + def test_extract_full_has_page_param(self): + from docstrange.mcp_server.tools import TOOLS + t = next(t for t in TOOLS if t.name == "extract_full") + assert "page" in t.inputSchema["properties"] + + def test_every_tool_requires_file_path(self): + from docstrange.mcp_server.tools import TOOLS + for t in TOOLS: + assert "file_path" in t.inputSchema.get("required", []), \ + f"{t.name} missing file_path in required" + + +# --------------------------------------------------------------------------- +# Dispatch — happy paths +# --------------------------------------------------------------------------- + +class TestDispatch: + + @pytest.mark.anyio + async def test_extract_dimensions(self, server, tmp_png): + data = _parse(await server._dispatch("extract_dimensions", {"file_path": tmp_png})) + assert isinstance(data, list) + assert data[0]["type"] == "dimension" + assert data[0]["nominal"] == 25.4 + + @pytest.mark.anyio + async def test_extract_title_block(self, server, tmp_png): + data = _parse(await server._dispatch("extract_title_block", {"file_path": tmp_png})) + assert data[0]["field_name"] == "drawing_number" + + @pytest.mark.anyio + async def test_extract_notes(self, server, tmp_png): + data = _parse(await server._dispatch("extract_notes", {"file_path": tmp_png})) + assert data[0]["type"] == "note" + + @pytest.mark.anyio + async def test_extract_gdt(self, server, tmp_png): + data = _parse(await server._dispatch("extract_gdt", {"file_path": tmp_png})) + assert data[0]["symbol"] == "perpendicularity" + + @pytest.mark.anyio + async def test_extract_bom(self, server, tmp_png): + data = _parse(await server._dispatch("extract_bom", {"file_path": tmp_png})) + assert data[0]["type"] == "bom" + assert "Bolt M6" in data[0]["text"] + + @pytest.mark.anyio + async def test_extract_revisions(self, server, tmp_png): + data = _parse(await server._dispatch("extract_revisions", {"file_path": tmp_png})) + assert data[0]["revision"] == "A" + + @pytest.mark.anyio + async def test_extract_full(self, server, tmp_png): + data = _parse(await server._dispatch("extract_full", {"file_path": tmp_png})) + assert "title_block" in data + assert "dimensions" in data + assert "metadata" in data + assert "extractor_version" in data["metadata"] + + @pytest.mark.anyio + async def test_extract_full_with_overlays(self, server, tmp_png): + data = _parse(await server._dispatch("extract_full", { + "file_path": tmp_png, + "include_overlays": True, + "image_width": 1000, + "image_height": 800, + })) + assert "overlay_json" in data + assert "annotations" in data["overlay_json"] + + @pytest.mark.anyio + async def test_generate_overlays(self, server, tmp_png): + data = _parse(await server._dispatch("generate_overlays", { + "file_path": tmp_png, + "image_width": 1000, + "image_height": 800, + })) + assert "annotations" in data + assert "total_annotations" in data + + @pytest.mark.anyio + async def test_generate_overlays_page_filter(self, server, tmp_png): + data = _parse(await server._dispatch("generate_overlays", { + "file_path": tmp_png, + "image_width": 1000, + "image_height": 800, + "page": 1, + })) + assert data["page_filter"] == 1 + + @pytest.mark.anyio + async def test_unknown_tool_returns_error(self, server, tmp_png): + data = _parse(await server._dispatch("nonexistent_tool", {"file_path": tmp_png})) + assert data["status"] == "error" + assert "Unknown tool" in data["error"] + + +# --------------------------------------------------------------------------- +# Validation — file path checks +# --------------------------------------------------------------------------- + +class TestValidation: + + def test_file_not_found_raises(self, server): + with pytest.raises(FileNotFoundError): + server._extract("/nonexistent/drawing.png") + + def test_unsupported_extension_raises(self, server, tmp_path): + bad = tmp_path / "drawing.docx" + bad.write_bytes(b"fake") + with pytest.raises(ValueError, match="Unsupported file type"): + server._extract(str(bad)) + + def test_directory_path_raises(self, server, tmp_path): + with pytest.raises(ValueError, match="not a regular file"): + server._extract(str(tmp_path)) + + @pytest.mark.anyio + async def test_dispatch_wraps_file_not_found_as_error_json(self, server): + data = _parse(await server._dispatch("extract_dimensions", { + "file_path": "/no/such/file.png" + })) + assert data["status"] == "error" + assert "not found" in data["error"].lower() + + +# --------------------------------------------------------------------------- +# LRU cache +# --------------------------------------------------------------------------- + +class TestExtractionCache: + + def test_miss_returns_none(self, tmp_png): + from docstrange.mcp_server.cache import ExtractionCache + cache = ExtractionCache(maxsize=5) + assert cache.get(tmp_png, None) is None + + def test_put_then_get(self, tmp_png): + from docstrange.mcp_server.cache import ExtractionCache + cache = ExtractionCache(maxsize=5) + cache.put(tmp_png, None, _FULL_RESULT) + assert cache.get(tmp_png, None) is _FULL_RESULT + + def test_lru_eviction(self, tmp_path): + from docstrange.mcp_server.cache import ExtractionCache + cache = ExtractionCache(maxsize=3) + files = [] + for i in range(4): + p = tmp_path / f"f{i}.png" + p.write_bytes(b"\x89PNG" + bytes([i]) * 20) + files.append(str(p)) + + # Fill cache with f0, f1, f2 + for f in files[:3]: + cache.put(f, None, f"result_{f}") + assert len(cache) == 3 + + # Access f0 to make f1 the LRU + cache.get(files[0], None) + + # Insert f3 — f1 (LRU) should be evicted, not f0 + cache.put(files[3], None, "result_f3") + assert len(cache) == 3 + assert cache.get(files[1], None) is None # evicted + assert cache.get(files[0], None) is not None # still present + + def test_extractors_tuple_is_part_of_key(self, tmp_png): + from docstrange.mcp_server.cache import ExtractionCache + cache = ExtractionCache(maxsize=5) + cache.put(tmp_png, ["dimensions"], "dims_result") + cache.put(tmp_png, ["notes"], "notes_result") + assert cache.get(tmp_png, ["dimensions"]) == "dims_result" + assert cache.get(tmp_png, ["notes"]) == "notes_result" + assert cache.get(tmp_png, None) is None + + def test_cache_hit_in_server(self, server, tmp_png): + """Second call to _extract() should use cache, not re-invoke pipeline.""" + server._extract(tmp_png) + server._extract(tmp_png) + assert server._pipeline.extract_from_image.call_count == 1 diff --git a/tests/test_overlays.py b/tests/test_overlays.py new file mode 100644 index 0000000..36c9d3b --- /dev/null +++ b/tests/test_overlays.py @@ -0,0 +1,218 @@ +"""Tests for Phase 6 — Overlay JSON generation. + +Covers change_id sequencing, bbox field naming, summary section, +page filtering, and cross-run stability of sequential IDs. +""" + +import re + +import pytest + +from docstrange.overlays.generator import OverlayGenerator +from docstrange.schemas.engineering import ( + BBoxSchema, + BOMRow, + DimensionElement, + EngineeringDrawingResult, + GDTElement, + NoteElement, + RevisionEntry, + TitleBlockField, +) + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +_BBOX = BBoxSchema(x=10.0, y=20.0, width=50.0, height=15.0) + +_FULL_RESULT = EngineeringDrawingResult( + title_block=[TitleBlockField(text="DWG-001", field_name="drawing_number", + field_value="DWG-001", confidence=0.95, bbox=_BBOX)], + dimensions=[DimensionElement(text="25.4 mm", nominal=25.4, unit="mm", + dimension_type="linear", confidence=0.9, bbox=_BBOX), + DimensionElement(text="12.7 mm", nominal=12.7, unit="mm", + dimension_type="linear", confidence=0.85, bbox=_BBOX)], + notes=[NoteElement(text="ALL DIMS IN MM", is_general=True, + confidence=0.85, bbox=_BBOX)], + gdt=[GDTElement(text="⊥ 0.05 A", symbol="perpendicularity", + tolerance_value="0.05", datum_reference="A", + confidence=0.88, bbox=_BBOX)], + bom=[BOMRow(item_number="1", quantity="2", description="Bolt M6", + raw_cells=["1", "2", "Bolt M6"], confidence=0.9, bbox=_BBOX)], + revisions=[RevisionEntry(revision="A", date="2024-01-15", + description="Initial release", confidence=0.8, bbox=_BBOX)], + metadata={"source": "test.pdf", "pages": 1}, +) + + +@pytest.fixture +def gen(): + return OverlayGenerator() + + +@pytest.fixture +def overlay(gen): + return gen.generate(_FULL_RESULT, image_width=1000, image_height=800) + + +# --------------------------------------------------------------------------- +# change_id field +# --------------------------------------------------------------------------- + +class TestChangeId: + + def test_every_annotation_has_change_id(self, overlay): + for ann in overlay["annotations"]: + assert "change_id" in ann + + def test_change_id_format(self, overlay): + pattern = re.compile(r"^chg_\d{3}$") + for ann in overlay["annotations"]: + assert pattern.match(ann["change_id"]), \ + f"Bad change_id format: {ann['change_id']}" + + def test_change_ids_are_sequential(self, overlay): + ids = [ann["change_id"] for ann in overlay["annotations"]] + for i, cid in enumerate(ids, start=1): + assert cid == f"chg_{i:03d}", \ + f"Expected chg_{i:03d}, got {cid}" + + def test_change_ids_are_unique(self, overlay): + ids = [ann["change_id"] for ann in overlay["annotations"]] + assert len(ids) == len(set(ids)) + + def test_change_ids_stable_across_reruns(self, gen): + """Same extraction → same change_id sequence every time.""" + overlay1 = gen.generate(_FULL_RESULT, image_width=1000, image_height=800) + overlay2 = gen.generate(_FULL_RESULT, image_width=1000, image_height=800) + ids1 = [a["change_id"] for a in overlay1["annotations"]] + ids2 = [a["change_id"] for a in overlay2["annotations"]] + assert ids1 == ids2 + + def test_no_uuid_id_field(self, overlay): + """Old random `id` field must not be present.""" + for ann in overlay["annotations"]: + assert "id" not in ann + + +# --------------------------------------------------------------------------- +# bbox field (pixel-space, flat) +# --------------------------------------------------------------------------- + +class TestBboxField: + + def test_bbox_present_on_every_annotation(self, overlay): + for ann in overlay["annotations"]: + assert "bbox" in ann + + def test_bbox_has_required_keys(self, overlay): + for ann in overlay["annotations"]: + bbox = ann["bbox"] + for key in ("x", "y", "width", "height"): + assert key in bbox, f"Missing bbox key: {key}" + + def test_bbox_values_match_source(self, overlay): + for ann in overlay["annotations"]: + assert ann["bbox"]["x"] == 10.0 + assert ann["bbox"]["y"] == 20.0 + assert ann["bbox"]["width"] == 50.0 + assert ann["bbox"]["height"] == 15.0 + + def test_bbox_normalized_also_present(self, overlay): + for ann in overlay["annotations"]: + assert "bbox_normalized" in ann + + def test_no_bbox_pixels_field(self, overlay): + """Old `bbox_pixels` key must not leak through.""" + for ann in overlay["annotations"]: + assert "bbox_pixels" not in ann + + def test_bbox_normalized_values_in_range(self, overlay): + for ann in overlay["annotations"]: + nb = ann["bbox_normalized"] + for key in ("x", "y", "width", "height"): + assert 0.0 <= nb[key] <= 1.0, \ + f"Normalized {key}={nb[key]} out of [0,1]" + + +# --------------------------------------------------------------------------- +# Summary section +# --------------------------------------------------------------------------- + +class TestSummary: + + def test_summary_present(self, overlay): + assert "summary" in overlay + + def test_summary_total_matches_annotations(self, overlay): + assert overlay["summary"]["total"] == len(overlay["annotations"]) + + def test_summary_by_type_keys(self, overlay): + by_type = overlay["summary"]["by_type"] + expected_types = {"dimension", "title_block", "note", "gdt", "bom", "revision"} + assert expected_types == set(by_type.keys()) + + def test_summary_by_type_counts(self, overlay): + by_type = overlay["summary"]["by_type"] + assert by_type["dimension"] == 2 + assert by_type["title_block"] == 1 + assert by_type["gdt"] == 1 + + def test_summary_by_type_sum_equals_total(self, overlay): + by_type = overlay["summary"]["by_type"] + assert sum(by_type.values()) == overlay["summary"]["total"] + + +# --------------------------------------------------------------------------- +# Page filtering interacts correctly with change_id +# --------------------------------------------------------------------------- + +class TestPageFilterWithChangeId: + + def test_page_filtered_ids_restart_at_chg_001(self, gen): + """Page-filtered output renumbers from chg_001, not from the global index.""" + result = EngineeringDrawingResult( + dimensions=[ + DimensionElement(text="10 mm", nominal=10.0, unit="mm", + dimension_type="linear", confidence=0.9, + bbox=_BBOX, page=1), + DimensionElement(text="20 mm", nominal=20.0, unit="mm", + dimension_type="linear", confidence=0.9, + bbox=_BBOX, page=2), + ], + ) + overlay = gen.generate(result, image_width=500, image_height=400, page=2) + assert len(overlay["annotations"]) == 1 + assert overlay["annotations"][0]["change_id"] == "chg_001" + + def test_page_filter_excludes_other_pages(self, gen): + result = EngineeringDrawingResult( + dimensions=[ + DimensionElement(text="10 mm", nominal=10.0, unit="mm", + dimension_type="linear", confidence=0.9, + bbox=_BBOX, page=1), + DimensionElement(text="20 mm", nominal=20.0, unit="mm", + dimension_type="linear", confidence=0.9, + bbox=_BBOX, page=2), + ], + ) + overlay = gen.generate(result, image_width=500, image_height=400, page=1) + assert all(ann["page"] == 1 for ann in overlay["annotations"]) + + +# --------------------------------------------------------------------------- +# Overlay with zero image dimensions +# --------------------------------------------------------------------------- + +class TestZeroDimensions: + + def test_no_crash_on_zero_dimensions(self, gen): + overlay = gen.generate(_FULL_RESULT) + assert overlay["total_annotations"] > 0 + + def test_bbox_normalized_zero_when_no_dimensions(self, gen): + overlay = gen.generate(_FULL_RESULT) + for ann in overlay["annotations"]: + nb = ann["bbox_normalized"] + assert nb == {"x": 0.0, "y": 0.0, "width": 0.0, "height": 0.0} diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..a79d5c6 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,246 @@ +"""Integration tests for EngineeringDrawingPipeline using mocked OCR service.""" + +from unittest.mock import MagicMock, patch +import pytest + + +def make_element(text, x=0.0, y=0.0, w=100.0, h=20.0, confidence=0.9): + from docstrange.pipeline.layout_detector import LayoutElement + return LayoutElement(text=text, x=x, y=y, width=w, height=h, + element_type="paragraph", confidence=confidence) + + +@pytest.fixture +def mock_elements(): + """A minimal set of synthetic layout elements covering multiple extractor types.""" + return [ + # Title block zone elements + make_element("SCALE 1:1", x=750.0, y=900.0, w=80.0, h=15.0), + make_element("DWG NO. ASSY-001", x=750.0, y=920.0, w=120.0, h=15.0), + # Dimensions + make_element("25.4 mm", x=200.0, y=300.0), + make_element("⌀10.0", x=350.0, y=400.0), + # Notes section + make_element("GENERAL NOTES:", x=50.0, y=700.0), + make_element("1. ALL DIMENSIONS IN MM", x=50.0, y=720.0), + # GD&T + make_element("⊥ 0.05 A", x=400.0, y=500.0), + # BOM + make_element("BILL OF MATERIALS", x=50.0, y=100.0, w=200.0), + make_element("1 2 M6-BOLT Hex bolt", x=50.0, y=130.0, w=400.0), + ] + + +@pytest.fixture +def mock_ocr_service(mock_elements): + service = MagicMock() + service.extract_layout_elements.return_value = mock_elements + return service + + +class TestEngineeringDrawingPipeline: + + def test_extract_from_image_calls_ocr(self, tmp_path, mock_ocr_service, mock_elements): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + # Create a dummy image file so path exists + img_file = tmp_path / "drawing.png" + img_file.write_bytes(b"fake") + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + result = pipeline.extract_from_image(str(img_file)) + + mock_ocr_service.extract_layout_elements.assert_called_once_with(str(img_file)) + assert result is not None + + def test_dimensions_extracted(self, tmp_path, mock_ocr_service): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + img_file = tmp_path / "drawing.png" + img_file.write_bytes(b"fake") + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + result = pipeline.extract_from_image(str(img_file)) + + assert len(result.dimensions) > 0 + + def test_title_block_extracted(self, tmp_path, mock_ocr_service): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + img_file = tmp_path / "drawing.png" + img_file.write_bytes(b"fake") + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + result = pipeline.extract_from_image(str(img_file)) + + assert len(result.title_block) > 0 + + def test_selective_extractors(self, tmp_path, mock_ocr_service): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + img_file = tmp_path / "drawing.png" + img_file.write_bytes(b"fake") + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + result = pipeline.extract_from_image(str(img_file), extractors=["dimensions"]) + + # Only dimensions should be populated + assert len(result.dimensions) > 0 + assert result.notes == [] + assert result.title_block == [] + + def test_result_is_serialisable(self, tmp_path, mock_ocr_service): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + img_file = tmp_path / "drawing.png" + img_file.write_bytes(b"fake") + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + result = pipeline.extract_from_image(str(img_file)) + + import json + serialised = json.dumps(result.model_dump()) + assert isinstance(serialised, str) + + def test_file_not_found_raises(self, mock_ocr_service): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + with pytest.raises(FileNotFoundError): + pipeline.extract_from_image("/nonexistent/drawing.png") + + def test_merge_page_results(self, mock_ocr_service): + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + from docstrange.schemas.engineering import EngineeringDrawingResult, DimensionElement, BBoxSchema + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + bbox = BBoxSchema(x=0, y=0, width=10, height=10) + p1 = EngineeringDrawingResult( + dimensions=[DimensionElement(text="10mm", nominal=10, dimension_type="linear", + confidence=0.9, bbox=bbox)] + ) + p2 = EngineeringDrawingResult( + dimensions=[DimensionElement(text="20mm", nominal=20, dimension_type="linear", + confidence=0.85, bbox=bbox)] + ) + merged = pipeline._merge_page_results([p1, p2], metadata={}) + assert len(merged.dimensions) == 2 + + +class TestOverlayGenerator: + + def test_generate_returns_annotations(self, mock_elements, mock_ocr_service, tmp_path): + from docstrange.overlays.generator import OverlayGenerator + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + img_file = tmp_path / "drawing.png" + img_file.write_bytes(b"fake") + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + result = pipeline.extract_from_image(str(img_file)) + + gen = OverlayGenerator() + overlay = gen.generate(result, image_width=1000, image_height=800) + + assert "annotations" in overlay + assert "image_size" in overlay + assert overlay["image_size"]["width"] == 1000 + assert overlay["total_annotations"] == len(overlay["annotations"]) + + def test_normalised_coordinates_in_range(self, mock_elements, mock_ocr_service, tmp_path): + from docstrange.overlays.generator import OverlayGenerator + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + img_file = tmp_path / "drawing.png" + img_file.write_bytes(b"fake") + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + result = pipeline.extract_from_image(str(img_file)) + + gen = OverlayGenerator() + overlay = gen.generate(result, image_width=1000, image_height=800) + + for ann in overlay["annotations"]: + nb = ann["bbox_normalized"] + assert 0.0 <= nb["x"] <= 1.0 + assert 0.0 <= nb["y"] <= 1.0 + + def test_annotations_carry_page_field(self, mock_ocr_service, tmp_path): + from docstrange.overlays.generator import OverlayGenerator + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + + img_file = tmp_path / "drawing.png" + img_file.write_bytes(b"fake") + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + result = pipeline.extract_from_image(str(img_file)) + + gen = OverlayGenerator() + overlay = gen.generate(result, image_width=1000, image_height=800) + + for ann in overlay["annotations"]: + assert "page" in ann + assert isinstance(ann["page"], int) + + def test_page_filter_excludes_other_pages(self, mock_ocr_service, tmp_path): + """Page filter must return only annotations from the requested page.""" + from docstrange.overlays.generator import OverlayGenerator + from docstrange.pipelines.engineering import EngineeringDrawingPipeline + from docstrange.schemas.engineering import EngineeringDrawingResult, DimensionElement, BBoxSchema + + pipeline = EngineeringDrawingPipeline(ocr_service=mock_ocr_service) + bbox = BBoxSchema(x=0, y=0, width=50, height=20) + result = EngineeringDrawingResult( + dimensions=[ + DimensionElement(text="10mm", nominal=10, dimension_type="linear", + confidence=0.9, bbox=bbox, page=1), + DimensionElement(text="20mm", nominal=20, dimension_type="linear", + confidence=0.85, bbox=bbox, page=2), + ] + ) + + gen = OverlayGenerator() + overlay_p1 = gen.generate(result, image_width=1000, image_height=800, page=1) + overlay_p2 = gen.generate(result, image_width=1000, image_height=800, page=2) + overlay_all = gen.generate(result, image_width=1000, image_height=800) + + assert overlay_p1["total_annotations"] == 1 + assert overlay_p1["annotations"][0]["text"] == "10mm" + assert overlay_p2["total_annotations"] == 1 + assert overlay_p2["annotations"][0]["text"] == "20mm" + assert overlay_all["total_annotations"] == 2 + + def test_bom_row_annotation_uses_text_field(self, mock_ocr_service, tmp_path): + """BOMRow overlay annotation must use the auto-filled .text field from Phase 2.""" + from docstrange.overlays.generator import OverlayGenerator + from docstrange.schemas.engineering import EngineeringDrawingResult, BOMRow, BBoxSchema + + bbox = BBoxSchema(x=10, y=50, width=300, height=20) + result = EngineeringDrawingResult( + bom=[BOMRow( + confidence=0.9, + bbox=bbox, + item_number="1", + description="Hex bolt", + raw_cells=["1", "2", "Hex bolt"], + )] + ) + + gen = OverlayGenerator() + overlay = gen.generate(result, image_width=1000, image_height=800) + + assert overlay["total_annotations"] == 1 + ann = overlay["annotations"][0] + assert ann["type"] == "bom" + assert "Hex bolt" in ann["text"] + + def test_overlay_no_normalisation_when_dimensions_zero(self, mock_ocr_service): + """When image dimensions are 0 and no path given, bbox_normalized is all-zeros.""" + from docstrange.overlays.generator import OverlayGenerator + from docstrange.schemas.engineering import EngineeringDrawingResult, DimensionElement, BBoxSchema + + bbox = BBoxSchema(x=10, y=20, width=50, height=15) + result = EngineeringDrawingResult( + dimensions=[DimensionElement(text="5mm", nominal=5, + dimension_type="linear", confidence=0.8, bbox=bbox)] + ) + gen = OverlayGenerator() + overlay = gen.generate(result) # no image_width / image_height + + assert overlay["annotations"][0]["bbox_normalized"] == { + "x": 0.0, "y": 0.0, "width": 0.0, "height": 0.0 + } diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..8254dfe --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,69 @@ +"""Tests for engineering drawing Pydantic schemas.""" + +import pytest + + +def test_bbox_schema(): + from docstrange.schemas.engineering import BBoxSchema + bbox = BBoxSchema(x=10, y=20, width=80, height=15) + assert bbox.x == 10 + assert bbox.y == 20 + d = bbox.model_dump() + assert set(d.keys()) == {"x", "y", "width", "height"} + + +def test_extraction_element_confidence_bounds(): + from pydantic import ValidationError + from docstrange.schemas.engineering import BBoxSchema, ExtractionElement + bbox = BBoxSchema(x=0, y=0, width=50, height=10) + with pytest.raises(ValidationError): + ExtractionElement(text="foo", type="test", confidence=1.5, bbox=bbox) + with pytest.raises(ValidationError): + ExtractionElement(text="foo", type="test", confidence=-0.1, bbox=bbox) + el = ExtractionElement(text="foo", type="test", confidence=0.95, bbox=bbox) + assert el.confidence == 0.95 + + +def test_dimension_element_defaults(): + from docstrange.schemas.engineering import BBoxSchema, DimensionElement + bbox = BBoxSchema(x=0, y=0, width=60, height=12) + d = DimensionElement(text="25.4", confidence=0.9, bbox=bbox) + assert d.type == "dimension" + assert d.nominal is None + assert d.unit is None + + +def test_title_block_field(): + from docstrange.schemas.engineering import BBoxSchema, TitleBlockField + bbox = BBoxSchema(x=700, y=900, width=80, height=15) + f = TitleBlockField(text="SCALE 1:2", field_name="scale", field_value="1:2", + confidence=0.95, bbox=bbox) + assert f.type == "title_block" + assert f.field_name == "scale" + + +def test_engineering_drawing_result_empty(): + from docstrange.schemas.engineering import EngineeringDrawingResult + r = EngineeringDrawingResult() + assert r.dimensions == [] + assert r.title_block == [] + d = r.model_dump() + assert "dimensions" in d + assert "metadata" in d + + +def test_bom_row_schema(): + from docstrange.schemas.engineering import BBoxSchema, BOMRow + bbox = BBoxSchema(x=0, y=100, width=200, height=15) + row = BOMRow(item_number="1", quantity="2", description="Bolt M6", confidence=0.85, bbox=bbox) + assert row.item_number == "1" + assert row.quantity == "2" + + +def test_revision_entry_schema(): + from docstrange.schemas.engineering import BBoxSchema, RevisionEntry + bbox = BBoxSchema(x=700, y=50, width=150, height=15) + entry = RevisionEntry(revision="B", date="2024-01-15", description="Updated tolerances", + confidence=0.9, bbox=bbox) + assert entry.revision == "B" + assert entry.date == "2024-01-15"