diff --git a/.dockerignore b/.dockerignore index 33be4b6ee1..4988174bbb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -88,6 +88,13 @@ stirling/ customFiles/ configs/ +# Local dev runtime dirs created by bootRun under the module dir (hold the locked H2 DB, +# downloaded models, logs); never build input. Root-level patterns above don't match these. +**/configs/ +app/core/customFiles/ +app/core/logs/ +app/core/pipeline/ + # Claude Code workspace .claude/ diff --git a/app/common/src/main/java/stirling/software/SPDF/config/EndpointConfiguration.java b/app/common/src/main/java/stirling/software/SPDF/config/EndpointConfiguration.java index a154ed2a7a..f80297dda9 100644 --- a/app/common/src/main/java/stirling/software/SPDF/config/EndpointConfiguration.java +++ b/app/common/src/main/java/stirling/software/SPDF/config/EndpointConfiguration.java @@ -536,6 +536,7 @@ public void init() { addEndpointToGroup("Java", "pdf-to-epub"); addEndpointToGroup("Java", "eml-to-pdf"); addEndpointToGroup("Java", "handleData"); + addEndpointToGroup("Java", "form-detection"); addEndpointToGroup("rar", "pdf-to-cbr"); // Javascript diff --git a/app/common/src/main/java/stirling/software/common/configuration/RuntimePathConfig.java b/app/common/src/main/java/stirling/software/common/configuration/RuntimePathConfig.java index b1ecc1020e..58f6b8fb65 100644 --- a/app/common/src/main/java/stirling/software/common/configuration/RuntimePathConfig.java +++ b/app/common/src/main/java/stirling/software/common/configuration/RuntimePathConfig.java @@ -41,6 +41,9 @@ public class RuntimePathConfig { // Tesseract data path private final String tessDataPath; + // Auto Form Detection model directory + private final String formDetectionModelPath; + private final List unoServerEndpoints; // Pipeline paths @@ -131,6 +134,22 @@ public RuntimePathConfig(ApplicationProperties properties) { log.info("Using Tesseract data path: {}", this.tessDataPath); + // Auto Form Detection model directory (kept under so it survives + // restarts/updates) + String configuredModelDir = + properties.getFormDetection() != null + ? properties.getFormDetection().getModelDir() + : null; + this.formDetectionModelPath = + StringUtils.isNotBlank(configuredModelDir) + ? configuredModelDir + : Path.of( + InstallationPathConfig.getConfigPath(), + "models", + "form-detection") + .toString(); + log.info("Using Auto Form Detection model path: {}", this.formDetectionModelPath); + ApplicationProperties.ProcessExecutor processExecutor = properties.getProcessExecutor(); int libreOfficeLimit = 1; if (processExecutor != null && processExecutor.getSessionLimit() != null) { diff --git a/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java b/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java index 067be54eb6..0fbdc6eeeb 100644 --- a/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java +++ b/app/common/src/main/java/stirling/software/common/model/ApplicationProperties.java @@ -77,6 +77,7 @@ public class ApplicationProperties { private ProcessExecutor processExecutor = new ProcessExecutor(); private PdfEditor pdfEditor = new PdfEditor(); private AiEngine aiEngine = new AiEngine(); + private FormDetection formDetection = new FormDetection(); private Mcp mcp = new Mcp(); private InternalApi internalApi = new InternalApi(); private Cluster cluster = new Cluster(); @@ -297,6 +298,37 @@ public static class AiEngine { private int longRunningTimeoutSeconds = 600; } + /** + * Auto Form Detection settings. The model itself is downloaded on demand by an admin (see + * {@code /api/v1/ai/form-detection-model/*}); only lightweight pointers are persisted here. + */ + @Data + public static class FormDetection { + /** Master on/off switch for the whole feature (admin-controlled). */ + private boolean enabled = true; + + /** + * Where detection runs: {@code auto} (browser first, server fallback), {@code browser} + * (in-browser WASM only - the PDF never leaves the device), or {@code server} (backend + * inference). Read by the frontend tool to choose its pipeline. + */ + private String executionMode = "auto"; + + /** Id of the installed model; blank means none installed. */ + private String activeModelId = ""; + + /** Optional override dir; blank uses {@code /models/form-detection}. */ + private String modelDir = ""; + + /** + * Read-only dir of models baked into the image (e.g. the Docker server image pre-downloads + * FFDNet-S here). On startup any {@code .onnx} found here is copied into the + * writable model dir if not already present, and activated if no model is active - so the + * feature works out-of-the-box. Blank (default) disables seeding. + */ + private String preinstalledModelDir = ""; + } + /** * Model Context Protocol (MCP) server configuration. All keys live under the top-level {@code * mcp.*} prefix. {@link #enabled} defaults to {@code false}: when off, no MCP beans are wired, diff --git a/app/common/src/main/java/stirling/software/common/util/FormUtils.java b/app/common/src/main/java/stirling/software/common/util/FormUtils.java index 25cb4b5825..cf95a2093a 100644 --- a/app/common/src/main/java/stirling/software/common/util/FormUtils.java +++ b/app/common/src/main/java/stirling/software/common/util/FormUtils.java @@ -944,6 +944,81 @@ private PDAcroForm getAcroFormSafely(PDDocument document) { } } + /** + * Create new AcroForm fields from a list of definitions (used by Auto Form Detection). Reuses + * the same field-creation and appearance logic as the rest of this class, and creates the + * AcroForm (with a Helvetica default resource) when the document has none. Field names are made + * unique against any existing fields. + */ + public void addFields(PDDocument document, List definitions) + throws IOException { + if (document == null || definitions == null || definitions.isEmpty()) { + return; + } + PDDocumentCatalog documentCatalog = document.getDocumentCatalog(); + PDAcroForm acroForm = documentCatalog.getAcroForm(); + if (acroForm == null) { + acroForm = new PDAcroForm(document); + PDResources dr = new PDResources(); + dr.put(COSName.getPDFName("Helv"), new PDType1Font(Standard14Fonts.FontName.HELVETICA)); + acroForm.setDefaultResources(dr); + acroForm.setNeedAppearances(true); + documentCatalog.setAcroForm(acroForm); + } + + Set existingNames = new java.util.HashSet<>(); + for (PDField field : acroForm.getFieldTree()) { + if (field.getPartialName() != null) { + existingNames.add(field.getPartialName()); + } + } + + int pageCount = document.getNumberOfPages(); + for (NewFormFieldDefinition definition : definitions) { + Integer pageIndex = definition.pageIndex(); + if (pageIndex == null + || pageIndex < 0 + || pageIndex >= pageCount + || definition.x() == null + || definition.y() == null + || definition.width() == null + || definition.height() == null) { + continue; + } + PDPage page = document.getPage(pageIndex); + PDRectangle rectangle = + new PDRectangle( + definition.x(), + definition.y(), + definition.width(), + definition.height()); + FormFieldTypeSupport handler = FormFieldTypeSupport.forTypeName(definition.type()); + if (handler == null || handler.doesNotsupportsDefinitionCreation()) { + handler = FormFieldTypeSupport.TEXT; + } + String baseName = + (definition.name() != null && !definition.name().isBlank()) + ? definition.name() + : handler.typeName() + "_" + (pageIndex + 1); + String uniqueName = generateUniqueFieldName(baseName, existingNames); + existingNames.add(uniqueName); + try { + createNewField( + handler, + acroForm, + page, + rectangle, + uniqueName, + definition, + definition.options()); + } catch (Exception e) { + log.warn("Failed to create detected field '{}': {}", uniqueName, e.getMessage()); + } + } + + ensureAppearances(acroForm); + } + public String filterSingleChoiceSelection( String selection, List allowedOptions, String fieldName) { if (selection == null || selection.trim().isEmpty()) return null; diff --git a/app/common/src/test/java/stirling/software/common/util/FormUtilsAddFieldsTest.java b/app/common/src/test/java/stirling/software/common/util/FormUtilsAddFieldsTest.java new file mode 100644 index 0000000000..cf5f9a658a --- /dev/null +++ b/app/common/src/test/java/stirling/software/common/util/FormUtilsAddFieldsTest.java @@ -0,0 +1,85 @@ +package stirling.software.common.util; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.pdfbox.pdmodel.PDDocument; +import org.apache.pdfbox.pdmodel.PDPage; +import org.apache.pdfbox.pdmodel.common.PDRectangle; +import org.apache.pdfbox.pdmodel.interactive.form.PDAcroForm; +import org.apache.pdfbox.pdmodel.interactive.form.PDCheckBox; +import org.apache.pdfbox.pdmodel.interactive.form.PDField; +import org.apache.pdfbox.pdmodel.interactive.form.PDTextField; +import org.junit.jupiter.api.Test; + +import stirling.software.common.util.FormUtils.NewFormFieldDefinition; + +class FormUtilsAddFieldsTest { + + private NewFormFieldDefinition def(String type, int page, float x, float y, float w, float h) { + return new NewFormFieldDefinition( + null, null, type, page, x, y, w, h, false, null, null, null, null); + } + + @Test + void createsTextAndCheckboxFieldsOnPagelessDocument() throws IOException { + try (PDDocument doc = new PDDocument()) { + doc.addPage(new PDPage(new PDRectangle(612, 792))); + + FormUtils.addFields( + doc, + List.of( + def("text", 0, 100f, 700f, 200f, 20f), + def("checkbox", 0, 100f, 650f, 15f, 15f))); + + PDAcroForm form = doc.getDocumentCatalog().getAcroForm(); + assertNotNull(form, "AcroForm should be created"); + + List fields = new ArrayList<>(); + form.getFieldTree().forEach(fields::add); + assertEquals(2, fields.size()); + + boolean hasText = fields.stream().anyMatch(f -> f instanceof PDTextField); + boolean hasCheckbox = fields.stream().anyMatch(f -> f instanceof PDCheckBox); + assertTrue(hasText, "expected a text field"); + assertTrue(hasCheckbox, "expected a checkbox field"); + } + } + + @Test + void skipsOutOfRangePageAndKeepsNamesUnique() throws IOException { + try (PDDocument doc = new PDDocument()) { + doc.addPage(new PDPage(new PDRectangle(612, 792))); + + FormUtils.addFields( + doc, + List.of( + def("text", 0, 10f, 10f, 50f, 12f), + def("text", 0, 10f, 40f, 50f, 12f), + def("text", 5, 10f, 70f, 50f, 12f))); // page 5 out of range -> skipped + + PDAcroForm form = doc.getDocumentCatalog().getAcroForm(); + List fields = new ArrayList<>(); + form.getFieldTree().forEach(fields::add); + assertEquals(2, fields.size()); + + long distinctNames = fields.stream().map(PDField::getPartialName).distinct().count(); + assertEquals(2, distinctNames, "field names must be unique"); + } + } + + @Test + void noOpOnEmptyDefinitions() throws IOException { + try (PDDocument doc = new PDDocument()) { + doc.addPage(new PDPage(new PDRectangle(612, 792))); + FormUtils.addFields(doc, List.of()); + // no AcroForm forced into existence when there is nothing to add + assertEquals(null, doc.getDocumentCatalog().getAcroForm()); + } + } +} diff --git a/app/core/src/main/resources/settings.yml.template b/app/core/src/main/resources/settings.yml.template index 2ab074d594..fc44f25f9e 100644 --- a/app/core/src/main/resources/settings.yml.template +++ b/app/core/src/main/resources/settings.yml.template @@ -364,6 +364,13 @@ aiEngine: url: http://localhost:5001 # URL of the Python AI engine timeoutSeconds: 120 # Timeout in seconds for AI engine requests +formDetection: + enabled: true # Master on/off switch for the Auto Form Detection feature + executionMode: auto # Where detection runs: 'auto' (browser first, server fallback), 'browser' (in-browser WASM only, PDF never leaves the device), or 'server' (backend inference) + activeModelId: "" # Id of the installed Auto Form Detection model (set automatically after an admin installs one) + modelDir: "" # Optional override directory for downloaded .onnx models; blank uses /models/form-detection + preinstalledModelDir: "" # Read-only dir of models baked into the image (Docker pre-downloads FFDNet-S here); seeded into the model dir on startup. Blank disables. + policies: # Folder automations can read from and write to the directories you allow here, so treat this as a # security boundary. Leave allowedFolderRoots empty (default) to disable folder sources/outputs diff --git a/app/proprietary/build.gradle b/app/proprietary/build.gradle index ceb44153d3..7fa0c97208 100644 --- a/app/proprietary/build.gradle +++ b/app/proprietary/build.gradle @@ -7,6 +7,12 @@ ext { jwtVersion = '0.13.0' awsSdkVersion = '2.44.12' testcontainersMinioVersion = '1.21.4' + // CPU build; self-extracting natives for linux/win/mac x64+arm64. Keep major.minor aligned + // with the frontend onnxruntime-web pin so the same .onnx runs identically on both paths. + // Must be >=1.21 to load opset-22 models: the FFDNet exports are stamped ai.onnx opset 22, + // which 1.19/1.20 hard-reject (ORT_FAIL "support is till opset 21"). 1.26.0 matches the + // onnxruntime build the upstream CommonForms inference reference runs on. + onnxruntimeVersion = '1.26.0' } bootRun { @@ -72,6 +78,18 @@ dependencies { implementation 'com.google.code.gson:gson:2.13.2' + // ONNX Runtime (Java) for SERVER-SIDE Auto Form Detection inference. + // Compiled against everywhere so the module always builds, but the multi-platform native + // (~42MB) is only BUNDLED when explicitly requested via -PbundleOnnxRuntime=true - i.e. the + // Docker server image (which then slims it to one Linux arch, ~8MB). Desktop/local/core builds + // ship WITHOUT it: server-side detection cleanly reports "unavailable" and the in-browser + // (onnxruntime-web) engine handles detection instead. Tests always get it. + compileOnly "com.microsoft.onnxruntime:onnxruntime:$onnxruntimeVersion" + testImplementation "com.microsoft.onnxruntime:onnxruntime:$onnxruntimeVersion" + if ((project.findProperty('bundleOnnxRuntime') ?: 'false').toString().toBoolean()) { + runtimeOnly "com.microsoft.onnxruntime:onnxruntime:$onnxruntimeVersion" + } + api 'io.micrometer:micrometer-registry-prometheus' api "io.jsonwebtoken:jjwt-api:$jwtVersion" diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/catalog/ModelCatalogService.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/catalog/ModelCatalogService.java new file mode 100644 index 0000000000..c4cf1f23a1 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/catalog/ModelCatalogService.java @@ -0,0 +1,64 @@ +package stirling.software.proprietary.formdetection.catalog; + +import java.io.InputStream; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.stereotype.Service; + +import jakarta.annotation.PostConstruct; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; + +import tools.jackson.core.type.TypeReference; +import tools.jackson.databind.ObjectMapper; + +/** Loads the curated Auto Form Detection model catalog from a bundled JSON resource. */ +@Slf4j +@Service +@RequiredArgsConstructor +public class ModelCatalogService { + + private static final String CATALOG_RESOURCE = "formdetection/model-catalog.json"; + + private final ObjectMapper objectMapper; + + private volatile List entries = List.of(); + private volatile Map byId = Map.of(); + + @PostConstruct + void load() { + try (InputStream is = new ClassPathResource(CATALOG_RESOURCE).getInputStream()) { + List loaded = + objectMapper.readValue(is, new TypeReference>() {}); + Map map = new LinkedHashMap<>(); + for (ModelCatalogEntry entry : loaded) { + if (entry.getId() != null && !entry.getId().isBlank()) { + map.put(entry.getId(), entry); + } + } + this.entries = List.copyOf(map.values()); + this.byId = Map.copyOf(map); + log.info("Loaded {} Auto Form Detection model catalog entries", entries.size()); + } catch (Exception e) { + log.error( + "Failed to load Auto Form Detection model catalog from {}", + CATALOG_RESOURCE, + e); + } + } + + public List getAll() { + return entries; + } + + public Optional getById(String id) { + return Optional.ofNullable(id == null ? null : byId.get(id)); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionController.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionController.java new file mode 100644 index 0000000000..8d925758e8 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionController.java @@ -0,0 +1,174 @@ +package stirling.software.proprietary.formdetection.controller; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.pdfbox.pdmodel.PDDocument; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.multipart.MultipartFile; + +import io.github.pixee.security.Filenames; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import stirling.software.common.service.CustomPDFDocumentFactory; +import stirling.software.common.util.FormUtils; +import stirling.software.common.util.FormUtils.NewFormFieldDefinition; +import stirling.software.common.util.TempFileManager; +import stirling.software.common.util.WebResponseUtils; +import stirling.software.proprietary.formdetection.inference.OnnxFormDetector; +import stirling.software.proprietary.formdetection.inference.Yolo; +import stirling.software.proprietary.formdetection.model.DetectedField; +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; +import stirling.software.proprietary.formdetection.render.CoordinateMapper; +import stirling.software.proprietary.formdetection.render.PageRasterizer; +import stirling.software.proprietary.formdetection.service.FormDetectionModelManager; + +/** + * Server-side detection endpoint. Gated behind the {@code form-detection} endpoint key, which is + * disabled until a model is installed (so the tool tile is greyed in the UI). Returns the shared + * detection schema, or - when {@code applyToPdf=true} - the AcroForm-applied PDF. + */ +@Slf4j +@RestController +@RequestMapping("/api/v1/ai/form-detection") +@ConditionalOnClass(name = "ai.onnxruntime.OrtEnvironment") +@RequiredArgsConstructor +@Tag(name = "Auto Form Detection") +public class FormDetectionController { + + private final FormDetectionModelManager manager; + private final OnnxFormDetector detector; + private final PageRasterizer rasterizer; + private final CustomPDFDocumentFactory pdfDocumentFactory; + private final TempFileManager tempFileManager; + + @PostMapping(value = "/detect", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) + @Operation( + summary = "Detect form fields with the installed AI model", + description = + "Runs the installed ONNX model over each page and returns detected fields in" + + " PDF points. With applyToPdf=true, returns the fillable PDF instead.") + public ResponseEntity detect( + @RequestParam("file") MultipartFile file, + @RequestParam(value = "confThreshold", required = false) Float confThreshold, + @RequestParam(value = "applyToPdf", required = false, defaultValue = "false") + boolean applyToPdf) + throws IOException { + + if (!manager.isReady()) { + return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE) + .body( + Map.of( + "reason", + "DEPENDENCY", + "message", + "AI form-detection model is not installed")); + } + ModelCatalogEntry spec = manager.getActiveEntry().orElse(null); + if (spec == null) { + return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE) + .body( + Map.of( + "reason", + "DEPENDENCY", + "message", + "Active model spec unavailable")); + } + float score = confThreshold != null ? confThreshold : spec.getScoreThreshold(); + byte[] pdfBytes = file.getBytes(); + + List detections = new ArrayList<>(); + try { + for (PageRasterizer.RasterPage page : + rasterizer.rasterize(pdfBytes, spec.getInputSize())) { + Yolo.Preprocessed pre = + Yolo.preprocess(page.rgba(), page.widthPx(), page.heightPx(), spec); + Yolo.RawOutput out = detector.infer(pre.chw(), spec.getInputSize()); + for (Yolo.Detection d : Yolo.decode(out, spec, pre, score)) { + DetectedField.RectPt rect = CoordinateMapper.toPdfPoints(d, page); + detections.add( + new DetectedField( + fieldType(spec, d.classId()), + page.pageIndex(), + rect, + d.score())); + } + } + } catch (IllegalStateException e) { + // e.g. ONNX Runtime native unavailable for this OS/arch - report unavailable cleanly + // rather than a 500. Cannot happen on a normally-built jar (all platforms bundled), but + // keeps a slimmed/mis-targeted build from erroring. + log.warn("Auto Form Detection inference unavailable: {}", e.getMessage()); + return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE) + .body(Map.of("reason", "DEPENDENCY", "message", e.getMessage())); + } + + if (applyToPdf) { + try (PDDocument document = pdfDocumentFactory.load(file)) { + FormUtils.repairMissingWidgetPageReferences(document); + List defs = new ArrayList<>(); + for (DetectedField f : detections) { + defs.add(toDefinition(f)); + } + FormUtils.addFields(document, defs); + return WebResponseUtils.pdfDocToWebResponse( + document, baseName(file) + ".pdf", tempFileManager); + } + } + return ResponseEntity.ok(new DetectResponse(detections)); + } + + private static String fieldType(ModelCatalogEntry spec, int classId) { + List types = spec.getClassFieldTypes(); + if (types != null && classId >= 0 && classId < types.size()) { + return types.get(classId); + } + return "text"; + } + + private static NewFormFieldDefinition toDefinition(DetectedField f) { + DetectedField.RectPt r = f.rectInPdfPoints(); + return new NewFormFieldDefinition( + null, + null, + f.type(), + f.page(), + (float) r.x(), + (float) r.y(), + (float) r.w(), + (float) r.h(), + Boolean.FALSE, + null, + null, + null, + null); + } + + private static String baseName(MultipartFile file) { + String original = Filenames.toSimpleFileName(file.getOriginalFilename()); + if (original == null || original.isBlank()) { + original = "document"; + } + String stem = + original.toLowerCase().endsWith(".pdf") + ? original.substring(0, original.length() - 4) + : original; + return stem + "_form"; + } + + /** Shared JSON response (mirrors the browser pipeline output). */ + public record DetectResponse(List detections) {} +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelController.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelController.java new file mode 100644 index 0000000000..6c8b3f55aa --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelController.java @@ -0,0 +1,119 @@ +package stirling.software.proprietary.formdetection.controller; + +import org.apache.commons.lang3.StringUtils; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; + +import lombok.Data; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.formdetection.model.ModelStatusResponse; +import stirling.software.proprietary.formdetection.service.FormDetectionModelManager; + +/** + * Admin-managed lifecycle for the Auto Form Detection model. Lives under the never-gated {@code + * form-detection-model} endpoint key so install/status stay reachable while the feature itself (the + * {@code form-detection} detect endpoint) is disabled until a model is ready. + */ +@Slf4j +@RestController +@RequestMapping("/api/v1/ai/form-detection-model") +@RequiredArgsConstructor +@Tag(name = "Auto Form Detection") +public class FormDetectionModelController { + + private final FormDetectionModelManager manager; + + @GetMapping("/status") + @Operation(summary = "Auto Form Detection model status, progress and catalog") + public ResponseEntity status() { + return ResponseEntity.ok(manager.status()); + } + + @PostMapping("/install") + @PreAuthorize("hasRole('ADMIN')") + @Operation(summary = "Install (download + checksum-verify) a catalog model") + public ResponseEntity install(@RequestBody InstallRequest request) { + if (request == null || StringUtils.isBlank(request.getModelId())) { + ModelStatusResponse s = manager.status(); + s.setError("modelId is required"); + return ResponseEntity.badRequest().body(s); + } + try { + manager.startInstall(request.getModelId()); + return ResponseEntity.accepted().body(manager.status()); + } catch (IllegalStateException e) { + ModelStatusResponse s = manager.status(); + s.setError(e.getMessage()); + return ResponseEntity.status(HttpStatus.CONFLICT).body(s); + } catch (IllegalArgumentException e) { + ModelStatusResponse s = manager.status(); + s.setError(e.getMessage()); + return ResponseEntity.badRequest().body(s); + } + } + + @DeleteMapping + @PreAuthorize("hasRole('ADMIN')") + @Operation(summary = "Uninstall a model") + public ResponseEntity delete( + @RequestParam(name = "modelId", required = false) String modelId) { + try { + manager.deleteModel(modelId); + return ResponseEntity.ok(manager.status()); + } catch (IllegalStateException e) { + ModelStatusResponse s = manager.status(); + s.setError(e.getMessage()); + return ResponseEntity.status(HttpStatus.CONFLICT).body(s); + } + } + + @PostMapping("/config") + @PreAuthorize("hasRole('ADMIN')") + @Operation( + summary = "Update the feature on/off switch and execution mode (auto/browser/server)") + public ResponseEntity config(@RequestBody ConfigRequest request) { + if (request == null) { + return ResponseEntity.badRequest().body(manager.status()); + } + try { + if (request.getEnabled() != null) { + manager.setEnabled(request.getEnabled()); + } + if (StringUtils.isNotBlank(request.getExecutionMode())) { + manager.setExecutionMode(request.getExecutionMode()); + } + return ResponseEntity.ok(manager.status()); + } catch (IllegalArgumentException e) { + ModelStatusResponse s = manager.status(); + s.setError(e.getMessage()); + return ResponseEntity.badRequest().body(s); + } + } + + @Data + public static class ConfigRequest { + /** Master on/off; {@code null} leaves it unchanged. */ + private Boolean enabled; + + /** auto|browser|server; blank/null leaves it unchanged. */ + private String executionMode; + } + + @Data + public static class InstallRequest { + private String modelId; + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelServeController.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelServeController.java new file mode 100644 index 0000000000..1178b99795 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelServeController.java @@ -0,0 +1,99 @@ +package stirling.software.proprietary.formdetection.controller; + +import java.io.IOException; +import java.nio.file.Path; +import java.time.Duration; +import java.util.List; +import java.util.Optional; + +import org.springframework.core.io.FileSystemResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourceRegion; +import org.springframework.http.CacheControl; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRange; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.formdetection.service.FormDetectionModelManager; + +/** + * Streams the installed .onnx to browsers for in-browser inference. Supports resumable range + * requests with a stable (content-addressed) ETag and a public Cache-Control so each browser + * fetches the ~200MB model at most once. The file is streamed from disk by Spring's resource + * converters - never buffered into heap. + */ +@Slf4j +@RestController +@RequestMapping("/api/v1/ai/form-detection-model") +@RequiredArgsConstructor +@Tag(name = "Auto Form Detection") +public class FormDetectionModelServeController { + + private final FormDetectionModelManager manager; + + @GetMapping("/file") + @Operation(summary = "Stream the installed model (.onnx); supports HTTP range requests") + public ResponseEntity serveModel(@RequestHeader HttpHeaders headers) + throws IOException { + Optional active = manager.getActiveModelFile(); + if (active.isEmpty()) { + return ResponseEntity.notFound().build(); + } + Path path = active.get(); + Resource resource = new FileSystemResource(path); + long length = resource.contentLength(); + String etag = + "\"" + + manager.getActiveEtag() + .orElseGet(() -> length + "-" + path.toFile().lastModified()) + + "\""; + // cachePublic + maxAge override the interceptor's blanket "no-store" on /api responses. + CacheControl cacheControl = CacheControl.maxAge(Duration.ofDays(30)).cachePublic(); + + List ranges; + try { + ranges = headers.getRange(); + } catch (IllegalArgumentException ex) { + return ResponseEntity.status(HttpStatus.REQUESTED_RANGE_NOT_SATISFIABLE) + .header(HttpHeaders.CONTENT_RANGE, "bytes */" + length) + .build(); + } + + if (ranges.isEmpty()) { + return ResponseEntity.ok() + .header(HttpHeaders.ACCEPT_RANGES, "bytes") + .eTag(etag) + .cacheControl(cacheControl) + .contentType(MediaType.APPLICATION_OCTET_STREAM) + .contentLength(length) + .body(resource); + } + + // Serve a single region (the resumable-download case). Returning a bare List would lose its + // generic element type when the method returns ResponseEntity, leaving the region + // converter unable to match it. + HttpRange range = ranges.get(0); + long start = range.getRangeStart(length); + long count = Math.min(range.getRangeEnd(length) - start + 1, length - start); + ResourceRegion region = new ResourceRegion(resource, start, count); + // No explicit content type: the region converter derives it. Presetting octet-stream makes + // the ResourceRegion converter's content-negotiation reject the body. + return ResponseEntity.status(HttpStatus.PARTIAL_CONTENT) + .header(HttpHeaders.ACCEPT_RANGES, "bytes") + .eTag(etag) + .cacheControl(cacheControl) + .body(region); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/inference/OnnxFormDetector.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/inference/OnnxFormDetector.java new file mode 100644 index 0000000000..085cc50849 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/inference/OnnxFormDetector.java @@ -0,0 +1,153 @@ +package stirling.software.proprietary.formdetection.inference; + +import java.nio.FloatBuffer; +import java.nio.file.Path; +import java.util.Collections; +import java.util.concurrent.Semaphore; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.stereotype.Service; + +import jakarta.annotation.PreDestroy; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; +import stirling.software.proprietary.formdetection.service.FormDetectionModelManager; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; + +/** + * Holds the ONNX Runtime session for the active model. Lazily (re)loads when the active model + * changes, guards session swaps with a read/write lock, and bounds concurrent inferences to limit + * memory. The session input is NCHW float32 {@code [1,3,N,N]}; the raw output is returned as-is for + * {@link Yolo#decode} to interpret per the model spec. + */ +@Slf4j +@Service +@ConditionalOnClass(name = "ai.onnxruntime.OrtEnvironment") +@RequiredArgsConstructor +public class OnnxFormDetector { + + private final FormDetectionModelManager manager; + + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + private final Semaphore concurrency = + new Semaphore(Math.max(1, Runtime.getRuntime().availableProcessors() / 2)); + + private volatile OrtSession session; + private volatile String loadedModelId; + private volatile String inputName; + + public Yolo.RawOutput infer(float[] chw, int inputSize) { + ensureLoaded(); + concurrency.acquireUninterruptibly(); + lock.readLock().lock(); + try { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + long[] shape = {1, 3, inputSize, inputSize}; + try (OnnxTensor tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), shape); + OrtSession.Result results = + session.run(Collections.singletonMap(inputName, tensor))) { + OnnxValue value = results.get(0); + Object raw = value.getValue(); + if (!(raw instanceof float[][][] out3) || out3.length == 0) { + throw new IllegalStateException( + "Unexpected ONNX output type: " + + (raw == null ? "null" : raw.getClass())); + } + float[][] m = out3[0]; + int d1 = m.length; + int d2 = d1 > 0 ? m[0].length : 0; + float[] flat = new float[d1 * d2]; + for (int i = 0; i < d1; i++) { + System.arraycopy(m[i], 0, flat, i * d2, d2); + } + return new Yolo.RawOutput(flat, d1, d2); + } + } catch (OrtException e) { + throw new IllegalStateException("ONNX inference failed: " + e.getMessage(), e); + } finally { + lock.readLock().unlock(); + concurrency.release(); + } + } + + /** Force the next inference to reload from disk (called after install/uninstall). */ + public void unload() { + lock.writeLock().lock(); + try { + closeSession(); + loadedModelId = null; + inputName = null; + } finally { + lock.writeLock().unlock(); + } + } + + private void ensureLoaded() { + String activeId = manager.getActiveEntry().map(ModelCatalogEntry::getId).orElse(null); + if (activeId == null) { + throw new IllegalStateException("No Auto Form Detection model installed"); + } + if (activeId.equals(loadedModelId) && session != null) { + return; + } + lock.writeLock().lock(); + try { + if (activeId.equals(loadedModelId) && session != null) { + return; + } + Path file = + manager.getActiveModelFile() + .orElseThrow( + () -> new IllegalStateException("Active model file missing")); + try { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); + try { + opts.setIntraOpNumThreads( + Math.max(1, Runtime.getRuntime().availableProcessors() / 2)); + } catch (OrtException ignored) { + // best-effort tuning + } + closeSession(); + session = env.createSession(file.toString(), opts); + inputName = session.getInputNames().iterator().next(); + loadedModelId = activeId; + log.info("Loaded ONNX session for Auto Form Detection model '{}'", activeId); + } catch (OrtException | RuntimeException | LinkageError e) { + // Native library missing/incompatible for this OS+arch (e.g. a Linux-slimmed jar + // run on Windows), or a model load failure. Degrade gracefully instead of letting + // an UnsatisfiedLinkError escape - the detect endpoint reports unavailable and the + // server keeps running. + throw new IllegalStateException( + "ONNX Runtime is unavailable on this platform/build: " + e.getMessage(), e); + } + } finally { + lock.writeLock().unlock(); + } + } + + @PreDestroy + void close() { + unload(); + } + + private void closeSession() { + if (session != null) { + try { + session.close(); + } catch (Exception ignored) { + // ignore + } + session = null; + } + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/inference/Yolo.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/inference/Yolo.java new file mode 100644 index 0000000000..c785a9f4ec --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/inference/Yolo.java @@ -0,0 +1,252 @@ +package stirling.software.proprietary.formdetection.inference; + +import java.awt.Color; +import java.awt.Graphics2D; +import java.awt.RenderingHints; +import java.awt.image.BufferedImage; +import java.util.ArrayList; +import java.util.List; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; + +/** + * Pure pre/post-processing for a YOLO-style detector, driven entirely by the {@link + * ModelCatalogEntry} spec. The browser pipeline mirrors this exactly so both inference paths agree. + * + *

Coordinate spaces: {@code preprocess} maps the source bitmap into the model's NxN input; + * {@code decode} reads raw model output (boxes in input-pixel space), thresholds, runs NMS, and + * un-projects boxes back to the original bitmap-pixel space (top-left origin). Mapping to PDF + * points is done separately by {@code CoordinateMapper}. + */ +@Slf4j +public final class Yolo { + + private Yolo() {} + + /** Normalised model input plus the transform needed to invert it. */ + public record Preprocessed( + float[] chw, + int inputSize, + float scaleX, + float scaleY, + int padX, + int padY, + int srcW, + int srcH) {} + + /** Raw model output flattened to {@code data[i*d2 + j]} with dims {@code d1 x d2}. */ + public record RawOutput(float[] data, int d1, int d2) {} + + /** A detection in original bitmap-pixel space, top-left origin. */ + public record Detection(int classId, float score, float x, float y, float w, float h) {} + + /** Letterbox/stretch-resize, normalise and lay out as NCHW float32. */ + public static Preprocessed preprocess(byte[] rgba, int srcW, int srcH, ModelCatalogEntry spec) { + int n = spec.getInputSize(); + boolean letterbox = !"stretch".equalsIgnoreCase(spec.getResizeMode()); + + float scaleX; + float scaleY; + int padX; + int padY; + int drawW; + int drawH; + if (letterbox) { + float scale = Math.min((float) n / srcW, (float) n / srcH); + drawW = Math.max(1, Math.round(srcW * scale)); + drawH = Math.max(1, Math.round(srcH * scale)); + padX = (n - drawW) / 2; + padY = (n - drawH) / 2; + scaleX = scale; + scaleY = scale; + } else { + drawW = n; + drawH = n; + padX = 0; + padY = 0; + scaleX = (float) n / srcW; + scaleY = (float) n / srcH; + } + + int[] pad = spec.getPadColor(); + Color fill = + new Color( + clampByte(pad != null && pad.length > 0 ? pad[0] : 114), + clampByte(pad != null && pad.length > 1 ? pad[1] : 114), + clampByte(pad != null && pad.length > 2 ? pad[2] : 114)); + + BufferedImage canvas = new BufferedImage(n, n, BufferedImage.TYPE_INT_RGB); + Graphics2D g = canvas.createGraphics(); + try { + g.setRenderingHint( + RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g.setColor(fill); + g.fillRect(0, 0, n, n); + g.drawImage(rgbaToImage(rgba, srcW, srcH), padX, padY, drawW, drawH, null); + } finally { + g.dispose(); + } + + boolean bgr = "bgr".equalsIgnoreCase(spec.getChannelOrder()); + float[] mean = orZeros(spec.getNormMean()); + float[] std = orOnes(spec.getNormStd()); + int plane = n * n; + float[] chw = new float[3 * plane]; + int[] px = canvas.getRGB(0, 0, n, n, null, 0, n); + for (int i = 0; i < plane; i++) { + int rgb = px[i]; + float r = ((rgb >> 16) & 0xFF) / 255f; + float gg = ((rgb >> 8) & 0xFF) / 255f; + float b = (rgb & 0xFF) / 255f; + float c0 = bgr ? b : r; + float c1 = gg; + float c2 = bgr ? r : b; + chw[i] = (c0 - mean[0]) / std[0]; + chw[plane + i] = (c1 - mean[1]) / std[1]; + chw[2 * plane + i] = (c2 - mean[2]) / std[2]; + } + return new Preprocessed(chw, n, scaleX, scaleY, padX, padY, srcW, srcH); + } + + /** Decode raw output, threshold, NMS, and un-project to original bitmap pixels. */ + public static List decode( + RawOutput out, ModelCatalogEntry spec, Preprocessed pre, float scoreThreshold) { + int numClasses = spec.getClassNames() == null ? 0 : spec.getClassNames().size(); + if (numClasses == 0) { + return List.of(); + } + boolean obj = spec.isHasObjectness(); + boolean ncFirst = !"anchors_first".equalsIgnoreCase(spec.getOutputLayout()); + int channels = ncFirst ? out.d1() : out.d2(); + int anchors = ncFirst ? out.d2() : out.d1(); + int expected = 4 + (obj ? 1 : 0) + numClasses; + if (channels < expected) { + log.warn( + "ONNX output channel count {} < expected {} (4 + obj + {} classes); skipping", + channels, + expected, + numClasses); + return List.of(); + } + int classOffset = 4 + (obj ? 1 : 0); + float[] data = out.data(); + + List dets = new ArrayList<>(); + for (int a = 0; a < anchors; a++) { + float objScore = obj ? at(data, ncFirst, anchors, channels, 4, a) : 1f; + int bestClass = -1; + float bestScore = 0f; + for (int c = 0; c < numClasses; c++) { + float s = at(data, ncFirst, anchors, channels, classOffset + c, a) * objScore; + if (s > bestScore) { + bestScore = s; + bestClass = c; + } + } + if (bestClass < 0 || bestScore < scoreThreshold) { + continue; + } + float cx = at(data, ncFirst, anchors, channels, 0, a); + float cy = at(data, ncFirst, anchors, channels, 1, a); + float w = at(data, ncFirst, anchors, channels, 2, a); + float h = at(data, ncFirst, anchors, channels, 3, a); + float x1 = cx - w / 2f; + float y1 = cy - h / 2f; + float ox = (x1 - pre.padX()) / pre.scaleX(); + float oy = (y1 - pre.padY()) / pre.scaleY(); + float ow = w / pre.scaleX(); + float oh = h / pre.scaleY(); + // clamp to the source bitmap + float cxl = Math.max(0, Math.min(ox, pre.srcW())); + float cyl = Math.max(0, Math.min(oy, pre.srcH())); + ow = Math.max(0, Math.min(ow, pre.srcW() - cxl)); + oh = Math.max(0, Math.min(oh, pre.srcH() - cyl)); + if (ow <= 0 || oh <= 0) { + continue; + } + dets.add(new Detection(bestClass, bestScore, cxl, cyl, ow, oh)); + } + return nms(dets, spec.getNms(), spec.getIou()); + } + + private static float at( + float[] data, boolean ncFirst, int anchors, int channels, int c, int a) { + return ncFirst ? data[c * anchors + a] : data[a * channels + c]; + } + + private static List nms(List dets, String mode, float iouThreshold) { + if (dets.size() < 2 || "none".equalsIgnoreCase(mode)) { + return dets; + } + boolean classAgnostic = mode != null && mode.toLowerCase().contains("agnostic"); + List sorted = new ArrayList<>(dets); + sorted.sort((x, y) -> Float.compare(y.score(), x.score())); + boolean[] removed = new boolean[sorted.size()]; + List keep = new ArrayList<>(); + for (int i = 0; i < sorted.size(); i++) { + if (removed[i]) { + continue; + } + Detection di = sorted.get(i); + keep.add(di); + for (int j = i + 1; j < sorted.size(); j++) { + if (removed[j]) { + continue; + } + Detection dj = sorted.get(j); + if (!classAgnostic && di.classId() != dj.classId()) { + continue; + } + if (iou(di, dj) > iouThreshold) { + removed[j] = true; + } + } + } + return keep; + } + + private static float iou(Detection a, Detection b) { + float ax2 = a.x() + a.w(); + float ay2 = a.y() + a.h(); + float bx2 = b.x() + b.w(); + float by2 = b.y() + b.h(); + float ix1 = Math.max(a.x(), b.x()); + float iy1 = Math.max(a.y(), b.y()); + float ix2 = Math.min(ax2, bx2); + float iy2 = Math.min(ay2, by2); + float iw = Math.max(0, ix2 - ix1); + float ih = Math.max(0, iy2 - iy1); + float inter = iw * ih; + float union = a.w() * a.h() + b.w() * b.h() - inter; + return union <= 0 ? 0 : inter / union; + } + + private static BufferedImage rgbaToImage(byte[] rgba, int w, int h) { + BufferedImage img = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB); + int[] px = new int[w * h]; + int pixels = Math.min(w * h, rgba.length / 4); + for (int i = 0; i < pixels; i++) { + int r = rgba[i * 4] & 0xFF; + int g = rgba[i * 4 + 1] & 0xFF; + int b = rgba[i * 4 + 2] & 0xFF; + int a = rgba[i * 4 + 3] & 0xFF; + px[i] = (a << 24) | (r << 16) | (g << 8) | b; + } + img.setRGB(0, 0, w, h, px, 0, w); + return img; + } + + private static int clampByte(int v) { + return Math.max(0, Math.min(255, v)); + } + + private static float[] orZeros(float[] v) { + return v != null && v.length >= 3 ? v : new float[] {0f, 0f, 0f}; + } + + private static float[] orOnes(float[] v) { + return v != null && v.length >= 3 ? v : new float[] {1f, 1f, 1f}; + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/DetectedField.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/DetectedField.java new file mode 100644 index 0000000000..c8cadde11d --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/DetectedField.java @@ -0,0 +1,16 @@ +package stirling.software.proprietary.formdetection.model; + +/** + * One detected form field. This is the shared schema returned by both the server detect endpoint + * and the in-browser pipeline so the two paths are interchangeable. + * + * @param type AcroForm field type (text|checkbox|radio|signature) + * @param page zero-based page index + * @param rectInPdfPoints rectangle in PDF points (bottom-left origin) + * @param confidence detection confidence 0-1 + */ +public record DetectedField(String type, int page, RectPt rectInPdfPoints, double confidence) { + + /** Rectangle in PDF points, bottom-left origin (PDF user space). */ + public record RectPt(double x, double y, double w, double h) {} +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/FormDetectionStatus.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/FormDetectionStatus.java new file mode 100644 index 0000000000..43c920e467 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/FormDetectionStatus.java @@ -0,0 +1,17 @@ +package stirling.software.proprietary.formdetection.model; + +import java.util.Locale; + +/** Lifecycle state of the Auto Form Detection model install. */ +public enum FormDetectionStatus { + NOT_INSTALLED, + DOWNLOADING, + VERIFYING, + READY, + FAILED; + + /** Lowercase wire form sent to the frontend, e.g. {@code not_installed}. */ + public String wire() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/ModelCatalogEntry.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/ModelCatalogEntry.java new file mode 100644 index 0000000000..6eecb31484 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/ModelCatalogEntry.java @@ -0,0 +1,75 @@ +package stirling.software.proprietary.formdetection.model; + +import java.util.List; + +import lombok.Data; + +/** + * One installable form-detection model plus the data-driven pre/post-processing spec the inference + * pipeline needs. The exact same numeric values are surfaced to the browser (via the model status + * endpoint) so the in-browser and server inference paths stay equivalent. + * + *

NOTE: the pipeline-spec defaults below follow common Ultralytics-YOLO conventions. The precise + * values for a given model (input size, resize mode, channel order, output layout, NMS, class + * indices) MUST be verified against the actual exported {@code .onnx} before that entry's {@code + * onnxUrl}/{@code sha256} are populated. An entry with a blank {@code onnxUrl} or {@code sha256} is + * not installable, which keeps the distribution shippable without any bundled model. + */ +@Data +public class ModelCatalogEntry { + + // --- Identity / distribution ------------------------------------------------- + private String id; + private String displayName; + private String description; + private String license; + private long sizeBytes; + + /** Direct download URL of the .onnx. Blank = not yet available (install is rejected). */ + private String onnxUrl; + + /** Lower-hex SHA-256 of the .onnx. Blank = not yet available (install is rejected). */ + private String sha256; + + // --- Pre-processing (parity-critical, mirrored by the browser) --------------- + /** Square model input edge in pixels. */ + private int inputSize = 1024; + + /** "letterbox" (aspect-preserving pad) or "stretch" (resize to square). */ + private String resizeMode = "letterbox"; + + /** RGB letterbox pad colour. */ + private int[] padColor = {114, 114, 114}; + + /** "rgb" or "bgr" channel order fed to the model. */ + private String channelOrder = "rgb"; + + /** + * Per-channel mean subtracted after dividing the raw byte by 255 ({@code (raw/255 - + * mean)/std}). + */ + private float[] normMean = {0f, 0f, 0f}; + + /** Per-channel std applied after mean subtraction. */ + private float[] normStd = {1f, 1f, 1f}; + + // --- Post-processing (parity-critical) --------------------------------------- + /** "nc_first" => output [1, 4+nc, anchors]; "anchors_first" => [1, anchors, 4+nc]. */ + private String outputLayout = "nc_first"; + + /** True if an objectness score column precedes the class scores (YOLOv5 style). */ + private boolean hasObjectness = false; + + /** Class index -> label. */ + private List classNames = List.of("text", "choice", "signature"); + + /** Class index -> AcroForm field type (text|checkbox|radio|signature). */ + private List classFieldTypes = List.of("text", "checkbox", "signature"); + + private float scoreThreshold = 0.25f; + + /** "none", "classAgnostic" or "perClass". */ + private String nms = "perClass"; + + private float iou = 0.45f; +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/ModelStatusResponse.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/ModelStatusResponse.java new file mode 100644 index 0000000000..eb1b1e72f3 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/model/ModelStatusResponse.java @@ -0,0 +1,46 @@ +package stirling.software.proprietary.formdetection.model; + +import java.util.List; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Snapshot returned by {@code GET /api/v1/ai/form-detection-model/status}. Includes the full + * catalog so the browser can read the active model's parity-critical pipeline spec. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ModelStatusResponse { + /** Wire state: not_installed | downloading | verifying | ready | failed. */ + private String status; + + /** Download progress 0-100 (meaningful while downloading). */ + private int progress; + + /** Id of the active/usable model, or blank when none. */ + private String activeModelId; + + /** Model ids that currently have an .onnx file on disk. */ + private List installed; + + /** Last error message, or null. */ + private String error; + + /** Whether the model directory is writable (admin install possible). */ + private boolean writable; + + /** Full curated catalog (identity + pipeline spec). */ + private List catalog; + + /** Master on/off for the whole feature (admin-controlled). */ + private boolean enabled; + + /** Where detection runs: auto | browser | server. */ + private String executionMode; + + /** True if the server-side ONNX engine is bundled in this build (else only browser works). */ + private boolean serverEngineAvailable; +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/render/CoordinateMapper.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/render/CoordinateMapper.java new file mode 100644 index 0000000000..65892a7373 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/render/CoordinateMapper.java @@ -0,0 +1,35 @@ +package stirling.software.proprietary.formdetection.render; + +import stirling.software.proprietary.formdetection.inference.Yolo; +import stirling.software.proprietary.formdetection.model.DetectedField; + +/** + * Maps a detection (original bitmap pixels, top-left origin) to PDF points (bottom-left origin), + * accounting for the render scale and the top-left vs bottom-left origin flip. + */ +public final class CoordinateMapper { + + private CoordinateMapper() {} + + public static DetectedField.RectPt toPdfPoints( + Yolo.Detection d, PageRasterizer.RasterPage page) { + float sx = page.scaleX() > 0 ? page.scaleX() : 1f; + float sy = page.scaleY() > 0 ? page.scaleY() : 1f; + + double wPt = d.w() / sx; + double hPt = d.h() / sy; + double xPt = d.x() / sx; + // Flip Y: bitmap origin is top-left, PDF origin is bottom-left. + double yPt = page.pageHeightPt() - (d.y() / sy) - hPt; + + xPt = clamp(xPt, 0, page.pageWidthPt()); + yPt = clamp(yPt, 0, page.pageHeightPt()); + wPt = clamp(wPt, 0, page.pageWidthPt() - xPt); + hPt = clamp(hPt, 0, page.pageHeightPt() - yPt); + return new DetectedField.RectPt(xPt, yPt, wPt, hPt); + } + + private static double clamp(double v, double lo, double hi) { + return v < lo ? lo : Math.min(v, hi); + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/render/PageRasterizer.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/render/PageRasterizer.java new file mode 100644 index 0000000000..9c5ca45e49 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/render/PageRasterizer.java @@ -0,0 +1,65 @@ +package stirling.software.proprietary.formdetection.render; + +import java.util.ArrayList; +import java.util.List; + +import org.springframework.stereotype.Service; + +import lombok.extern.slf4j.Slf4j; + +import stirling.software.jpdfium.PdfDocument; +import stirling.software.jpdfium.PdfPage; +import stirling.software.jpdfium.model.PageSize; +import stirling.software.jpdfium.model.RenderResult; + +/** + * Renders PDF pages to RGBA bitmaps via JPDFium (the same PDFium engine the browser pipeline uses, + * for closer parity than PDFBox's Java2D renderer). Each page is rendered at a DPI chosen so its + * long side is approximately the model input size, minimising any later resampling. The actual + * pixels-per-point scale is computed from the rendered dimensions so coordinate mapping does not + * depend on how {@code renderAt} interprets its argument. + */ +@Slf4j +@Service +public class PageRasterizer { + + /** A rendered page: RGBA pixels plus the page size (points) and px-per-point scale. */ + public record RasterPage( + int pageIndex, + byte[] rgba, + int widthPx, + int heightPx, + float pageWidthPt, + float pageHeightPt, + float scaleX, + float scaleY) {} + + public List rasterize(byte[] pdfBytes, int inputSize) { + List pages = new ArrayList<>(); + try (PdfDocument doc = PdfDocument.open(pdfBytes)) { + int count = doc.pageCount(); + for (int i = 0; i < count; i++) { + try (PdfPage page = doc.page(i)) { + PageSize size = page.size(); + float maxSide = Math.max(size.width(), size.height()); + int dpi = maxSide <= 0 ? 150 : Math.round(72f * inputSize / maxSide); + dpi = Math.max(36, Math.min(dpi, 300)); + RenderResult r = page.renderAt(dpi); + float scaleX = size.width() > 0 ? r.width() / size.width() : dpi / 72f; + float scaleY = size.height() > 0 ? r.height() / size.height() : dpi / 72f; + pages.add( + new RasterPage( + i, + r.rgba(), + r.width(), + r.height(), + size.width(), + size.height(), + scaleX, + scaleY)); + } + } + } + return pages; + } +} diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/service/FormDetectionModelManager.java b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/service/FormDetectionModelManager.java new file mode 100644 index 0000000000..c166391768 --- /dev/null +++ b/app/proprietary/src/main/java/stirling/software/proprietary/formdetection/service/FormDetectionModelManager.java @@ -0,0 +1,509 @@ +package stirling.software.proprietary.formdetection.service; + +import static java.nio.file.StandardOpenOption.CREATE; +import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; +import static java.nio.file.StandardOpenOption.WRITE; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.nio.file.AtomicMoveNotSupportedException; +import java.nio.file.DirectoryStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.security.MessageDigest; +import java.util.ArrayList; +import java.util.HexFormat; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.regex.Pattern; + +import org.apache.commons.lang3.StringUtils; +import org.springframework.stereotype.Service; + +import jakarta.annotation.PostConstruct; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import stirling.software.SPDF.config.EndpointConfiguration; +import stirling.software.SPDF.config.EndpointConfiguration.DisableReason; +import stirling.software.common.configuration.RuntimePathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.common.util.GeneralUtils; +import stirling.software.proprietary.formdetection.catalog.ModelCatalogService; +import stirling.software.proprietary.formdetection.model.FormDetectionStatus; +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; +import stirling.software.proprietary.formdetection.model.ModelStatusResponse; + +/** + * Downloads, verifies and tracks the on-demand Auto Form Detection model. Concurrency-safe + * (single-flight install), checksum-verified, and atomic-published to a mounted volume so the model + * survives container restarts/updates. Mirrors the OCR tessdata admin pattern but adds the lock, + * temp-file + atomic rename, and SHA-256 verification the spec requires. + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class FormDetectionModelManager { + + /** Endpoint key gated until a model is ready (drives the disabled tool tile in the UI). */ + public static final String ENDPOINT_KEY = "form-detection"; + + private static final Pattern SAFE_ID = Pattern.compile("[a-z0-9][a-z0-9-]{0,63}"); + private static final Pattern SHA256_HEX = Pattern.compile("[0-9a-f]{64}"); + + /** + * Whether the server-side ONNX engine is bundled in this build (the onnxruntime jar is only + * included via {@code -PbundleOnnxRuntime=true}, e.g. the Docker server image). The frontend + * uses this to disable the "Server" execution mode when it isn't available. + */ + private static final boolean SERVER_ENGINE_AVAILABLE = isOnnxRuntimePresent(); + + private static boolean isOnnxRuntimePresent() { + try { + Class.forName( + "ai.onnxruntime.OrtEnvironment", + false, + FormDetectionModelManager.class.getClassLoader()); + return true; + } catch (Throwable t) { + return false; + } + } + + private final RuntimePathConfig runtimePathConfig; + private final ModelCatalogService catalog; + private final ApplicationProperties applicationProperties; + private final EndpointConfiguration endpointConfiguration; + + private final AtomicBoolean installing = new AtomicBoolean(false); + private volatile FormDetectionStatus state = FormDetectionStatus.NOT_INSTALLED; + private volatile int progress = 0; + private volatile String error = null; + private volatile String activeSha = null; + + @PostConstruct + void init() { + sweepTempFiles(); + seedPreinstalledModels(); + Optional active = getActiveModelFile(); + if (active.isPresent()) { + activeSha = getActiveEntry().map(ModelCatalogEntry::getSha256).orElse(null); + state = FormDetectionStatus.READY; + log.info("Auto Form Detection model '{}' is installed and ready", activeModelId()); + } else { + state = FormDetectionStatus.NOT_INSTALLED; + } + applyEndpointState(); + } + + /** + * Gate the {@code form-detection} endpoint (which drives the tool tile): off with reason CONFIG + * when the feature is disabled by the admin, off with reason DEPENDENCY when no model is ready, + * otherwise on. Execution mode (browser/server) does not affect this - the tile is active + * either way and the frontend chooses where to run. + */ + private void applyEndpointState() { + if (!isFeatureEnabled()) { + endpointConfiguration.disableEndpoint(ENDPOINT_KEY, DisableReason.CONFIG); + } else if (state == FormDetectionStatus.READY && getActiveModelFile().isPresent()) { + endpointConfiguration.enableEndpoint(ENDPOINT_KEY); + } else { + endpointConfiguration.disableEndpoint(ENDPOINT_KEY, DisableReason.DEPENDENCY); + } + } + + private boolean isFeatureEnabled() { + return applicationProperties.getFormDetection().isEnabled(); + } + + /** Master on/off (admin). Persists and re-gates the endpoint immediately. */ + public synchronized void setEnabled(boolean enabled) { + applicationProperties.getFormDetection().setEnabled(enabled); + try { + GeneralUtils.saveKeyToSettings("formDetection.enabled", enabled); + } catch (IOException e) { + log.warn("Could not persist formDetection.enabled (state kept in memory)", e); + } + applyEndpointState(); + } + + /** Set where detection runs: auto|browser|server (admin). Persists. */ + public synchronized void setExecutionMode(String mode) { + String m = mode == null ? "auto" : mode.trim().toLowerCase(Locale.ROOT); + if (!m.equals("auto") && !m.equals("browser") && !m.equals("server")) { + throw new IllegalArgumentException("executionMode must be auto, browser or server"); + } + applicationProperties.getFormDetection().setExecutionMode(m); + try { + GeneralUtils.saveKeyToSettings("formDetection.executionMode", m); + } catch (IOException e) { + log.warn("Could not persist formDetection.executionMode (state kept in memory)", e); + } + } + + /** + * Validate and kick off a background download+verify+install. Returns immediately; callers poll + * {@link #status()}. + * + * @throws IllegalArgumentException unknown/invalid model id or bad checksum format + * @throws IllegalStateException no URL/checksum configured, or an install is already running + */ + public synchronized void startInstall(String modelId) { + if (!SAFE_ID.matcher(modelId).matches()) { + throw new IllegalArgumentException("Invalid model id: " + modelId); + } + ModelCatalogEntry entry = + catalog.getById(modelId) + .orElseThrow( + () -> new IllegalArgumentException("Unknown model id: " + modelId)); + // URL + checksum come ONLY from the bundled catalog (trusted constants), never from the + // request, so an admin cannot point the download at an arbitrary host (avoids SSRF). + String url = entry.getOnnxUrl(); + String sha = entry.getSha256() == null ? null : entry.getSha256().toLowerCase(Locale.ROOT); + if (StringUtils.isBlank(url) || StringUtils.isBlank(sha)) { + throw new IllegalStateException( + "Model '" + modelId + "' has no download URL/checksum configured yet"); + } + String scheme = URI.create(url).getScheme(); + if (!"https".equalsIgnoreCase(scheme) && !"http".equalsIgnoreCase(scheme)) { + throw new IllegalArgumentException("Model URL must be http(s): " + url); + } + if (!SHA256_HEX.matcher(sha).matches()) { + throw new IllegalArgumentException("Checksum must be a 64-char hex SHA-256"); + } + if (!installing.compareAndSet(false, true)) { + throw new IllegalStateException("An install is already in progress"); + } + state = FormDetectionStatus.DOWNLOADING; + progress = 0; + error = null; + final String fUrl = url; + final String fSha = sha; + Thread.ofVirtual() + .name("form-detection-install-" + modelId) + .start( + () -> { + try { + doInstall(modelId, entry, fUrl, fSha); + } catch (Exception e) { + log.error("Auto Form Detection install failed for {}", modelId, e); + error = e.getMessage(); + // Keep a previously-installed model usable if the new one failed. + state = + getActiveModelFile().isPresent() + ? FormDetectionStatus.READY + : FormDetectionStatus.FAILED; + } finally { + installing.set(false); + } + }); + } + + private void doInstall(String modelId, ModelCatalogEntry entry, String url, String expectedSha) + throws IOException { + Path dir = modelDir(); + Files.createDirectories(dir); + if (!isWritable(dir)) { + throw new IOException("Model directory is not writable: " + dir); + } + Path base = dir.toRealPath(); + Path target = base.resolve(modelId + ".onnx").normalize(); + if (!target.startsWith(base)) { + throw new IOException("Blocked path traversal for model id " + modelId); + } + + // Already downloaded and intact: skip the network fetch and just (re)activate it. Makes + // switching between already-downloaded models instant instead of re-fetching tens of MB. + if (Files.isRegularFile(target) && expectedSha.equals(sha256OfFile(target))) { + log.info( + "Model '{}' already present and verified; activating without re-download", + modelId); + activate(modelId, expectedSha); + return; + } + + Path tmp = base.resolve(modelId + ".onnx.tmp"); + + MessageDigest digest; + try { + digest = MessageDigest.getInstance("SHA-256"); + } catch (Exception e) { + throw new IOException("SHA-256 unavailable", e); + } + + HttpURLConnection conn = null; + try { + conn = (HttpURLConnection) URI.create(url).toURL().openConnection(); + conn.setRequestMethod("GET"); + conn.setRequestProperty("User-Agent", "Stirling-PDF-App"); + conn.setRequestProperty("Accept", "application/octet-stream"); + conn.setConnectTimeout(10000); + conn.setReadTimeout(60000); + int http = conn.getResponseCode(); + if (http != HttpURLConnection.HTTP_OK) { + throw new IOException("Download failed: HTTP " + http + " from " + url); + } + long total = + entry.getSizeBytes() > 0 ? entry.getSizeBytes() : conn.getContentLengthLong(); + try (InputStream in = conn.getInputStream(); + OutputStream out = + Files.newOutputStream(tmp, CREATE, TRUNCATE_EXISTING, WRITE)) { + byte[] buf = new byte[1 << 16]; + long read = 0; + int n; + while ((n = in.read(buf)) >= 0) { + out.write(buf, 0, n); + digest.update(buf, 0, n); + read += n; + if (total > 0) { + progress = (int) Math.min(99, (read * 100) / total); + } + } + } + } finally { + if (conn != null) { + conn.disconnect(); + } + } + + state = FormDetectionStatus.VERIFYING; + byte[] actual = digest.digest(); + byte[] expected = HexFormat.of().parseHex(expectedSha); + if (!MessageDigest.isEqual(actual, expected)) { + Files.deleteIfExists(tmp); + throw new IOException( + "Checksum mismatch (expected " + + expectedSha + + " got " + + HexFormat.of().formatHex(actual) + + ")"); + } + + try { + Files.move( + tmp, + target, + StandardCopyOption.ATOMIC_MOVE, + StandardCopyOption.REPLACE_EXISTING); + } catch (AtomicMoveNotSupportedException e) { + Files.move(tmp, target, StandardCopyOption.REPLACE_EXISTING); + } + + activate(modelId, expectedSha); + } + + /** Mark a verified, on-disk model as the active one and (re)enable the feature. */ + private void activate(String modelId, String expectedSha) { + applicationProperties.getFormDetection().setActiveModelId(modelId); + try { + GeneralUtils.saveKeyToSettings("formDetection.activeModelId", modelId); + } catch (IOException e) { + log.warn("Could not persist formDetection.activeModelId (state kept in memory)", e); + } + activeSha = expectedSha; + progress = 100; + state = FormDetectionStatus.READY; + applyEndpointState(); + log.info("Auto Form Detection model '{}' installed and ready", modelId); + } + + /** SHA-256 of an existing model file as lowercase hex, or {@code null} if it cannot be read. */ + private String sha256OfFile(Path file) { + try (InputStream in = Files.newInputStream(file)) { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] buf = new byte[1 << 16]; + int n; + while ((n = in.read(buf)) >= 0) { + digest.update(buf, 0, n); + } + return HexFormat.of().formatHex(digest.digest()); + } catch (Exception e) { + log.debug("Could not hash existing model file {}", file, e); + return null; + } + } + + /** Remove a model file; if it was the active one, disable the feature again. */ + public synchronized void deleteModel(String modelId) { + if (installing.get()) { + throw new IllegalStateException("Cannot uninstall while an install is in progress"); + } + String id = StringUtils.isNotBlank(modelId) ? modelId : activeModelId(); + if (StringUtils.isBlank(id) || !SAFE_ID.matcher(id).matches()) { + return; + } + Path base = modelDir().toAbsolutePath().normalize(); + Path file = base.resolve(id + ".onnx").normalize(); + if (!file.startsWith(base)) { + return; // path-traversal guard (SAFE_ID already blocks it; this also satisfies CodeQL) + } + try { + Files.deleteIfExists(file); + } catch (IOException e) { + log.warn("Failed to delete model file {}", file, e); + } + if (id.equals(activeModelId())) { + applicationProperties.getFormDetection().setActiveModelId(""); + try { + GeneralUtils.saveKeyToSettings("formDetection.activeModelId", ""); + } catch (IOException e) { + log.warn("Could not clear formDetection.activeModelId", e); + } + activeSha = null; + } + if (getActiveModelFile().isEmpty()) { + state = FormDetectionStatus.NOT_INSTALLED; + error = null; + } + applyEndpointState(); + } + + public ModelStatusResponse status() { + Path dir = modelDir(); + List installed = new ArrayList<>(); + if (Files.isDirectory(dir)) { + try (DirectoryStream s = Files.newDirectoryStream(dir, "*.onnx")) { + for (Path p : s) { + String fn = p.getFileName().toString(); + installed.add(fn.substring(0, fn.length() - ".onnx".length())); + } + } catch (IOException e) { + log.debug("Could not list installed models in {}", dir, e); + } + } + return new ModelStatusResponse( + state.wire(), + progress, + activeModelId(), + installed, + error, + isWritable(dir), + catalog.getAll(), + isFeatureEnabled(), + applicationProperties.getFormDetection().getExecutionMode(), + SERVER_ENGINE_AVAILABLE); + } + + public Optional getActiveModelFile() { + String id = activeModelId(); + if (StringUtils.isBlank(id)) { + return Optional.empty(); + } + Path f = modelDir().resolve(id + ".onnx"); + return Files.isRegularFile(f) ? Optional.of(f) : Optional.empty(); + } + + public Optional getActiveEntry() { + return catalog.getById(activeModelId()); + } + + public Optional getActiveEtag() { + return Optional.ofNullable(activeSha); + } + + public boolean isReady() { + return isFeatureEnabled() + && state == FormDetectionStatus.READY + && getActiveModelFile().isPresent(); + } + + private String activeModelId() { + return applicationProperties.getFormDetection().getActiveModelId(); + } + + private Path modelDir() { + return Paths.get(runtimePathConfig.getFormDetectionModelPath()); + } + + /** + * Copy any image-baked models (see {@code formDetection.preinstalledModelDir}) into the + * writable model dir if not already present, and activate one when nothing is active yet. Lets + * the Docker server image ship with FFDNet-S ready without an admin install. No-op when the dir + * is unset or missing (desktop/local). + */ + private void seedPreinstalledModels() { + String preDir = applicationProperties.getFormDetection().getPreinstalledModelDir(); + if (StringUtils.isBlank(preDir)) { + return; + } + Path src = Paths.get(preDir); + if (!Files.isDirectory(src)) { + return; + } + Path dir = modelDir(); + try { + Files.createDirectories(dir); + } catch (IOException e) { + log.warn("Cannot create model dir to seed pre-installed models: {}", e.getMessage()); + return; + } + if (!isWritable(dir)) { + log.warn("Model dir {} not writable; skipping pre-installed model seeding", dir); + return; + } + try (DirectoryStream models = Files.newDirectoryStream(src, "*.onnx")) { + for (Path p : models) { + String fn = p.getFileName().toString(); + String id = fn.substring(0, fn.length() - ".onnx".length()); + if (!SAFE_ID.matcher(id).matches() || catalog.getById(id).isEmpty()) { + continue; + } + Path target = dir.resolve(id + ".onnx"); + if (!Files.exists(target)) { + Files.copy(p, target, StandardCopyOption.COPY_ATTRIBUTES); + log.info("Seeded pre-installed Auto Form Detection model '{}'", id); + } + if (StringUtils.isBlank(activeModelId())) { + applicationProperties.getFormDetection().setActiveModelId(id); + try { + GeneralUtils.saveKeyToSettings("formDetection.activeModelId", id); + } catch (IOException e) { + log.warn("Could not persist seeded activeModelId: {}", e.getMessage()); + } + } + } + } catch (IOException e) { + log.warn("Failed to seed pre-installed models from {}: {}", src, e.getMessage()); + } + } + + private void sweepTempFiles() { + Path dir = modelDir(); + if (!Files.isDirectory(dir)) { + return; + } + try (DirectoryStream s = Files.newDirectoryStream(dir, "*.tmp")) { + for (Path p : s) { + try { + Files.deleteIfExists(p); + } catch (IOException ignored) { + // best-effort sweep of interrupted downloads + } + } + } catch (IOException e) { + log.debug("No stale form-detection temp files to sweep", e); + } + } + + private boolean isWritable(Path dir) { + try { + Files.createDirectories(dir); + if (!Files.isWritable(dir)) { + return false; + } + Path probe = Files.createTempFile(dir, "fd-write-test", ".tmp"); + Files.deleteIfExists(probe); + return true; + } catch (IOException e) { + return false; + } + } +} diff --git a/app/proprietary/src/main/resources/formdetection/model-catalog.json b/app/proprietary/src/main/resources/formdetection/model-catalog.json new file mode 100644 index 0000000000..1bae82f939 --- /dev/null +++ b/app/proprietary/src/main/resources/formdetection/model-catalog.json @@ -0,0 +1,46 @@ +[ + { + "id": "ffdnet-s", + "displayName": "CommonForms FFDNet-S (Small)", + "description": "Small, fast detector with lower memory use - a good default for most forms. Finds text inputs, checkboxes, and signature fields.", + "license": "CommonForms dataset CC-BY-4.0; Ultralytics-YOLO lineage AGPL-3.0. Downloaded on demand, never bundled.", + "sizeBytes": 38370092, + "onnxUrl": "https://huggingface.co/jbarrow/FFDNet-S-cpu/resolve/main/FFDNet-S.onnx", + "sha256": "93bccf47c048f9f947f9b1b52d002edf144a8a583dae39f164d9e5725321acc0", + "inputSize": 1216, + "resizeMode": "stretch", + "padColor": [114, 114, 114], + "channelOrder": "bgr", + "normMean": [0.0, 0.0, 0.0], + "normStd": [1.0, 1.0, 1.0], + "outputLayout": "nc_first", + "hasObjectness": false, + "classNames": ["text", "choice", "signature"], + "classFieldTypes": ["text", "checkbox", "signature"], + "scoreThreshold": 0.3, + "nms": "perClass", + "iou": 0.45 + }, + { + "id": "ffdnet-l", + "displayName": "CommonForms FFDNet-L (Large)", + "description": "Larger, higher-accuracy detector (~25M parameters) with better recall on dense or complex forms - uses more memory. Finds text inputs, checkboxes, and signature fields.", + "license": "CommonForms dataset CC-BY-4.0; Ultralytics-YOLO lineage AGPL-3.0. Downloaded on demand, never bundled.", + "sizeBytes": 101944542, + "onnxUrl": "https://huggingface.co/jbarrow/FFDNet-L-cpu/resolve/main/FFDNet-L.onnx", + "sha256": "e00c59edd9a5275dab5847d38f042c8ecc827063650c8aac22b0e486c414cd35", + "inputSize": 1216, + "resizeMode": "stretch", + "padColor": [114, 114, 114], + "channelOrder": "bgr", + "normMean": [0.0, 0.0, 0.0], + "normStd": [1.0, 1.0, 1.0], + "outputLayout": "nc_first", + "hasObjectness": false, + "classNames": ["text", "choice", "signature"], + "classFieldTypes": ["text", "checkbox", "signature"], + "scoreThreshold": 0.3, + "nms": "perClass", + "iou": 0.45 + } +] diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/catalog/ModelCatalogServiceTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/catalog/ModelCatalogServiceTest.java new file mode 100644 index 0000000000..4365d1c5c3 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/catalog/ModelCatalogServiceTest.java @@ -0,0 +1,49 @@ +package stirling.software.proprietary.formdetection.catalog; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; + +import tools.jackson.databind.json.JsonMapper; + +class ModelCatalogServiceTest { + + @Test + void loadsBundledCatalogWithSpecDefaults() { + ModelCatalogService service = new ModelCatalogService(JsonMapper.builder().build()); + service.load(); + + List all = service.getAll(); + assertTrue(all.size() >= 2, "catalog should ship with at least two entries"); + assertTrue(service.getById("ffdnet-s").isPresent()); + + ModelCatalogEntry l = service.getById("ffdnet-l").orElseThrow(); + assertEquals(3, l.getClassNames().size()); + assertEquals(3, l.getClassFieldTypes().size()); + assertTrue(l.getInputSize() > 0); + + // Model-free distribution: the jar bundles no weights. Every entry instead carries a + // download URL and a SHA-256 so the model is fetched and integrity-verified on demand. + for (ModelCatalogEntry e : all) { + assertNotNull(e.getOnnxUrl(), e.getId() + " must declare a download URL"); + assertFalse(e.getOnnxUrl().isBlank(), e.getId() + " must declare a download URL"); + assertNotNull(e.getSha256(), e.getId() + " must declare a SHA-256 checksum"); + assertFalse(e.getSha256().isBlank(), e.getId() + " must declare a SHA-256 checksum"); + } + } + + @Test + void unknownIdReturnsEmpty() { + ModelCatalogService service = new ModelCatalogService(JsonMapper.builder().build()); + service.load(); + assertTrue(service.getById("does-not-exist").isEmpty()); + assertTrue(service.getById(null).isEmpty()); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionControllerTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionControllerTest.java new file mode 100644 index 0000000000..746f91912e --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionControllerTest.java @@ -0,0 +1,70 @@ +package stirling.software.proprietary.formdetection.controller; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.multipart; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import java.util.List; +import java.util.Optional; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.mock.web.MockMultipartFile; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; + +import stirling.software.common.service.CustomPDFDocumentFactory; +import stirling.software.common.util.TempFileManager; +import stirling.software.proprietary.formdetection.inference.OnnxFormDetector; +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; +import stirling.software.proprietary.formdetection.render.PageRasterizer; +import stirling.software.proprietary.formdetection.service.FormDetectionModelManager; + +class FormDetectionControllerTest { + + private MockMvc mvc( + FormDetectionModelManager manager, + OnnxFormDetector detector, + PageRasterizer rasterizer) { + FormDetectionController controller = + new FormDetectionController( + manager, + detector, + rasterizer, + Mockito.mock(CustomPDFDocumentFactory.class), + Mockito.mock(TempFileManager.class)); + return MockMvcBuilders.standaloneSetup(controller).build(); + } + + private MockMultipartFile pdf() { + return new MockMultipartFile("file", "test.pdf", "application/pdf", "%PDF-1.4".getBytes()); + } + + @Test + void detectReturns503WhenModelNotReady() throws Exception { + FormDetectionModelManager manager = Mockito.mock(FormDetectionModelManager.class); + Mockito.when(manager.isReady()).thenReturn(false); + + mvc(manager, Mockito.mock(OnnxFormDetector.class), Mockito.mock(PageRasterizer.class)) + .perform(multipart("/api/v1/ai/form-detection/detect").file(pdf())) + .andExpect(status().isServiceUnavailable()) + .andExpect(jsonPath("$.reason").value("DEPENDENCY")); + } + + @Test + void detectReturnsEmptyDetectionsForBlankRender() throws Exception { + FormDetectionModelManager manager = Mockito.mock(FormDetectionModelManager.class); + Mockito.when(manager.isReady()).thenReturn(true); + Mockito.when(manager.getActiveEntry()).thenReturn(Optional.of(new ModelCatalogEntry())); + + PageRasterizer rasterizer = Mockito.mock(PageRasterizer.class); + Mockito.when(rasterizer.rasterize(Mockito.any(), Mockito.anyInt())) + .thenReturn(List.of()); // no pages -> no detections, detector never called + + mvc(manager, Mockito.mock(OnnxFormDetector.class), rasterizer) + .perform(multipart("/api/v1/ai/form-detection/detect").file(pdf())) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.detections").isArray()) + .andExpect(jsonPath("$.detections").isEmpty()); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelControllerTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelControllerTest.java new file mode 100644 index 0000000000..274771f4d4 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelControllerTest.java @@ -0,0 +1,95 @@ +package stirling.software.proprietary.formdetection.controller; + +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; + +import stirling.software.proprietary.formdetection.model.ModelStatusResponse; +import stirling.software.proprietary.formdetection.service.FormDetectionModelManager; + +class FormDetectionModelControllerTest { + + private ModelStatusResponse notInstalled() { + return new ModelStatusResponse( + "not_installed", 0, "", List.of(), null, true, List.of(), true, "auto", true); + } + + private MockMvc mvc(FormDetectionModelManager manager) { + return MockMvcBuilders.standaloneSetup(new FormDetectionModelController(manager)).build(); + } + + @Test + void statusReturnsJson() throws Exception { + FormDetectionModelManager manager = Mockito.mock(FormDetectionModelManager.class); + Mockito.when(manager.status()).thenReturn(notInstalled()); + + mvc(manager) + .perform(get("/api/v1/ai/form-detection-model/status")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.status").value("not_installed")) + .andExpect(jsonPath("$.writable").value(true)); + } + + @Test + void installWithBlankModelIdReturns400() throws Exception { + FormDetectionModelManager manager = Mockito.mock(FormDetectionModelManager.class); + Mockito.when(manager.status()).thenReturn(notInstalled()); + + mvc(manager) + .perform( + post("/api/v1/ai/form-detection-model/install") + .contentType(MediaType.APPLICATION_JSON) + .content("{\"modelId\":\"\"}")) + .andExpect(status().isBadRequest()); + } + + @Test + void installValidReturns202() throws Exception { + FormDetectionModelManager manager = Mockito.mock(FormDetectionModelManager.class); + Mockito.when(manager.status()).thenReturn(notInstalled()); + + mvc(manager) + .perform( + post("/api/v1/ai/form-detection-model/install") + .contentType(MediaType.APPLICATION_JSON) + .content("{\"modelId\":\"ffdnet-s\"}")) + .andExpect(status().isAccepted()); + } + + @Test + void installWhileBusyReturns409() throws Exception { + FormDetectionModelManager manager = Mockito.mock(FormDetectionModelManager.class); + Mockito.when(manager.status()).thenReturn(notInstalled()); + Mockito.doThrow(new IllegalStateException("An install is already in progress")) + .when(manager) + .startInstall(Mockito.eq("ffdnet-s")); + + mvc(manager) + .perform( + post("/api/v1/ai/form-detection-model/install") + .contentType(MediaType.APPLICATION_JSON) + .content("{\"modelId\":\"ffdnet-s\"}")) + .andExpect(status().isConflict()); + } + + @Test + void deleteReturnsStatus() throws Exception { + FormDetectionModelManager manager = Mockito.mock(FormDetectionModelManager.class); + Mockito.when(manager.status()).thenReturn(notInstalled()); + + mvc(manager) + .perform(delete("/api/v1/ai/form-detection-model")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.status").value("not_installed")); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelServeControllerTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelServeControllerTest.java new file mode 100644 index 0000000000..a01fc1df63 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/controller/FormDetectionModelServeControllerTest.java @@ -0,0 +1,81 @@ +package stirling.software.proprietary.formdetection.controller; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Optional; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.Mockito; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourceRegion; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRange; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; + +import stirling.software.proprietary.formdetection.service.FormDetectionModelManager; + +/** + * Tests the serve controller's logic directly (status, headers, region bounds). Serializing a + * ResourceRegion needs the full Spring resource converters, which standaloneSetup does not wire, so + * exercising the method directly is both simpler and converter-independent. + */ +class FormDetectionModelServeControllerTest { + + private FormDetectionModelManager managerWith(Path model) { + FormDetectionModelManager manager = Mockito.mock(FormDetectionModelManager.class); + Mockito.when(manager.getActiveModelFile()) + .thenReturn(model == null ? Optional.empty() : Optional.of(model)); + Mockito.when(manager.getActiveEtag()).thenReturn(Optional.of("a".repeat(64))); + return manager; + } + + @Test + void servesFullModelWithPublicCacheAndAcceptRanges(@TempDir Path dir) throws Exception { + Path model = dir.resolve("m.onnx"); + Files.write(model, "abcdefghij".getBytes()); + + ResponseEntity resp = + new FormDetectionModelServeController(managerWith(model)) + .serveModel(new HttpHeaders()); + + assertEquals(HttpStatus.OK, resp.getStatusCode()); + assertEquals("bytes", resp.getHeaders().getFirst(HttpHeaders.ACCEPT_RANGES)); + assertNotNull(resp.getHeaders().getFirst(HttpHeaders.CACHE_CONTROL)); + assertTrue(resp.getHeaders().getFirst(HttpHeaders.CACHE_CONTROL).contains("public")); + assertNotNull(resp.getHeaders().getETag()); + assertTrue(resp.getBody() instanceof Resource); + } + + @Test + void servesRangeRequestAsSingleRegion(@TempDir Path dir) throws Exception { + Path model = dir.resolve("m.onnx"); + Files.write(model, "abcdefghij".getBytes()); + HttpHeaders headers = new HttpHeaders(); + headers.setRange(List.of(HttpRange.createByteRange(0, 3))); + + ResponseEntity resp = + new FormDetectionModelServeController(managerWith(model)).serveModel(headers); + + assertEquals(HttpStatus.PARTIAL_CONTENT, resp.getStatusCode()); + assertEquals("bytes", resp.getHeaders().getFirst(HttpHeaders.ACCEPT_RANGES)); + assertTrue(resp.getBody() instanceof ResourceRegion, "body should be a single region"); + ResourceRegion region = (ResourceRegion) resp.getBody(); + assertEquals(0L, region.getPosition()); + assertEquals(4L, region.getCount()); // bytes 0-3 inclusive + } + + @Test + void returns404WhenNoModelInstalled() throws Exception { + ResponseEntity resp = + new FormDetectionModelServeController(managerWith(null)) + .serveModel(new HttpHeaders()); + assertEquals(HttpStatus.NOT_FOUND, resp.getStatusCode()); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/inference/YoloTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/inference/YoloTest.java new file mode 100644 index 0000000000..ecd3438122 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/inference/YoloTest.java @@ -0,0 +1,98 @@ +package stirling.software.proprietary.formdetection.inference; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; + +class YoloTest { + + private ModelCatalogEntry spec() { + ModelCatalogEntry e = new ModelCatalogEntry(); + e.setInputSize(10); + e.setResizeMode("letterbox"); + e.setChannelOrder("rgb"); + e.setOutputLayout("nc_first"); + e.setHasObjectness(false); + e.setClassNames(List.of("text", "choice")); + e.setClassFieldTypes(List.of("text", "checkbox")); + e.setNms("perClass"); + e.setIou(0.5f); + return e; + } + + @Test + void decodeThresholdsAndSuppressesOverlaps() { + ModelCatalogEntry spec = spec(); + // identity transform (scale 1, no pad), 10x10 source + Yolo.Preprocessed pre = new Yolo.Preprocessed(new float[0], 10, 1f, 1f, 0, 0, 10, 10); + + // nc_first layout [channels=6][anchors=3], data[c*anchors + a] + // box A (cx5,cy5,w4,h4) twice (overlapping) + box B (cx8,cy8,w2,h2) + float[] data = { + 5, + 5, + 8, // cx + 5, + 5, + 8, // cy + 4, + 4, + 2, // w + 4, + 4, + 2, // h + 0.9f, + 0.8f, + 0.7f, // text score + 0.1f, + 0.1f, + 0.1f // choice score + }; + Yolo.RawOutput out = new Yolo.RawOutput(data, 6, 3); + + List dets = Yolo.decode(out, spec, pre, 0.5f); + + // a0 (box A, 0.9) kept; a1 (box A', 0.8) suppressed by NMS; a2 (box B, 0.7) kept + assertEquals(2, dets.size()); + + Yolo.Detection a = dets.get(0); + assertEquals(0, a.classId()); + assertEquals(0.9f, a.score(), 1e-5); + assertEquals(3f, a.x(), 1e-4); + assertEquals(3f, a.y(), 1e-4); + assertEquals(4f, a.w(), 1e-4); + assertEquals(4f, a.h(), 1e-4); + + Yolo.Detection b = dets.get(1); + assertEquals(0.7f, b.score(), 1e-5); + assertEquals(7f, b.x(), 1e-4); + assertEquals(7f, b.y(), 1e-4); + assertEquals(2f, b.w(), 1e-4); + assertEquals(2f, b.h(), 1e-4); + } + + @Test + void decodeIsDeterministic() { + ModelCatalogEntry spec = spec(); + Yolo.Preprocessed pre = new Yolo.Preprocessed(new float[0], 10, 1f, 1f, 0, 0, 10, 10); + float[] data = {5, 5, 8, 5, 5, 8, 4, 4, 2, 4, 4, 2, 0.9f, 0.8f, 0.7f, 0.1f, 0.1f, 0.1f}; + Yolo.RawOutput out = new Yolo.RawOutput(data, 6, 3); + assertEquals( + Yolo.decode(out, spec, pre, 0.5f).toString(), + Yolo.decode(out, spec, pre, 0.5f).toString()); + } + + @Test + void thresholdDropsLowScores() { + ModelCatalogEntry spec = spec(); + Yolo.Preprocessed pre = new Yolo.Preprocessed(new float[0], 10, 1f, 1f, 0, 0, 10, 10); + float[] data = {5, 5, 8, 5, 5, 8, 4, 4, 2, 4, 4, 2, 0.9f, 0.8f, 0.7f, 0.1f, 0.1f, 0.1f}; + Yolo.RawOutput out = new Yolo.RawOutput(data, 6, 3); + // threshold above every score -> nothing survives + assertEquals(0, Yolo.decode(out, spec, pre, 0.95f).size()); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/render/CoordinateMapperTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/render/CoordinateMapperTest.java new file mode 100644 index 0000000000..5fd9f56d81 --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/render/CoordinateMapperTest.java @@ -0,0 +1,45 @@ +package stirling.software.proprietary.formdetection.render; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import stirling.software.proprietary.formdetection.inference.Yolo; +import stirling.software.proprietary.formdetection.model.DetectedField; + +class CoordinateMapperTest { + + @Test + void mapsBitmapPixelsToPdfPointsWithYFlip() { + // 200x300pt page rendered at 2 px/pt (400x600 px) + PageRasterizer.RasterPage page = + new PageRasterizer.RasterPage(0, new byte[0], 400, 600, 200f, 300f, 2f, 2f); + // detection at top-left (10,20) px, 40x60 px + Yolo.Detection d = new Yolo.Detection(0, 0.9f, 10f, 20f, 40f, 60f); + + DetectedField.RectPt r = CoordinateMapper.toPdfPoints(d, page); + + assertEquals(5.0, r.x(), 1e-4); // 10/2 + assertEquals(20.0, r.w(), 1e-4); // 40/2 + assertEquals(30.0, r.h(), 1e-4); // 60/2 + // Y flip: pageHeight - (yTopPx/scale) - hPt = 300 - 10 - 30 + assertEquals(260.0, r.y(), 1e-4); + } + + @Test + void clampsToPageBounds() { + PageRasterizer.RasterPage page = + new PageRasterizer.RasterPage(0, new byte[0], 200, 200, 100f, 100f, 2f, 2f); + // box partly off the right/bottom edge in px + Yolo.Detection d = new Yolo.Detection(0, 0.5f, 180f, 0f, 60f, 40f); + + DetectedField.RectPt r = CoordinateMapper.toPdfPoints(d, page); + + // x = 90pt, width clamped to 100-90 = 10pt + assertEquals(90.0, r.x(), 1e-4); + assertEquals(10.0, r.w(), 1e-4); + // stays within the page + org.junit.jupiter.api.Assertions.assertEquals(true, r.x() + r.w() <= 100.0 + 1e-6); + org.junit.jupiter.api.Assertions.assertEquals(true, r.y() >= -1e-6); + } +} diff --git a/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/service/FormDetectionModelManagerTest.java b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/service/FormDetectionModelManagerTest.java new file mode 100644 index 0000000000..1f14eb5b0e --- /dev/null +++ b/app/proprietary/src/test/java/stirling/software/proprietary/formdetection/service/FormDetectionModelManagerTest.java @@ -0,0 +1,173 @@ +package stirling.software.proprietary.formdetection.service; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.net.InetSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.MessageDigest; +import java.util.HexFormat; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.Mockito; + +import com.sun.net.httpserver.HttpServer; + +import stirling.software.SPDF.config.EndpointConfiguration; +import stirling.software.common.configuration.RuntimePathConfig; +import stirling.software.common.model.ApplicationProperties; +import stirling.software.proprietary.formdetection.catalog.ModelCatalogService; +import stirling.software.proprietary.formdetection.model.ModelCatalogEntry; + +class FormDetectionModelManagerTest { + + private HttpServer server; + private byte[] modelBytes; + private String modelSha; + private int port; + + @BeforeEach + void startServer() throws Exception { + modelBytes = "fake-onnx-model-content-1234567890".getBytes(); + modelSha = + HexFormat.of().formatHex(MessageDigest.getInstance("SHA-256").digest(modelBytes)); + server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0); + port = server.getAddress().getPort(); + server.createContext( + "/model.onnx", + ex -> { + ex.sendResponseHeaders(200, modelBytes.length); + ex.getResponseBody().write(modelBytes); + ex.close(); + }); + server.start(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + } + + private ModelCatalogEntry entry(String url, String sha) { + ModelCatalogEntry e = new ModelCatalogEntry(); + e.setId("test-model"); + e.setOnnxUrl(url); + e.setSha256(sha); + e.setSizeBytes(modelBytes.length); + return e; + } + + private FormDetectionModelManager manager( + Path dir, ModelCatalogEntry entry, EndpointConfiguration ep) { + RuntimePathConfig paths = Mockito.mock(RuntimePathConfig.class); + Mockito.when(paths.getFormDetectionModelPath()).thenReturn(dir.toString()); + ModelCatalogService catalog = Mockito.mock(ModelCatalogService.class); + Mockito.when(catalog.getById("test-model")).thenReturn(Optional.of(entry)); + Mockito.when(catalog.getById(Mockito.argThat(s -> !"test-model".equals(s)))) + .thenReturn(Optional.empty()); + Mockito.when(catalog.getAll()).thenReturn(List.of(entry)); + return new FormDetectionModelManager(paths, catalog, new ApplicationProperties(), ep); + } + + private void awaitState(FormDetectionModelManager m, String wire, long timeoutMs) + throws InterruptedException { + long deadline = System.currentTimeMillis() + timeoutMs; + while (System.currentTimeMillis() < deadline) { + if (wire.equals(m.status().getStatus())) { + return; + } + Thread.sleep(25); + } + fail("Timed out waiting for state '" + wire + "', was '" + m.status().getStatus() + "'"); + } + + @Test + void installsDownloadsVerifiesAndPublishesAtomically(@TempDir Path dir) throws Exception { + EndpointConfiguration ep = Mockito.mock(EndpointConfiguration.class); + FormDetectionModelManager m = + manager(dir, entry("http://127.0.0.1:" + port + "/model.onnx", modelSha), ep); + + m.startInstall("test-model"); + awaitState(m, "ready", 5000); + + Path onnx = dir.resolve("test-model.onnx"); + assertTrue(Files.exists(onnx), "model file should be published"); + assertArrayEquals(modelBytes, Files.readAllBytes(onnx)); + assertFalse(Files.exists(dir.resolve("test-model.onnx.tmp")), "temp file should be gone"); + assertTrue(m.isReady()); + Mockito.verify(ep).enableEndpoint("form-detection"); + } + + @Test + void rejectsChecksumMismatchAndLeavesNoFile(@TempDir Path dir) throws Exception { + EndpointConfiguration ep = Mockito.mock(EndpointConfiguration.class); + FormDetectionModelManager m = + manager(dir, entry("http://127.0.0.1:" + port + "/model.onnx", "0".repeat(64)), ep); + + m.startInstall("test-model"); + awaitState(m, "failed", 5000); + + assertFalse(Files.exists(dir.resolve("test-model.onnx")), "no model on mismatch"); + assertFalse(Files.exists(dir.resolve("test-model.onnx.tmp")), "temp cleaned up"); + assertFalse(m.isReady()); + Mockito.verify(ep, Mockito.never()).enableEndpoint("form-detection"); + } + + @Test + void secondConcurrentInstallIsRejected(@TempDir Path dir) throws Exception { + CountDownLatch gate = new CountDownLatch(1); + server.createContext( + "/gated.onnx", + ex -> { + try { + gate.await(3, TimeUnit.SECONDS); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + } + ex.sendResponseHeaders(200, modelBytes.length); + ex.getResponseBody().write(modelBytes); + ex.close(); + }); + FormDetectionModelManager m = + manager( + dir, + entry("http://127.0.0.1:" + port + "/gated.onnx", modelSha), + Mockito.mock(EndpointConfiguration.class)); + + m.startInstall("test-model"); // begins, blocks in handler + // installing flag is set synchronously before the worker thread spawns + assertThrows(IllegalStateException.class, () -> m.startInstall("test-model")); + gate.countDown(); + awaitState(m, "ready", 5000); + } + + @Test + void rejectsBlankUrl(@TempDir Path dir) { + FormDetectionModelManager m = + manager(dir, entry("", ""), Mockito.mock(EndpointConfiguration.class)); + assertThrows(IllegalStateException.class, () -> m.startInstall("test-model")); + } + + @Test + void rejectsUnknownModelId(@TempDir Path dir) { + FormDetectionModelManager m = + manager( + dir, + entry("http://127.0.0.1:" + port + "/model.onnx", modelSha), + Mockito.mock(EndpointConfiguration.class)); + assertThrows(IllegalArgumentException.class, () -> m.startInstall("unknown")); + } +} diff --git a/docker/embedded/Dockerfile b/docker/embedded/Dockerfile index 89913af09a..86377c6397 100644 --- a/docker/embedded/Dockerfile +++ b/docker/embedded/Dockerfile @@ -19,6 +19,15 @@ RUN apt-get update \ && rm /tmp/task.deb \ && rm -rf /var/lib/apt/lists/* +# Pre-download the default Auto Form Detection model (FFDNet-S, ~37MB) so the feature works +# out-of-the-box. Verified by checksum here; seeded into the writable configs model dir at startup +# by FormDetectionModelManager (FORMDETECTION_PREINSTALLEDMODELDIR below). +ARG FORM_DETECTION_MODEL_URL=https://huggingface.co/jbarrow/FFDNet-S-cpu/resolve/main/FFDNet-S.onnx +ARG FORM_DETECTION_MODEL_SHA256=93bccf47c048f9f947f9b1b52d002edf144a8a583dae39f164d9e5725321acc0 +RUN mkdir -p /preinstalled-models \ + && curl -fSL "${FORM_DETECTION_MODEL_URL}" -o /preinstalled-models/ffdnet-s.onnx \ + && echo "${FORM_DETECTION_MODEL_SHA256} /preinstalled-models/ffdnet-s.onnx" | sha256sum -c - + # JDK 25+: --add-exports is no longer accepted via JAVA_TOOL_OPTIONS; use JDK_JAVA_OPTIONS instead ENV JDK_JAVA_OPTIONS="--add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \ --add-exports=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ @@ -47,6 +56,7 @@ RUN STIRLING_FLAVOR=${STIRLING_FLAVOR} \ gradle clean build \ -PbuildWithFrontend=true \ -PprototypesMode=${PROTOTYPES_BUILD} \ + -PbundleOnnxRuntime=true \ -x spotlessApply -x spotlessCheck -x test -x sonarqube \ --no-daemon @@ -56,6 +66,15 @@ WORKDIR /tmp COPY --from=app-build /app/app/core/build/libs/*.jar app.jar RUN java -Djarmode=tools -jar app.jar extract --layers --destination /layers +# Slim the bundled ONNX Runtime (included only via -PbundleOnnxRuntime above) to the target Linux +# arch (~42MB jar -> ~8MB). Docker server image only; desktop/local builds don't bundle it at all. +# The detection model itself is unaffected - it is still downloaded on demand at runtime. +COPY scripts/slim-onnxruntime.sh /tmp/slim-onnxruntime.sh +RUN apt-get update \ + && apt-get install -y --no-install-recommends zip \ + && rm -rf /var/lib/apt/lists/* \ + && sh /tmp/slim-onnxruntime.sh /layers/dependencies + # Stage 3: Final runtime image on top of pre-built base FROM ${BASE_IMAGE} @@ -74,6 +93,9 @@ COPY --link --from=app-build --chown=1000:1000 \ /app/build/libs/restart-helper.jar /restart-helper.jar COPY --link --chown=1000:1000 scripts/ /scripts/ +# Pre-bundled default detection model; FormDetectionModelManager seeds it into /configs on startup. +COPY --link --from=app-build --chown=1000:1000 /preinstalled-models/ /opt/stirling/preinstalled-models/ + # Fonts go to system dir, root ownership is correct (world-readable) COPY app/core/src/main/resources/static/fonts/*.ttf /usr/share/fonts/truetype/ @@ -97,6 +119,7 @@ RUN echo "${VERSION_TAG:-dev}" > /etc/stirling_version # Environment variables ENV VERSION_TAG=$VERSION_TAG \ + FORMDETECTION_PREINSTALLEDMODELDIR="/opt/stirling/preinstalled-models" \ STIRLING_AOT_ENABLE="false" \ STIRLING_JVM_PROFILE="balanced" \ _JVM_OPTS_BALANCED="-XX:+ExitOnOutOfMemoryError -XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/configs/heap_dumps -XX:+UseG1GC -XX:MaxGCPauseMillis=200 -XX:G1HeapRegionSize=4m -XX:G1PeriodicGCInterval=60000 -XX:+UseStringDeduplication -XX:+UseCompactObjectHeaders -XX:+ExplicitGCInvokesConcurrent -Dspring.threads.virtual.enabled=true -Djava.awt.headless=true" \ diff --git a/docker/embedded/Dockerfile.fat b/docker/embedded/Dockerfile.fat index f9754b4721..a270ab06bc 100644 --- a/docker/embedded/Dockerfile.fat +++ b/docker/embedded/Dockerfile.fat @@ -20,6 +20,15 @@ RUN apt-get update \ && rm /tmp/task.deb \ && rm -rf /var/lib/apt/lists/* +# Pre-download the default Auto Form Detection model (FFDNet-S, ~37MB) so the feature works +# out-of-the-box - especially important for this air-gapped fat image (no runtime internet). +# Verified by checksum; seeded into the writable configs model dir at startup. +ARG FORM_DETECTION_MODEL_URL=https://huggingface.co/jbarrow/FFDNet-S-cpu/resolve/main/FFDNet-S.onnx +ARG FORM_DETECTION_MODEL_SHA256=93bccf47c048f9f947f9b1b52d002edf144a8a583dae39f164d9e5725321acc0 +RUN mkdir -p /preinstalled-models \ + && curl -fSL "${FORM_DETECTION_MODEL_URL}" -o /preinstalled-models/ffdnet-s.onnx \ + && echo "${FORM_DETECTION_MODEL_SHA256} /preinstalled-models/ffdnet-s.onnx" | sha256sum -c - + # JDK 25+: --add-exports is no longer accepted via JAVA_TOOL_OPTIONS; use JDK_JAVA_OPTIONS instead ENV JDK_JAVA_OPTIONS="--add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \ --add-exports=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ @@ -43,6 +52,7 @@ COPY . . RUN DISABLE_ADDITIONAL_FEATURES=false \ gradle clean build \ -PbuildWithFrontend=true \ + -PbundleOnnxRuntime=true \ -x spotlessApply -x spotlessCheck -x test -x sonarqube \ --no-daemon @@ -52,6 +62,14 @@ WORKDIR /tmp COPY --from=app-build /app/app/core/build/libs/*.jar app.jar RUN java -Djarmode=tools -jar app.jar extract --layers --destination /layers +# Slim the bundled ONNX Runtime (included only via -PbundleOnnxRuntime above) to the target Linux +# arch (~42MB jar -> ~8MB). Docker server image only; desktop/local builds don't bundle it at all. +COPY scripts/slim-onnxruntime.sh /tmp/slim-onnxruntime.sh +RUN apt-get update \ + && apt-get install -y --no-install-recommends zip \ + && rm -rf /var/lib/apt/lists/* \ + && sh /tmp/slim-onnxruntime.sh /layers/dependencies + # Stage 3: Final runtime image on top of pre-built base FROM ${BASE_IMAGE} @@ -70,6 +88,9 @@ COPY --link --from=app-build --chown=1000:1000 \ /app/build/libs/restart-helper.jar /restart-helper.jar COPY --link --chown=1000:1000 scripts/ /scripts/ +# Pre-bundled default detection model; FormDetectionModelManager seeds it into /configs on startup. +COPY --link --from=app-build --chown=1000:1000 /preinstalled-models/ /opt/stirling/preinstalled-models/ + # Fonts go to system dir, root ownership is correct (world-readable) COPY app/core/src/main/resources/static/fonts/*.ttf /usr/share/fonts/truetype/ @@ -92,6 +113,7 @@ RUN echo "${VERSION_TAG:-dev}" > /etc/stirling_version # Environment variables ENV VERSION_TAG=$VERSION_TAG \ + FORMDETECTION_PREINSTALLEDMODELDIR="/opt/stirling/preinstalled-models" \ STIRLING_AOT_ENABLE="false" \ STIRLING_JVM_PROFILE="balanced" \ _JVM_OPTS_BALANCED="-XX:+ExitOnOutOfMemoryError -XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/configs/heap_dumps -XX:+UseG1GC -XX:MaxGCPauseMillis=200 -XX:G1HeapRegionSize=4m -XX:G1PeriodicGCInterval=60000 -XX:+UseStringDeduplication -XX:+UseCompactObjectHeaders -XX:+ExplicitGCInvokesConcurrent -Dspring.threads.virtual.enabled=true -Djava.awt.headless=true" \ diff --git a/frontend/editor/public/locales/en-US/translation.toml b/frontend/editor/public/locales/en-US/translation.toml index 373a253b05..b007f20c63 100644 --- a/frontend/editor/public/locales/en-US/translation.toml +++ b/frontend/editor/public/locales/en-US/translation.toml @@ -558,6 +558,39 @@ close = "Close" error = "Error" expand = "Expand" +[admin.formDetection] +active = "Active" +description = "Install the AI model used to auto-detect form fields. The selected model is downloaded on demand (about 40-100MB) into the configs volume and is not bundled with Stirling-PDF." +enableFeature = "Enable feature" +engineDescription = "Browser keeps the PDF on the device (downloads a ~12MB runtime once, then cached); Server runs it on the backend; Auto prefers the browser and falls back to the server." +engineLabel = "Where detection runs" +install = "Install" +license = "License" +notAvailable = "This catalog entry has no download URL/checksum configured yet, so it cannot be installed." +notWritable = "The model directory is not writable; check the configs volume mount." +selectModel = "Model" +selectPlaceholder = "Select a model" +status = "Status" +switch = "Switch to this model" +title = "AI Form Detection" +uninstall = "Uninstall" + +[admin.formDetection.airgap] +hide = "Hide offline install instructions" +intro = "No internet on the server? Install the model manually:" +noSha = "(checksum not set)" +show = "Air-gapped / offline install instructions" +step1 = "1. On a machine with internet, download the model file:" +step2 = "2. Verify its SHA-256 checksum matches:" +step3 = "3. Copy it onto the Stirling-PDF server into the model directory:" +step4 = "4. Set formDetection.activeModelId to this model's id in settings.yml, then restart (an installed model is auto-detected on boot). is the configs volume (e.g. /configs in Docker). The Docker image also pre-bundles this model, so it is ready with no download." + +[admin.formDetection.engine] +auto = "Auto" +browser = "Browser" +server = "Server" +serverUnavailable = "The server engine is not bundled in this build, so detection runs in the browser. Use Auto or Browser." + [admin.settings] discard = "Discard" error = "Failed to save settings" @@ -1837,6 +1870,16 @@ bullet2 = "Creates a clean, valid filename from the detected title" bullet3 = "Keeps the original name if no suitable title is found" title = "Smart Renaming" +[autoFormDetection] +loading = "Detecting form fields..." +submit = "Detect & make fillable" + +[autoFormDetection.error] +failed = "An error occurred while detecting form fields." + +[autoFormDetection.results] +title = "Detected Form Fields" + [automate] copyToSaved = "Copy to Saved" desc = "Build multi-step workflows by chaining together PDF actions. Ideal for recurring tasks." @@ -4207,6 +4250,11 @@ desc = "Highlight, draw, add notes and shapes in the viewer" tags = "annotate,highlight,draw,markup,comment,notes,review,redline,feedback,markup tools,sticky notes,shapes,arrows,text box,freehand" title = "Annotate" +[home.autoFormDetection] +desc = "Automatically detect form fields with AI and make your PDF fillable." +tags = "form,detect,fillable,acroform,auto detect,ai,fields,make fillable,form detection" +title = "Auto Form Detection" + [home.automate] desc = "Build multi-step workflows by chaining together PDF actions. Ideal for recurring tasks." tags = "workflow,sequence,automation,automate,batch,batch processing,pipeline,chain,multi-step,recurring,scheduled,automatic,process multiple,bulk operations" diff --git a/frontend/editor/src/core/data/useTranslatedToolRegistry.tsx b/frontend/editor/src/core/data/useTranslatedToolRegistry.tsx index 74b676dceb..f884c04452 100644 --- a/frontend/editor/src/core/data/useTranslatedToolRegistry.tsx +++ b/frontend/editor/src/core/data/useTranslatedToolRegistry.tsx @@ -25,6 +25,7 @@ import { addWatermarkOperationConfig } from "@app/hooks/tools/addWatermark/useAd import { addStampOperationConfig } from "@app/components/tools/addStamp/useAddStampOperation"; import { addAttachmentsOperationConfig } from "@app/hooks/tools/addAttachments/useAddAttachmentsOperation"; import { unlockPdfFormsOperationConfig } from "@app/hooks/tools/unlockPdfForms/useUnlockPdfFormsOperation"; +import { autoFormDetectionOperationConfig } from "@app/hooks/tools/autoFormDetection/useAutoFormDetectionOperation"; import { singleLargePageOperationConfig } from "@app/hooks/tools/singleLargePage/useSingleLargePageOperation"; import { ocrOperationConfig } from "@app/hooks/tools/ocr/useOCROperation"; import { convertOperationConfig } from "@app/hooks/tools/convert/useConvertOperation"; @@ -446,6 +447,39 @@ export function useTranslatedToolCatalog(): TranslatedToolCatalog { supportsAutomate: false, synonyms: ["form", "fill", "fillable", "input", "field", "acroform"], }, + autoFormDetection: { + icon: ( + + ), + name: t("home.autoFormDetection.title", "Auto Form Detection"), + component: lazy( + () => import("@app/tools/autoFormDetection/AutoFormDetection"), + ), + description: t( + "home.autoFormDetection.desc", + "Automatically detect form fields with AI and make your PDF fillable.", + ), + categoryId: ToolCategoryId.STANDARD_TOOLS, + subcategoryId: SubcategoryId.GENERAL, + maxFiles: 1, + endpoints: ["form-detection"], + operationConfig: autoFormDetectionOperationConfig, + synonyms: [ + "form", + "detect", + "fillable", + "acroform", + "ai", + "fields", + "auto", + ], + automationSettings: null, + supportsAutomate: false, + }, changePermissions: { icon: , name: t("home.changePermissions.title", "Change Permissions"), diff --git a/frontend/editor/src/core/hooks/tools/autoFormDetection/useAutoFormDetectionOperation.ts b/frontend/editor/src/core/hooks/tools/autoFormDetection/useAutoFormDetectionOperation.ts new file mode 100644 index 0000000000..88a84387c1 --- /dev/null +++ b/frontend/editor/src/core/hooks/tools/autoFormDetection/useAutoFormDetectionOperation.ts @@ -0,0 +1,133 @@ +import { useTranslation } from "react-i18next"; + +import { + ToolType, + useToolOperation, +} from "@app/hooks/tools/shared/useToolOperation"; +import { + FormDetectionModelStatus, + FormDetectionCatalogEntry, +} from "@app/hooks/useFormDetectionModelStatus"; +import apiClient from "@app/services/apiClient"; +import { createStandardErrorHandler } from "@app/utils/toolErrorHandler"; +import { + AutoFormDetectionParameters, + defaultParameters, +} from "@app/hooks/tools/autoFormDetection/useAutoFormDetectionParameters"; + +const DETECT_ENDPOINT = "/api/v1/ai/form-detection/detect"; +const STATUS_URL = "/api/v1/ai/form-detection-model/status"; + +// Static function shared by the hook and the automation executor. +export const buildAutoFormDetectionFormData = ( + parameters: AutoFormDetectionParameters, + file: File, +): FormData => { + const formData = new FormData(); + formData.append("file", file); + // Server path detects and applies the AcroForm in one call, returning the fillable PDF. + formData.append("applyToPdf", "true"); + if (typeof parameters.confidence === "number") { + formData.append("confThreshold", String(parameters.confidence)); + } + return formData; +}; + +function outputName(file: File): string { + const base = (file.name || "document").replace(/\.pdf$/i, ""); + return `${base}_form.pdf`; +} + +async function serverDetect( + parameters: AutoFormDetectionParameters, + file: File, +): Promise { + const res = await apiClient.post( + DETECT_ENDPOINT, + buildAutoFormDetectionFormData(parameters, file), + { responseType: "blob" }, + ); + return new File([res.data as Blob], outputName(file), { + type: "application/pdf", + }); +} + +async function browserDetect( + parameters: AutoFormDetectionParameters, + file: File, + entry: FormDetectionCatalogEntry, +): Promise { + // Lazy-load the in-browser engine (onnxruntime-web + the ~12MB wasm) only when browser-mode + // detection actually runs - it is never pulled into the initial bundle or loaded on the homepage. + const { runBrowserDetection } = + await import("@app/services/formDetection/runBrowserPipeline"); + const bytes = await file.arrayBuffer(); + const { appliedPdf } = await runBrowserDetection( + bytes, + entry, + parameters.confidence, + ); + return new File([new Uint8Array(appliedPdf)], outputName(file), { + type: "application/pdf", + }); +} + +/** + * Runs detection where the admin configured it to run: 'server' (upload), 'browser' (in-browser + * WASM only - the PDF never leaves the device, no fallback), or 'auto' (browser first, falling back + * to the server on any browser-path error). + */ +async function processAutoFormDetection( + parameters: AutoFormDetectionParameters, + files: File[], +): Promise<{ files: File[] }> { + const file = files[0]; + + const status = (await apiClient.get(STATUS_URL)) + .data as FormDetectionModelStatus; + const mode = status.executionMode ?? "auto"; + const activeEntry = (status.catalog ?? []).find( + (c) => c.id === status.activeModelId, + ); + + if (mode === "server" || !activeEntry) { + return { files: [await serverDetect(parameters, file)] }; + } + if (mode === "browser") { + // Strict: never fall back to the server, so the PDF truly stays on the device. + return { files: [await browserDetect(parameters, file, activeEntry)] }; + } + // auto: prefer the browser, fall back to the server if it fails. + try { + return { files: [await browserDetect(parameters, file, activeEntry)] }; + } catch (e) { + console.warn( + "[AutoFormDetection] in-browser engine failed; falling back to server", + e, + ); + return { files: [await serverDetect(parameters, file)] }; + } +} + +export const autoFormDetectionOperationConfig = { + toolType: ToolType.custom, + customProcessor: processAutoFormDetection, + operationType: "autoFormDetection", + // Used only for cloud/credit routing of the server path; execution is the custom processor. + endpoint: DETECT_ENDPOINT, + defaultParameters, +} as const; + +export const useAutoFormDetectionOperation = () => { + const { t } = useTranslation(); + + return useToolOperation({ + ...autoFormDetectionOperationConfig, + getErrorMessage: createStandardErrorHandler( + t( + "autoFormDetection.error.failed", + "An error occurred while detecting form fields.", + ), + ), + }); +}; diff --git a/frontend/editor/src/core/hooks/tools/autoFormDetection/useAutoFormDetectionParameters.ts b/frontend/editor/src/core/hooks/tools/autoFormDetection/useAutoFormDetectionParameters.ts new file mode 100644 index 0000000000..806bfed49e --- /dev/null +++ b/frontend/editor/src/core/hooks/tools/autoFormDetection/useAutoFormDetectionParameters.ts @@ -0,0 +1,24 @@ +import { BaseParameters } from "@app/types/parameters"; +import { + useBaseParameters, + BaseParametersHook, +} from "@app/hooks/tools/shared/useBaseParameters"; + +export interface AutoFormDetectionParameters extends BaseParameters { + /** Optional confidence threshold override (0-1); blank uses the model default. */ + confidence?: number; +} + +export const defaultParameters: AutoFormDetectionParameters = {}; + +export type AutoFormDetectionParametersHook = + BaseParametersHook; + +export const useAutoFormDetectionParameters = + (): AutoFormDetectionParametersHook => { + return useBaseParameters({ + defaultParameters, + // Gated endpoint key - the tool tile/button stay disabled until a model is installed. + endpointName: "form-detection", + }); + }; diff --git a/frontend/editor/src/core/hooks/useEndpointConfig.ts b/frontend/editor/src/core/hooks/useEndpointConfig.ts index 9bd5f17126..da533a3a43 100644 --- a/frontend/editor/src/core/hooks/useEndpointConfig.ts +++ b/frontend/editor/src/core/hooks/useEndpointConfig.ts @@ -245,3 +245,11 @@ export function useMultipleEndpointsEnabled(endpoints: string[]): { refetch: () => fetchAllEndpointStatuses(true), }; } + +/** + * Invalidate the cached endpoint-availability map so the next check refetches. + * Call after an admin action changes availability (e.g. installing the form-detection model). + */ +export function invalidateEndpointCache() { + resetGlobalCache(); +} diff --git a/frontend/editor/src/core/hooks/useFormDetectionModelStatus.ts b/frontend/editor/src/core/hooks/useFormDetectionModelStatus.ts new file mode 100644 index 0000000000..b3ac5badfd --- /dev/null +++ b/frontend/editor/src/core/hooks/useFormDetectionModelStatus.ts @@ -0,0 +1,139 @@ +import { useCallback, useEffect, useState } from "react"; +import apiClient from "@app/services/apiClient"; +import { invalidateEndpointCache } from "@app/hooks/useEndpointConfig"; + +export interface FormDetectionCatalogEntry { + id: string; + displayName: string; + description: string; + license: string; + sizeBytes: number; + onnxUrl: string; + sha256: string; + // Pipeline spec (parity with the backend ModelCatalogEntry) - drives the in-browser engine. + inputSize: number; + resizeMode?: string; + padColor?: number[]; + channelOrder?: string; + normMean?: number[]; + normStd?: number[]; + outputLayout?: string; + hasObjectness?: boolean; + classNames?: string[]; + classFieldTypes?: string[]; + scoreThreshold?: number; + nms?: string; + iou?: number; +} + +export type FormDetectionState = + | "not_installed" + | "downloading" + | "verifying" + | "ready" + | "failed"; + +export type FormDetectionExecutionMode = "auto" | "browser" | "server"; + +export interface FormDetectionModelStatus { + status: FormDetectionState; + progress: number; + activeModelId: string; + installed: string[]; + error: string | null; + writable: boolean; + catalog: FormDetectionCatalogEntry[]; + enabled: boolean; + executionMode: FormDetectionExecutionMode; + serverEngineAvailable: boolean; +} + +const STATUS_URL = "/api/v1/ai/form-detection-model/status"; +const INSTALL_URL = "/api/v1/ai/form-detection-model/install"; +const CONFIG_URL = "/api/v1/ai/form-detection-model/config"; +const MODEL_URL = "/api/v1/ai/form-detection-model"; + +/** + * Polls the Auto Form Detection model status and exposes admin install/uninstall actions. + * Polling only runs while a download/verify is in flight. When readiness flips, the shared + * endpoint-availability cache is invalidated so the tool tile re-enables/disables. + */ +export function useFormDetectionModelStatus() { + const [status, setStatus] = useState(null); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + const fetchStatus = useCallback(async () => { + try { + const res = await apiClient.get(STATUS_URL); + setStatus(res.data); + setError(null); + } catch (e) { + setError(e instanceof Error ? e.message : "Failed to load model status"); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + fetchStatus(); + }, [fetchStatus]); + + const active = status?.status; + + // Poll only while an install is in flight. + useEffect(() => { + if (active === "downloading" || active === "verifying") { + const id = setInterval(fetchStatus, 1500); + return () => clearInterval(id); + } + return undefined; + }, [active, fetchStatus]); + + // When readiness flips, the tool availability cache must be refreshed. + useEffect(() => { + if (active === "ready" || active === "not_installed") { + invalidateEndpointCache(); + } + }, [active]); + + const install = useCallback( + async (modelId: string) => { + await apiClient.post(INSTALL_URL, { modelId }); + await fetchStatus(); + }, + [fetchStatus], + ); + + const uninstall = useCallback( + async (modelId?: string) => { + const url = modelId + ? `${MODEL_URL}?modelId=${encodeURIComponent(modelId)}` + : MODEL_URL; + await apiClient.delete(url); + await fetchStatus(); + }, + [fetchStatus], + ); + + const setConfig = useCallback( + async (config: { + enabled?: boolean; + executionMode?: FormDetectionExecutionMode; + }) => { + await apiClient.post(CONFIG_URL, config); + await fetchStatus(); + }, + [fetchStatus], + ); + + return { + status, + loading, + error, + refetch: fetchStatus, + install, + uninstall, + setConfig, + }; +} diff --git a/frontend/editor/src/core/services/formDetection/applyFields.ts b/frontend/editor/src/core/services/formDetection/applyFields.ts new file mode 100644 index 0000000000..f1ac5502d9 --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/applyFields.ts @@ -0,0 +1,51 @@ +// Apply detected fields as a real AcroForm using @cantoo/pdf-lib (the same library the rest of the +// app uses). The browser counterpart of FormUtils.addFields; coordinates are already in PDF points +// with a bottom-left origin, which is exactly what addToPage expects. + +import { PDFDocument } from "@cantoo/pdf-lib"; + +import { DetectedField } from "@app/services/formDetection/types"; + +export async function applyFields( + pdfBytes: ArrayBuffer | Uint8Array, + fields: DetectedField[], +): Promise { + const pdfDoc = await PDFDocument.load(pdfBytes, { + ignoreEncryption: true, + throwOnInvalidObject: false, + }); + const form = pdfDoc.getForm(); + const pages = pdfDoc.getPages(); + const counts: Record = {}; + + for (const f of fields) { + const page = pages[f.page]; + if (!page) continue; + const r = f.rectInPdfPoints; + if (r.w <= 0 || r.h <= 0) continue; + + const kind = + f.type === "checkbox" + ? "checkbox" + : f.type === "signature" + ? "signature" + : "text"; + counts[kind] = (counts[kind] ?? 0) + 1; + const name = `${kind}_${f.page + 1}_${counts[kind]}`; + + try { + if (kind === "checkbox") { + const cb = form.createCheckBox(name); + cb.addToPage(page, { x: r.x, y: r.y, width: r.w, height: r.h }); + } else { + // pdf-lib has no first-class signature widget; a text field keeps it fillable. + const tf = form.createTextField(name); + tf.addToPage(page, { x: r.x, y: r.y, width: r.w, height: r.h }); + } + } catch { + // Skip a field that fails to add (e.g. a duplicate name) rather than abort the whole doc. + } + } + + return pdfDoc.save(); +} diff --git a/frontend/editor/src/core/services/formDetection/coordinateMapping.ts b/frontend/editor/src/core/services/formDetection/coordinateMapping.ts new file mode 100644 index 0000000000..c31a9a3e3d --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/coordinateMapping.ts @@ -0,0 +1,32 @@ +// Map a detection (original bitmap pixels, top-left origin) to PDF points (bottom-left origin). +// 1:1 port of CoordinateMapper.toPdfPoints in the backend. + +import { Detection, RectPt } from "@app/services/formDetection/types"; + +export interface RasterPageInfo { + pageWidthPt: number; + pageHeightPt: number; + scaleX: number; // pixels per point + scaleY: number; +} + +function clamp(v: number, lo: number, hi: number): number { + return v < lo ? lo : Math.min(v, hi); +} + +export function toPdfPoints(d: Detection, page: RasterPageInfo): RectPt { + const sx = page.scaleX > 0 ? page.scaleX : 1; + const sy = page.scaleY > 0 ? page.scaleY : 1; + + const wPt = d.w / sx; + const hPt = d.h / sy; + let xPt = d.x / sx; + // Flip Y: bitmap origin is top-left, PDF origin is bottom-left. + let yPt = page.pageHeightPt - d.y / sy - hPt; + + xPt = clamp(xPt, 0, page.pageWidthPt); + yPt = clamp(yPt, 0, page.pageHeightPt); + const w = clamp(wPt, 0, page.pageWidthPt - xPt); + const h = clamp(hPt, 0, page.pageHeightPt - yPt); + return { x: xPt, y: yPt, w, h }; +} diff --git a/frontend/editor/src/core/services/formDetection/decode.test.ts b/frontend/editor/src/core/services/formDetection/decode.test.ts new file mode 100644 index 0000000000..a05706a0ce --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/decode.test.ts @@ -0,0 +1,90 @@ +import { describe, it, expect } from "vitest"; + +import { decode } from "@app/services/formDetection/decode"; +import { + ModelPipelineSpec, + Preprocessed, + RawOutput, +} from "@app/services/formDetection/types"; + +// Golden fixture shared VERBATIM with the backend YoloTest.java +// (decodeThresholdsAndSuppressesOverlaps). Asserting the same inputs -> same outputs in both +// suites proves the in-browser decode matches the server decode 1:1, so the two execution paths +// produce identical detections. + +const spec: ModelPipelineSpec = { + inputSize: 10, + resizeMode: "letterbox", + padColor: [114, 114, 114], + channelOrder: "rgb", + normMean: [0, 0, 0], + normStd: [1, 1, 1], + outputLayout: "nc_first", + hasObjectness: false, + classNames: ["text", "choice"], + classFieldTypes: ["text", "checkbox"], + scoreThreshold: 0.5, + nms: "perClass", + iou: 0.5, +}; + +// identity transform (scale 1, no pad), 10x10 source +const pre: Preprocessed = { + chw: new Float32Array(0), + inputSize: 10, + scaleX: 1, + scaleY: 1, + padX: 0, + padY: 0, + srcW: 10, + srcH: 10, +}; + +// nc_first layout [channels=6][anchors=3], data[c*anchors + a] +// box A (cx5,cy5,w4,h4) twice (overlapping) + box B (cx8,cy8,w2,h2) +const data = [ + 5, + 5, + 8, // cx + 5, + 5, + 8, // cy + 4, + 4, + 2, // w + 4, + 4, + 2, // h + 0.9, + 0.8, + 0.7, // text score + 0.1, + 0.1, + 0.1, // choice score +]; +const out: RawOutput = { data, d1: 6, d2: 3 }; + +describe("formDetection decode (parity with backend Yolo.decode)", () => { + it("thresholds and suppresses overlaps identically to the Java golden fixture", () => { + const dets = decode(out, spec, pre, 0.5); + // a0 (box A, 0.9) kept; a1 (box A', 0.8) suppressed by NMS; a2 (box B, 0.7) kept + expect(dets).toHaveLength(2); + + expect(dets[0].classId).toBe(0); + expect(dets[0].score).toBeCloseTo(0.9, 5); + expect(dets[0].x).toBeCloseTo(3, 4); + expect(dets[0].y).toBeCloseTo(3, 4); + expect(dets[0].w).toBeCloseTo(4, 4); + expect(dets[0].h).toBeCloseTo(4, 4); + + expect(dets[1].score).toBeCloseTo(0.7, 5); + expect(dets[1].x).toBeCloseTo(7, 4); + expect(dets[1].y).toBeCloseTo(7, 4); + expect(dets[1].w).toBeCloseTo(2, 4); + expect(dets[1].h).toBeCloseTo(2, 4); + }); + + it("drops everything when the threshold exceeds all scores", () => { + expect(decode(out, spec, pre, 0.95)).toHaveLength(0); + }); +}); diff --git a/frontend/editor/src/core/services/formDetection/decode.ts b/frontend/editor/src/core/services/formDetection/decode.ts new file mode 100644 index 0000000000..3dd32b76f4 --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/decode.ts @@ -0,0 +1,123 @@ +// Pure decode/NMS/un-projection - a 1:1 port of Yolo.decode in the backend. Kept free of any +// browser API so it can be unit-tested for parity against the Java golden output. + +import { + Detection, + ModelPipelineSpec, + Preprocessed, + RawOutput, +} from "@app/services/formDetection/types"; + +function at( + data: Float32Array | number[], + ncFirst: boolean, + anchors: number, + channels: number, + c: number, + a: number, +): number { + return ncFirst ? data[c * anchors + a] : data[a * channels + c]; +} + +function iou(a: Detection, b: Detection): number { + const ax2 = a.x + a.w; + const ay2 = a.y + a.h; + const bx2 = b.x + b.w; + const by2 = b.y + b.h; + const ix1 = Math.max(a.x, b.x); + const iy1 = Math.max(a.y, b.y); + const ix2 = Math.min(ax2, bx2); + const iy2 = Math.min(ay2, by2); + const iw = Math.max(0, ix2 - ix1); + const ih = Math.max(0, iy2 - iy1); + const inter = iw * ih; + const union = a.w * a.h + b.w * b.h - inter; + return union <= 0 ? 0 : inter / union; +} + +function nms( + dets: Detection[], + mode: string, + iouThreshold: number, +): Detection[] { + if (dets.length < 2 || (mode ?? "").toLowerCase() === "none") { + return dets; + } + const classAgnostic = (mode ?? "").toLowerCase().includes("agnostic"); + const sorted = [...dets].sort((x, y) => y.score - x.score); + const removed = new Array(sorted.length).fill(false); + const keep: Detection[] = []; + for (let i = 0; i < sorted.length; i++) { + if (removed[i]) continue; + const di = sorted[i]; + keep.push(di); + for (let j = i + 1; j < sorted.length; j++) { + if (removed[j]) continue; + const dj = sorted[j]; + if (!classAgnostic && di.classId !== dj.classId) continue; + if (iou(di, dj) > iouThreshold) removed[j] = true; + } + } + return keep; +} + +/** Decode raw output, threshold, NMS, and un-project to original bitmap pixels. */ +export function decode( + out: RawOutput, + spec: ModelPipelineSpec, + pre: Preprocessed, + scoreThreshold: number, +): Detection[] { + const numClasses = spec.classNames?.length ?? 0; + if (numClasses === 0) return []; + const obj = spec.hasObjectness; + const ncFirst = (spec.outputLayout ?? "").toLowerCase() !== "anchors_first"; + const channels = ncFirst ? out.d1 : out.d2; + const anchors = ncFirst ? out.d2 : out.d1; + const expected = 4 + (obj ? 1 : 0) + numClasses; + if (channels < expected) { + return []; + } + const classOffset = 4 + (obj ? 1 : 0); + const data = out.data; + + const dets: Detection[] = []; + for (let a = 0; a < anchors; a++) { + const objScore = obj ? at(data, ncFirst, anchors, channels, 4, a) : 1; + let bestClass = -1; + let bestScore = 0; + for (let c = 0; c < numClasses; c++) { + const s = + at(data, ncFirst, anchors, channels, classOffset + c, a) * objScore; + if (s > bestScore) { + bestScore = s; + bestClass = c; + } + } + if (bestClass < 0 || bestScore < scoreThreshold) continue; + const cx = at(data, ncFirst, anchors, channels, 0, a); + const cy = at(data, ncFirst, anchors, channels, 1, a); + const w = at(data, ncFirst, anchors, channels, 2, a); + const h = at(data, ncFirst, anchors, channels, 3, a); + const x1 = cx - w / 2; + const y1 = cy - h / 2; + const ox = (x1 - pre.padX) / pre.scaleX; + const oy = (y1 - pre.padY) / pre.scaleY; + let ow = w / pre.scaleX; + let oh = h / pre.scaleY; + const cxl = Math.max(0, Math.min(ox, pre.srcW)); + const cyl = Math.max(0, Math.min(oy, pre.srcH)); + ow = Math.max(0, Math.min(ow, pre.srcW - cxl)); + oh = Math.max(0, Math.min(oh, pre.srcH - cyl)); + if (ow <= 0 || oh <= 0) continue; + dets.push({ + classId: bestClass, + score: bestScore, + x: cxl, + y: cyl, + w: ow, + h: oh, + }); + } + return nms(dets, spec.nms, spec.iou); +} diff --git a/frontend/editor/src/core/services/formDetection/modelCache.ts b/frontend/editor/src/core/services/formDetection/modelCache.ts new file mode 100644 index 0000000000..98f2d37775 --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/modelCache.ts @@ -0,0 +1,69 @@ +// Fetch the active .onnx from the backend serve endpoint, verify its SHA-256, and keep it in the +// Cache API keyed by checksum so it is downloaded only once per device (then reused across reloads). + +const MODEL_FILE_URL = "/api/v1/ai/form-detection-model/file"; +const CACHE_NAME = "stirling-form-detection-models"; + +function toHex(buf: ArrayBuffer): string { + const bytes = new Uint8Array(buf); + let out = ""; + for (let i = 0; i < bytes.length; i++) { + out += bytes[i].toString(16).padStart(2, "0"); + } + return out; +} + +async function verify(bytes: ArrayBuffer, expectedSha?: string): Promise { + if (!expectedSha) return; + const digest = await crypto.subtle.digest("SHA-256", bytes); + const actual = toHex(digest); + if (actual.toLowerCase() !== expectedSha.toLowerCase()) { + throw new Error( + `Model checksum mismatch (expected ${expectedSha}, got ${actual})`, + ); + } +} + +/** + * Return the active model bytes, from the Cache API when present (and checksum-valid) or by + * downloading from the backend. The cache key is the checksum, so a model swap naturally misses. + */ +export async function loadModelBytes( + expectedSha?: string, +): Promise { + const cacheKey = `${MODEL_FILE_URL}#${expectedSha ?? "nosha"}`; + // Cache API is unavailable in non-secure contexts; degrade to a plain download in that case. + const cache = await caches.open(CACHE_NAME).catch(() => null); + + if (cache) { + const hit = await cache.match(cacheKey).catch(() => undefined); + if (hit) { + const buf = await hit.arrayBuffer(); + try { + await verify(buf, expectedSha); + return buf; + } catch { + await cache.delete(cacheKey).catch(() => false); // stale/corrupt - re-download + } + } + } + + const res = await fetch(MODEL_FILE_URL, { credentials: "include" }); + if (!res.ok) { + throw new Error(`Model download failed: HTTP ${res.status}`); + } + const buf = await res.arrayBuffer(); + await verify(buf, expectedSha); + + if (cache) { + await cache + .put( + cacheKey, + new Response(buf, { + headers: { "Content-Type": "application/octet-stream" }, + }), + ) + .catch(() => undefined); // best-effort; ignore quota/availability errors + } + return buf; +} diff --git a/frontend/editor/src/core/services/formDetection/onnxSession.ts b/frontend/editor/src/core/services/formDetection/onnxSession.ts new file mode 100644 index 0000000000..4bfd582531 --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/onnxSession.ts @@ -0,0 +1,60 @@ +// Wraps onnxruntime-web: points it at the locally-hosted CPU WASM (copied by Vite into /ort/), +// runs single-threaded (the app sets no COOP/COEP so SharedArrayBuffer threading is unavailable), +// and caches one session per model checksum. Output is returned in the same flat layout the +// backend uses so decode.ts can interpret it identically. + +import * as ort from "onnxruntime-web"; + +import { RawOutput } from "@app/services/formDetection/types"; + +let configured = false; +function configureOrt(): void { + if (configured) return; + ort.env.wasm.numThreads = 1; + // The CPU SIMD .wasm + its loader are copied next to the app under /ort/ by vite.config.ts. + ort.env.wasm.wasmPaths = new URL("ort/", document.baseURI).href; + configured = true; +} + +let session: ort.InferenceSession | null = null; +let sessionKey: string | null = null; + +/** Create (or reuse) a session for the given model bytes; keyed by checksum so swaps reload. */ +export async function getSession( + modelBytes: ArrayBuffer, + key: string, +): Promise { + configureOrt(); + if (session && sessionKey === key) return session; + if (session) { + try { + await session.release(); + } catch { + // ignore + } + session = null; + sessionKey = null; + } + session = await ort.InferenceSession.create(modelBytes, { + executionProviders: ["wasm"], + graphOptimizationLevel: "all", + }); + sessionKey = key; + return session; +} + +export async function runInference( + s: ort.InferenceSession, + chw: Float32Array, + inputSize: number, +): Promise { + const inputName = s.inputNames[0]; + const tensor = new ort.Tensor("float32", chw, [1, 3, inputSize, inputSize]); + const result = await s.run({ [inputName]: tensor }); + const out = result[s.outputNames[0]]; + const dims = out.dims; + // Expect [1, d1, d2]; data is flat row-major so data[i*d2 + j] == out[0][i][j]. + const d1 = dims.length >= 2 ? Number(dims[1]) : 0; + const d2 = dims.length >= 3 ? Number(dims[2]) : 0; + return { data: out.data as Float32Array, d1, d2 }; +} diff --git a/frontend/editor/src/core/services/formDetection/pdfRender.ts b/frontend/editor/src/core/services/formDetection/pdfRender.ts new file mode 100644 index 0000000000..4545b9f3b6 --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/pdfRender.ts @@ -0,0 +1,87 @@ +// Render each PDF page to an RGBA bitmap via PDF.js, at a DPI chosen so the long side is about the +// model input size. Mirrors the backend PageRasterizer: the px-per-point scale is taken from the +// rendered canvas so coordinate mapping does not depend on how the DPI rounds. + +export interface RasterPage { + pageIndex: number; + rgba: Uint8ClampedArray; + widthPx: number; + heightPx: number; + pageWidthPt: number; + pageHeightPt: number; + scaleX: number; + scaleY: number; +} + +declare global { + interface Window { + pdfjsLib?: { + getDocument: (src: { data: ArrayBuffer | Uint8Array }) => { + promise: Promise; + }; + }; + } +} + +interface PdfViewport { + width: number; + height: number; +} +interface PdfPageProxy { + getViewport: (opts: { scale: number }) => PdfViewport; + render: (opts: { + canvasContext: CanvasRenderingContext2D; + viewport: PdfViewport; + }) => { promise: Promise }; +} +interface PdfDocumentProxy { + numPages: number; + getPage: (n: number) => Promise; +} + +export async function renderPages( + pdfBytes: ArrayBuffer | Uint8Array, + inputSize: number, +): Promise { + const pdfjs = window.pdfjsLib; + if (!pdfjs) throw new Error("PDF.js is not available in this build"); + + const data = + pdfBytes instanceof Uint8Array ? pdfBytes : new Uint8Array(pdfBytes); + const pdf = await pdfjs.getDocument({ data }).promise; + const pages: RasterPage[] = []; + for (let i = 1; i <= pdf.numPages; i++) { + const page = await pdf.getPage(i); + const base = page.getViewport({ scale: 1 }); + const pageWidthPt = base.width; + const pageHeightPt = base.height; + const maxSide = Math.max(pageWidthPt, pageHeightPt); + let dpi = maxSide <= 0 ? 150 : Math.round((72 * inputSize) / maxSide); + dpi = Math.max(36, Math.min(dpi, 300)); + const scale = dpi / 72; + const vp = page.getViewport({ scale }); + + const canvas = document.createElement("canvas"); + canvas.width = Math.max(1, Math.ceil(vp.width)); + canvas.height = Math.max(1, Math.ceil(vp.height)); + const ctx = canvas.getContext("2d", { willReadFrequently: true }); + if (!ctx) throw new Error("2D canvas context unavailable"); + // Forms render on white; PDF.js does not paint a background. + ctx.fillStyle = "#ffffff"; + ctx.fillRect(0, 0, canvas.width, canvas.height); + await page.render({ canvasContext: ctx, viewport: vp }).promise; + + const rgba = ctx.getImageData(0, 0, canvas.width, canvas.height).data; + pages.push({ + pageIndex: i - 1, + rgba, + widthPx: canvas.width, + heightPx: canvas.height, + pageWidthPt, + pageHeightPt, + scaleX: pageWidthPt > 0 ? canvas.width / pageWidthPt : scale, + scaleY: pageHeightPt > 0 ? canvas.height / pageHeightPt : scale, + }); + } + return pages; +} diff --git a/frontend/editor/src/core/services/formDetection/preprocess.ts b/frontend/editor/src/core/services/formDetection/preprocess.ts new file mode 100644 index 0000000000..97c3749be7 --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/preprocess.ts @@ -0,0 +1,107 @@ +// Letterbox/stretch-resize, channel-swap, normalise and lay out as NCHW float32 - the browser +// counterpart of Yolo.preprocess. The resize uses a 2D canvas (bilinear-ish); the normalisation +// step is split out as a pure function so it can be unit-tested against the Java golden vectors. + +import { + ModelPipelineSpec, + Preprocessed, +} from "@app/services/formDetection/types"; + +function clampByte(v: number): number { + return Math.max(0, Math.min(255, Math.round(v))); +} + +function orZeros(v?: number[]): number[] { + return v && v.length >= 3 ? v : [0, 0, 0]; +} + +function orOnes(v?: number[]): number[] { + return v && v.length >= 3 ? v : [1, 1, 1]; +} + +/** Pure: turn an NxN RGBA buffer into normalised NCHW float32 per the spec. */ +export function normalizeToCHW( + rgbaNxN: Uint8ClampedArray | Uint8Array, + n: number, + spec: ModelPipelineSpec, +): Float32Array { + const bgr = (spec.channelOrder ?? "").toLowerCase() === "bgr"; + const mean = orZeros(spec.normMean); + const std = orOnes(spec.normStd); + const plane = n * n; + const chw = new Float32Array(3 * plane); + for (let i = 0; i < plane; i++) { + const r = rgbaNxN[i * 4] / 255; + const g = rgbaNxN[i * 4 + 1] / 255; + const b = rgbaNxN[i * 4 + 2] / 255; + const c0 = bgr ? b : r; + const c1 = g; + const c2 = bgr ? r : b; + chw[i] = (c0 - mean[0]) / std[0]; + chw[plane + i] = (c1 - mean[1]) / std[1]; + chw[2 * plane + i] = (c2 - mean[2]) / std[2]; + } + return chw; +} + +/** Browser: resize source RGBA into the model's NxN input and normalise to NCHW float32. */ +export function preprocess( + rgba: Uint8ClampedArray | Uint8Array, + srcW: number, + srcH: number, + spec: ModelPipelineSpec, +): Preprocessed { + const n = spec.inputSize; + const letterbox = (spec.resizeMode ?? "").toLowerCase() !== "stretch"; + + let scaleX: number; + let scaleY: number; + let padX: number; + let padY: number; + let drawW: number; + let drawH: number; + if (letterbox) { + const scale = Math.min(n / srcW, n / srcH); + drawW = Math.max(1, Math.round(srcW * scale)); + drawH = Math.max(1, Math.round(srcH * scale)); + padX = Math.floor((n - drawW) / 2); + padY = Math.floor((n - drawH) / 2); + scaleX = scale; + scaleY = scale; + } else { + drawW = n; + drawH = n; + padX = 0; + padY = 0; + scaleX = n / srcW; + scaleY = n / srcH; + } + + const pad = spec.padColor ?? [114, 114, 114]; + const canvas = document.createElement("canvas"); + canvas.width = n; + canvas.height = n; + const ctx = canvas.getContext("2d", { willReadFrequently: true }); + if (!ctx) throw new Error("2D canvas context unavailable"); + ctx.fillStyle = `rgb(${clampByte(pad[0] ?? 114)},${clampByte(pad[1] ?? 114)},${clampByte(pad[2] ?? 114)})`; + ctx.fillRect(0, 0, n, n); + + // Source RGBA -> temp canvas, then draw scaled into the padded NxN canvas. + const src = document.createElement("canvas"); + src.width = srcW; + src.height = srcH; + const sctx = src.getContext("2d"); + if (!sctx) throw new Error("2D canvas context unavailable"); + // Copy into a fresh ArrayBuffer-backed array: ImageData's type rejects the + // Uint8ClampedArray form (TS 5.7 typed-array generics). + const buf = new Uint8ClampedArray(rgba); + sctx.putImageData(new ImageData(buf, srcW, srcH), 0, 0); + + ctx.imageSmoothingEnabled = true; + ctx.imageSmoothingQuality = "high"; + ctx.drawImage(src, padX, padY, drawW, drawH); + + const px = ctx.getImageData(0, 0, n, n).data; + const chw = normalizeToCHW(px, n, spec); + return { chw, inputSize: n, scaleX, scaleY, padX, padY, srcW, srcH }; +} diff --git a/frontend/editor/src/core/services/formDetection/runBrowserPipeline.ts b/frontend/editor/src/core/services/formDetection/runBrowserPipeline.ts new file mode 100644 index 0000000000..316a0f5e65 --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/runBrowserPipeline.ts @@ -0,0 +1,64 @@ +// Orchestrates the in-browser engine: fetch+cache the model, render each page, preprocess, run +// onnxruntime-web, decode, map to PDF points, then build the fillable AcroForm - entirely on the +// device, so the PDF never leaves the browser. Mirrors FormDetectionController.detect server-side. + +import { FormDetectionCatalogEntry } from "@app/hooks/useFormDetectionModelStatus"; + +import { applyFields } from "@app/services/formDetection/applyFields"; +import { toPdfPoints } from "@app/services/formDetection/coordinateMapping"; +import { decode } from "@app/services/formDetection/decode"; +import { loadModelBytes } from "@app/services/formDetection/modelCache"; +import { + getSession, + runInference, +} from "@app/services/formDetection/onnxSession"; +import { renderPages } from "@app/services/formDetection/pdfRender"; +import { preprocess } from "@app/services/formDetection/preprocess"; +import { DetectedField, resolveSpec } from "@app/services/formDetection/types"; + +export interface BrowserDetectResult { + fields: DetectedField[]; + appliedPdf: Uint8Array; +} + +export async function runBrowserDetection( + pdfBytes: ArrayBuffer, + activeEntry: FormDetectionCatalogEntry, + confThreshold?: number, +): Promise { + const spec = resolveSpec(activeEntry); + const score = + typeof confThreshold === "number" ? confThreshold : spec.scoreThreshold; + + const modelBytes = await loadModelBytes(activeEntry.sha256); + const session = await getSession( + modelBytes, + activeEntry.sha256 || activeEntry.id, + ); + + const fieldType = (classId: number): string => { + const types = spec.classFieldTypes; + return types && classId >= 0 && classId < types.length + ? types[classId] + : "text"; + }; + + // pdf.js may detach the input buffer, so give each consumer its own copy. + const pages = await renderPages(pdfBytes.slice(0), spec.inputSize); + const fields: DetectedField[] = []; + for (const page of pages) { + const pre = preprocess(page.rgba, page.widthPx, page.heightPx, spec); + const out = await runInference(session, pre.chw, spec.inputSize); + for (const d of decode(out, spec, pre, score)) { + fields.push({ + type: fieldType(d.classId), + page: page.pageIndex, + rectInPdfPoints: toPdfPoints(d, page), + confidence: d.score, + }); + } + } + + const appliedPdf = await applyFields(pdfBytes.slice(0), fields); + return { fields, appliedPdf }; +} diff --git a/frontend/editor/src/core/services/formDetection/types.ts b/frontend/editor/src/core/services/formDetection/types.ts new file mode 100644 index 0000000000..c2a9ced9fd --- /dev/null +++ b/frontend/editor/src/core/services/formDetection/types.ts @@ -0,0 +1,86 @@ +// Shared types for the in-browser Auto Form Detection pipeline. The numeric pipeline mirrors the +// backend (Yolo.java / CoordinateMapper.java) 1:1 so browser and server produce the same fields. + +import { FormDetectionCatalogEntry } from "@app/hooks/useFormDetectionModelStatus"; + +/** Pipeline spec resolved from the active catalog entry (with backend defaults applied). */ +export interface ModelPipelineSpec { + inputSize: number; + resizeMode: string; // "stretch" | "letterbox" + padColor: number[]; + channelOrder: string; // "rgb" | "bgr" + normMean: number[]; + normStd: number[]; + outputLayout: string; // "nc_first" | "anchors_first" + hasObjectness: boolean; + classNames: string[]; + classFieldTypes: string[]; + scoreThreshold: number; + nms: string; // "none" | "perClass" | "classAgnostic" + iou: number; +} + +/** A detection in original bitmap-pixel space, top-left origin. */ +export interface Detection { + classId: number; + score: number; + x: number; + y: number; + w: number; + h: number; +} + +/** Normalised model input plus the transform needed to invert it. */ +export interface Preprocessed { + chw: Float32Array; + inputSize: number; + scaleX: number; + scaleY: number; + padX: number; + padY: number; + srcW: number; + srcH: number; +} + +/** Raw model output flattened to data[i*d2 + j] with dims d1 x d2. */ +export interface RawOutput { + data: Float32Array | number[]; + d1: number; + d2: number; +} + +export interface RectPt { + x: number; + y: number; + w: number; + h: number; +} + +/** Shared output schema (mirrors the server /detect response). */ +export interface DetectedField { + type: string; + page: number; + rectInPdfPoints: RectPt; + confidence: number; +} + +/** Resolve a catalog entry's pipeline spec, applying the same defaults the backend uses. */ +export function resolveSpec( + entry: FormDetectionCatalogEntry, +): ModelPipelineSpec { + return { + inputSize: entry.inputSize > 0 ? entry.inputSize : 1216, + resizeMode: entry.resizeMode ?? "letterbox", + padColor: entry.padColor ?? [114, 114, 114], + channelOrder: entry.channelOrder ?? "rgb", + normMean: entry.normMean ?? [0, 0, 0], + normStd: entry.normStd ?? [1, 1, 1], + outputLayout: entry.outputLayout ?? "nc_first", + hasObjectness: entry.hasObjectness ?? false, + classNames: entry.classNames ?? [], + classFieldTypes: entry.classFieldTypes ?? [], + scoreThreshold: entry.scoreThreshold ?? 0.3, + nms: entry.nms ?? "perClass", + iou: entry.iou ?? 0.45, + }; +} diff --git a/frontend/editor/src/core/tests/stubbed/auto-form-detection.spec.ts b/frontend/editor/src/core/tests/stubbed/auto-form-detection.spec.ts new file mode 100644 index 0000000000..ea483e2e7f --- /dev/null +++ b/frontend/editor/src/core/tests/stubbed/auto-form-detection.spec.ts @@ -0,0 +1,124 @@ +import { test, expect, type Page } from "@playwright/test"; +import { + bypassOnboarding, + mockAppApis, + seedCookieConsent, +} from "@app/tests/helpers/api-stubs"; + +/** + * Stubbed coverage for the Auto Form Detection tool and its admin install panel. + * - the tool tile renders even when the `form-detection` endpoint is disabled + * (model not installed), so users can discover it, + * - the tile opens the tool once the endpoint reports enabled, + * - an admin sees the "AI Form Detection" settings section. + */ + +const MODEL_STATUS_NOT_INSTALLED = { + status: "not_installed", + progress: 0, + activeModelId: "", + installed: [], + error: null, + writable: true, + catalog: [ + { + id: "ffdnet-s", + displayName: "CommonForms FFDNet-S", + description: "Small form-field detector", + license: "CC-BY-4.0", + sizeBytes: 0, + onnxUrl: "", + sha256: "", + inputSize: 1024, + }, + ], +}; + +async function stubModelStatus(page: Page) { + await page.route("**/api/v1/ai/form-detection-model/status", (route) => + route.fulfill({ json: MODEL_STATUS_NOT_INSTALLED }), + ); +} + +test.describe("Auto Form Detection tool", () => { + test("tile is present even when the model endpoint is disabled", async ({ + page, + }) => { + await seedCookieConsent(page); + await bypassOnboarding(page); + await stubModelStatus(page); + await mockAppApis(page, { + endpointsAvailability: { "form-detection": { enabled: false } }, + }); + await page.goto("/"); + + const tile = page.locator('[data-tour="tool-button-autoFormDetection"]'); + await expect(tile.first()).toBeVisible({ timeout: 10_000 }); + }); + + test("clicking the tile opens the tool when the endpoint is enabled", async ({ + page, + }) => { + await seedCookieConsent(page); + await bypassOnboarding(page); + await stubModelStatus(page); + await mockAppApis(page, { + endpointsAvailability: { "form-detection": { enabled: true } }, + }); + await page.goto("/"); + + await page + .locator('[data-tour="tool-button-autoFormDetection"]') + .first() + .click(); + await expect(page).toHaveURL(/auto-form-detection/i); + }); + + test("admin sees the AI Form Detection box inside the Features section", async ({ + page, + }) => { + await seedCookieConsent(page); + await bypassOnboarding(page); + // Seed JWT so the auth-gated dashboard chrome renders for the admin user. + await page.addInitScript(() => { + localStorage.setItem( + "stirling_jwt", + "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJhZG1pbiJ9.signature", + ); + }); + await stubModelStatus(page); + await mockAppApis(page, { + enableLogin: true, + isAdmin: true, + user: { + id: 1, + username: "admin", + email: "admin@example.com", + roles: ["ROLE_ADMIN"], + }, + }); + await page.route("**/api/v1/proprietary/ui-data/account", (route) => + route.fulfill({ + json: { username: "admin", email: "admin@example.com", isAdmin: true }, + }), + ); + await page.goto("/"); + + const configBtn = page.locator('[data-testid="config-button"]').first(); + if (!(await configBtn.isVisible({ timeout: 5_000 }).catch(() => false))) { + test.skip(true, "Config button not rendered for admin on this build"); + return; + } + await configBtn.click(); + const dialog = page.locator(".mantine-Modal-content").first(); + await expect(dialog).toBeVisible({ timeout: 5_000 }); + + // AI Form Detection now lives as a box inside the "Features" section, + // not as its own nav entry - navigate there first. + await dialog.getByText("Features", { exact: true }).first().click(); + + await expect(dialog.getByText(/AI Form Detection/i).first()).toBeVisible({ + timeout: 5_000, + }); + }); +}); diff --git a/frontend/editor/src/core/tools/autoFormDetection/AutoFormDetection.tsx b/frontend/editor/src/core/tools/autoFormDetection/AutoFormDetection.tsx new file mode 100644 index 0000000000..8b045576c8 --- /dev/null +++ b/frontend/editor/src/core/tools/autoFormDetection/AutoFormDetection.tsx @@ -0,0 +1,45 @@ +import { useTranslation } from "react-i18next"; +import { createToolFlow } from "@app/components/tools/shared/createToolFlow"; +import { useAutoFormDetectionParameters } from "@app/hooks/tools/autoFormDetection/useAutoFormDetectionParameters"; +import { useAutoFormDetectionOperation } from "@app/hooks/tools/autoFormDetection/useAutoFormDetectionOperation"; +import { useBaseTool } from "@app/hooks/tools/shared/useBaseTool"; +import { BaseToolProps, ToolComponent } from "@app/types/tool"; + +const AutoFormDetection = (props: BaseToolProps) => { + const { t } = useTranslation(); + + const base = useBaseTool( + "autoFormDetection", + useAutoFormDetectionParameters, + useAutoFormDetectionOperation, + props, + ); + + return createToolFlow({ + files: { + selectedFiles: base.selectedFiles, + isCollapsed: base.hasFiles || base.hasResults, + }, + steps: [], + executeButton: { + text: t("autoFormDetection.submit", "Detect & make fillable"), + isVisible: !base.hasResults, + loadingText: t("autoFormDetection.loading", "Detecting form fields..."), + onClick: base.handleExecute, + endpointEnabled: base.endpointEnabled, + paramsValid: base.params.validateParameters(), + }, + review: { + isVisible: base.hasResults, + operation: base.operation, + title: t("autoFormDetection.results.title", "Detected Form Fields"), + onFileClick: base.handleThumbnailClick, + onUndo: base.handleUndo, + }, + }); +}; + +// Static method to get the operation hook for automation. +AutoFormDetection.tool = () => useAutoFormDetectionOperation; + +export default AutoFormDetection as ToolComponent; diff --git a/frontend/editor/src/core/types/toolId.ts b/frontend/editor/src/core/types/toolId.ts index 5286aec080..196d832d93 100644 --- a/frontend/editor/src/core/types/toolId.ts +++ b/frontend/editor/src/core/types/toolId.ts @@ -63,6 +63,7 @@ export const CORE_REGULAR_TOOL_IDS = [ "bookletImposition", "pdfTextEditor", "formFill", + "autoFormDetection", ] as const; export const CORE_SUPER_TOOL_IDS = ["multiTool", "read", "automate"] as const; diff --git a/frontend/editor/src/proprietary/components/shared/config/configSections/AdminFeaturesSection.tsx b/frontend/editor/src/proprietary/components/shared/config/configSections/AdminFeaturesSection.tsx index 86edd74446..6382fe25dd 100644 --- a/frontend/editor/src/proprietary/components/shared/config/configSections/AdminFeaturesSection.tsx +++ b/frontend/editor/src/proprietary/components/shared/config/configSections/AdminFeaturesSection.tsx @@ -22,6 +22,7 @@ import { SettingsStickyFooter } from "@app/components/shared/config/SettingsStic import apiClient from "@app/services/apiClient"; import { useLoginRequired } from "@app/hooks/useLoginRequired"; import LoginRequiredBanner from "@app/components/shared/config/LoginRequiredBanner"; +import AdminFormDetectionSection from "@app/components/shared/config/configSections/AdminFormDetectionSection"; interface FeaturesSettingsData { serverCertificate?: { @@ -355,6 +356,8 @@ export default function AdminFeaturesSection() { + + = 1024 ? `${(mb / 1024).toFixed(1)}GB` : `${Math.round(mb)}MB`; +} + +function badgeColor(s?: FormDetectionState): string { + switch (s) { + case "ready": + return "green"; + case "downloading": + case "verifying": + return "blue"; + case "failed": + return "red"; + default: + return "gray"; + } +} + +export default function AdminFormDetectionSection() { + const { t } = useTranslation(); + const { status, loading, error, install, uninstall, setConfig } = + useFormDetectionModelStatus(); + const [selectedId, setSelectedId] = useState(null); + const [busy, setBusy] = useState(false); + const [configBusy, setConfigBusy] = useState(false); + const [actionError, setActionError] = useState(null); + + const enabled = status?.enabled ?? true; + const executionMode: FormDetectionExecutionMode = + status?.executionMode ?? "auto"; + const serverEngineAvailable = status?.serverEngineAvailable ?? true; + + const catalog = status?.catalog ?? []; + const selectData = useMemo( + () => + catalog.map((c) => ({ + value: c.id, + label: `${c.displayName} ยท ${formatSize(c.sizeBytes)}`, + })), + [catalog], + ); + + const effectiveId = + selectedId ?? status?.activeModelId ?? catalog[0]?.id ?? null; + const selectedEntry = catalog.find((c) => c.id === effectiveId); + const installable = Boolean(selectedEntry?.onnxUrl && selectedEntry?.sha256); + const st = status?.status; + const inFlight = st === "downloading" || st === "verifying"; + const activeId = status?.activeModelId || null; + const activeEntry = catalog.find((c) => c.id === activeId); + const installedIds = status?.installed ?? []; + // The action depends on the *selected* model, not just the overall status: + // only the currently-active model can be uninstalled; any other selection installs/switches. + const selectedIsActive = + st === "ready" && effectiveId != null && effectiveId === activeId; + const selectedIsInstalled = + effectiveId != null && installedIds.includes(effectiveId); + + const doInstall = async () => { + if (!effectiveId) return; + setBusy(true); + setActionError(null); + try { + await install(effectiveId); + } catch (e) { + setActionError(e instanceof Error ? e.message : "Install failed"); + } finally { + setBusy(false); + } + }; + + const doUninstall = async () => { + setBusy(true); + setActionError(null); + try { + await uninstall(status?.activeModelId || undefined); + } catch (e) { + setActionError(e instanceof Error ? e.message : "Uninstall failed"); + } finally { + setBusy(false); + } + }; + + const doSetConfig = async (config: { + enabled?: boolean; + executionMode?: FormDetectionExecutionMode; + }) => { + setConfigBusy(true); + setActionError(null); + try { + await setConfig(config); + } catch (e) { + setActionError(e instanceof Error ? e.message : "Failed to save setting"); + } finally { + setConfigBusy(false); + } + }; + + return ( + + +
+ + + {t("admin.formDetection.title", "AI Form Detection")} + + + doSetConfig({ enabled: e.currentTarget.checked }) + } + disabled={configBusy || (loading && !status)} + size="sm" + aria-label={t( + "admin.formDetection.enableFeature", + "Enable feature", + )} + /> + + + {t( + "admin.formDetection.description", + "Install the AI model used to auto-detect form fields. The selected model is downloaded on demand (about 40-100MB) into the configs volume and is not bundled with Stirling-PDF.", + )} + +
+ + {loading && !status ? ( + + ) : ( + + + {t("admin.formDetection.status", "Status")}: + + {st ?? "unknown"} + + {activeId ? ( + + {t("admin.formDetection.active", "Active")}:{" "} + {activeEntry?.displayName ?? activeId} + + ) : null} + + +
+ + {t("admin.formDetection.engineLabel", "Where detection runs")} + + + {t( + "admin.formDetection.engineDescription", + "Browser keeps the PDF on the device (downloads a ~12MB runtime once, then cached); Server runs it on the backend; Auto prefers the browser and falls back to the server.", + )} + + + doSetConfig({ + executionMode: v as FormDetectionExecutionMode, + }) + } + disabled={configBusy || !enabled} + data={[ + { + label: t("admin.formDetection.engine.auto", "Auto"), + value: "auto", + }, + { + label: t("admin.formDetection.engine.browser", "Browser"), + value: "browser", + }, + { + label: t("admin.formDetection.engine.server", "Server"), + value: "server", + disabled: !serverEngineAvailable, + }, + ]} + /> + {!serverEngineAvailable ? ( + + {t( + "admin.formDetection.engine.serverUnavailable", + "The server engine is not bundled in this build, so detection runs in the browser. Use Auto or Browser.", + )} + + ) : null} +
+ + {inFlight ? ( + + ) : null} + +