diff --git a/configs/examples/synthesis/ehr_db_synth.yaml b/configs/examples/synthesis/ehr_db_synth.yaml new file mode 100644 index 0000000000..472764591f --- /dev/null +++ b/configs/examples/synthesis/ehr_db_synth.yaml @@ -0,0 +1,141 @@ +# EHR Database Agent — Tool-Use Conversation Synthesis (real DB) +# +# Six tools that read or mutate live SQL state via SQLAlchemy. The DB *is* +# the state — nothing lives in memory. +# +# Setup before `oumi synth`: +# 1. Create the demo DB and load schema + seed: +# sqlite3 /tmp/ehr_demo.db < src/oumi/examples/ehr_db/schema.sql +# sqlite3 /tmp/ehr_demo.db < src/oumi/examples/ehr_db/seed.sql +# 2. Set env var (or change `database:` below): +# export EHR_DB_PATH=/tmp/ehr_demo.db +# 3. Set OPENAI_API_KEY (or change inference_config). +# +# Usage: +# oumi synth -c configs/examples/synthesis/ehr_db_synth.yaml +# +# See also: +# - Executors: src/oumi/examples/ehr_db/executors.py +# - Schema: src/oumi/examples/ehr_db/schema.sql +# - Seed: src/oumi/examples/ehr_db/seed.sql + +strategy: GENERAL +num_samples: 50 +output_path: ehr_db_dataset.jsonl + +environment_config: + environments: + - id: ehr_db + env_type: database + name: EHR Database + description: A clinical EHR database backed by real SQL. + env_kwargs: + connection: + driver: sqlite + database: ${oc.env:EHR_DB_PATH,/tmp/ehr_demo.db} + pool_size: 10 + pool_max_overflow: 10 + read_only: false + audit: true + tools: + - id: list_patients + name: list_patients + description: >- + Return a list of patient summaries (patient_id, name, dob, status). + Call this first when the clinician refers to a patient by NAME so + you can resolve the name to a patient_id. + parameters: + type: object + properties: {} + read_only: true + executor: oumi.examples.ehr_db.executors.list_patients + + - id: get_patient + name: get_patient + description: >- + Return the full record for a patient_id, including allergies, + active medications, diagnoses, and vitals history. + parameters: + type: object + properties: + patient_id: { type: string } + required: [patient_id] + read_only: true + executor: oumi.examples.ehr_db.executors.get_patient + + - id: record_vitals + name: record_vitals + description: >- + Append a vitals reading (timestamp, blood pressure, heart rate, + temperature in F) to a patient's vitals_history. + parameters: + type: object + properties: + patient_id: { type: string } + timestamp: { type: string } + bp: { type: string } + hr: { type: integer } + temp_f: { type: number } + required: [patient_id, timestamp, bp, hr, temp_f] + read_only: false + executor: oumi.examples.ehr_db.executors.record_vitals + + - id: add_diagnosis + name: add_diagnosis + description: >- + Append a new diagnosis (ICD-10 code, description, date) to a + patient. Returns an error if the same code is already present. + parameters: + type: object + properties: + patient_id: { type: string } + code: { type: string } + description: { type: string } + date: { type: string } + required: [patient_id, code, description, date] + read_only: false + executor: oumi.examples.ehr_db.executors.add_diagnosis + + - id: prescribe_medication + name: prescribe_medication + description: >- + Add an active medication to a patient's medication list. Returns + an error if the same medication is already prescribed or if it + conflicts with a known allergy. + parameters: + type: object + properties: + patient_id: { type: string } + name: { type: string } + dose: { type: string } + required: [patient_id, name, dose] + read_only: false + executor: oumi.examples.ehr_db.executors.prescribe_medication + + - id: update_allergies + name: update_allergies + description: >- + Replace a patient's allergy list with the supplied list. Pass an + empty list to clear all allergies. + parameters: + type: object + properties: + patient_id: { type: string } + allergies: + type: array + items: { type: string } + required: [patient_id, allergies] + read_only: false + executor: oumi.examples.ehr_db.executors.update_allergies + +inference_config: + model: + model_name: gpt-4o-mini + engine: OPENAI + generation: + max_new_tokens: 2048 + temperature: 0.7 + top_p: 0.9 + remote_params: + num_workers: 8 + politeness_policy: 60 diff --git a/pyproject.toml b/pyproject.toml index 61b4ee0f59..7ccd9b0a4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ dependencies = [ "responses>=0.25,<0.27", "safetensors>=0.6,<0.8", "skypilot>=0.11.1,<0.13", # 0.11.1 includes db locking fix + "sqlalchemy>=2.0,<3.0", # Used by DatabaseExecutableEnvironment "tensorboard>=2.20,<2.21", # Optional, for monitoring training "tiktoken>=0.7,<1.0", "torch>=2.6,<2.13.0", # vllm only supports up to torch==2.7 @@ -108,6 +109,7 @@ dev = [ "pytest", "responses", "ruff", + "testcontainers>=4.0,<5.0", # Postgres fixture for opt-in integration tests "torchfix", # Tool for automatically fixing common PyTorch issues ] docs = [ @@ -342,6 +344,7 @@ markers = [ "e2e_eternal: Extremely slow e2e integration tests (for manual/selective runs)", "single_gpu: The test uses max 1 GPU (can be potentially skipped on multi-GPU machine to conserve GPU resources)", "multi_gpu: The test should run on a machine with multiple GPU-s", + "requires_postgres: marks tests that need a real Postgres instance (run with -m requires_postgres)", ] [tool.coverage.run] diff --git a/src/oumi/core/configs/params/database_connection_params.py b/src/oumi/core/configs/params/database_connection_params.py new file mode 100644 index 0000000000..164b647cc3 --- /dev/null +++ b/src/oumi/core/configs/params/database_connection_params.py @@ -0,0 +1,111 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Connection parameters for DatabaseExecutableEnvironment.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +import sqlalchemy +import sqlalchemy.engine + +from oumi.core.configs.params.base_params import BaseParams + + +@dataclass +class DatabaseConnectionConfig(BaseParams): + """SQLAlchemy connection config for DatabaseExecutableEnvironment. + + Two modes (mutually exclusive): + + - **Structured fields:** populate ``driver`` and ``database`` (and host/port/ + username as needed). The password is resolved at ``resolve_url()`` time + from ``password_env_var`` so it never appears in YAML. + - **DSN env var:** populate ``dsn_env_var`` with the name of an env var + holding a complete SQLAlchemy URL. + """ + + # Mode 1: structured fields + driver: str = "" # e.g. "postgresql+psycopg", "mysql+pymysql", "sqlite" + host: str = "" + port: int | None = None + database: str = "" + username: str = "" + password_env_var: str = "" # name of env var holding the password + + # Mode 2: full DSN from env (escape hatch) + dsn_env_var: str = "" # mutually exclusive with structured fields + + # Pool / timeouts + pool_size: int = 5 + pool_max_overflow: int = 10 + pool_pre_ping: bool = True + connect_timeout_s: float = 10.0 + + def __finalize_and_validate__(self) -> None: + """Validate mode XOR (structured vs DSN) and pool/timeout numeric bounds.""" + structured = bool(self.driver and self.database) + dsn = bool(self.dsn_env_var) + if structured and dsn: + raise ValueError( + "DatabaseConnectionConfig: set EITHER (driver, database, ...) " + "OR dsn_env_var, not both." + ) + if not structured and not dsn: + raise ValueError( + "DatabaseConnectionConfig: must set either (driver + database) " + "OR dsn_env_var. Got neither." + ) + if self.pool_size < 1: + raise ValueError( + f"DatabaseConnectionConfig.pool_size must be >= 1, " + f"got {self.pool_size}." + ) + if self.pool_max_overflow < 0: + raise ValueError( + f"DatabaseConnectionConfig.pool_max_overflow must be >= 0, " + f"got {self.pool_max_overflow}." + ) + if self.connect_timeout_s <= 0: + raise ValueError( + f"DatabaseConnectionConfig.connect_timeout_s must be > 0, " + f"got {self.connect_timeout_s}." + ) + + def resolve_url(self) -> sqlalchemy.URL: + """Build the SQLAlchemy URL. + + Reads ``password_env_var`` / ``dsn_env_var`` from the environment at + call time (not at YAML-parse time), so configs are safe to log/dump. + """ + if self.dsn_env_var: + dsn = os.environ.get(self.dsn_env_var) + if not dsn: + raise ValueError( + f"DatabaseConnectionConfig: env var '{self.dsn_env_var}' " + f"is not set; cannot build connection URL." + ) + return sqlalchemy.engine.make_url(dsn) + password = ( + os.environ.get(self.password_env_var) if self.password_env_var else None + ) + return sqlalchemy.URL.create( + drivername=self.driver, + host=self.host or None, + port=self.port, + database=self.database, + username=self.username or None, + password=password, + ) diff --git a/src/oumi/environments/__init__.py b/src/oumi/environments/__init__.py index 8681b66074..a57c31aef9 100644 --- a/src/oumi/environments/__init__.py +++ b/src/oumi/environments/__init__.py @@ -18,6 +18,9 @@ concrete environment's `@register_environment(...)` decorator. """ +from oumi.core.configs.params.database_connection_params import ( + DatabaseConnectionConfig, +) from oumi.core.configs.params.grounding_params import ( GroundingConfig, GroundingFact, @@ -32,11 +35,18 @@ ) from oumi.core.types.tool_call import JSONSchema, ToolResult from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.database_executable_environment import ( + DatabaseExecutableEnvironment, + DatabaseExecutableEnvironmentKwargs, +) +from oumi.environments.database_executable_tool import DatabaseExecutableTool from oumi.environments.deterministic_environment import ( DeterministicEnvironment, DeterministicEnvironmentKwargs, ToolLookupEntry, ) +from oumi.environments.executable_environment import ExecutableEnvironment +from oumi.environments.executable_tool import ExecutableTool from oumi.environments.synthetic_environment import ( SyntheticEnvironment, SyntheticEnvironmentKwargs, @@ -45,8 +55,14 @@ __all__ = [ "BaseEnvironment", + "DatabaseConnectionConfig", + "DatabaseExecutableEnvironment", + "DatabaseExecutableEnvironmentKwargs", + "DatabaseExecutableTool", "DeterministicEnvironment", "DeterministicEnvironmentKwargs", + "ExecutableEnvironment", + "ExecutableTool", "GroundingConfig", "GroundingFact", "JSONSchema", diff --git a/src/oumi/environments/database_executable_environment.py b/src/oumi/environments/database_executable_environment.py new file mode 100644 index 0000000000..caf6512284 --- /dev/null +++ b/src/oumi/environments/database_executable_environment.py @@ -0,0 +1,342 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""SQLAlchemy-backed environment that runs user-supplied executors.""" + +from __future__ import annotations + +import time +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, ClassVar + +import sqlalchemy +import sqlalchemy.engine +import sqlalchemy.event as sa_event +import sqlalchemy.exc + +from oumi.core.configs.params.base_params import BaseParams +from oumi.core.configs.params.database_connection_params import ( + DatabaseConnectionConfig, +) +from oumi.core.configs.params.environment_params import EnvironmentParams +from oumi.core.configs.params.tool_params import ToolError +from oumi.core.registry import register_environment +from oumi.core.types.tool_call import ToolResult +from oumi.environments.database_executable_tool import DatabaseExecutableTool +from oumi.environments.executable_environment import ( + ExecutableEnvironment, + _import_executor, +) +from oumi.environments.executable_tool import ExecutableTool +from oumi.utils.logging import logger + + +def _install_dialect_guards( + engine: sqlalchemy.engine.Engine, + kwargs: DatabaseExecutableEnvironmentKwargs, +) -> None: + """Install dialect-specific session settings. + + - ``connect`` event: every new pool connection gets read-only + env-level + timeout. + - ``checkin`` event: when a connection returns to the pool, reset its + timeout to env-level (or RESET if env-level is unset) so per-tool + overrides don't leak across checkouts. + """ + dialect = engine.dialect.name + read_only = kwargs.read_only + env_timeout_ms = kwargs.statement_timeout_ms + + @sa_event.listens_for(engine, "connect") + def _on_connect(dbapi_conn, connection_record): # noqa: ANN001 + cursor = dbapi_conn.cursor() + try: + if dialect == "postgresql": + if read_only: + cursor.execute("SET default_transaction_read_only = on") + if env_timeout_ms is not None: + cursor.execute(f"SET statement_timeout = {int(env_timeout_ms)}") + elif dialect == "mysql": + if read_only: + cursor.execute("SET SESSION TRANSACTION READ ONLY") + if env_timeout_ms is not None: + cursor.execute( + f"SET SESSION max_execution_time = {int(env_timeout_ms)}" + ) + elif dialect == "sqlite": + if read_only: + cursor.execute("PRAGMA query_only = ON") + if env_timeout_ms is not None: + logger.warning( + "SQLite does not support statement_timeout; ignoring " + "statement_timeout_ms=%s.", + env_timeout_ms, + ) + else: + if read_only or env_timeout_ms is not None: + logger.warning( + "Dialect '%s' has no registered guard handler; " + "read_only=%s and statement_timeout_ms=%s will not be " + "enforced. Set them at the DB role / DSN level instead.", + dialect, + read_only, + env_timeout_ms, + ) + finally: + cursor.close() + + @sa_event.listens_for(engine, "checkin") + def _on_checkin(dbapi_conn, connection_record): # noqa: ANN001 + if dialect not in {"postgresql", "mysql"}: + return + try: + cursor = dbapi_conn.cursor() + except Exception: + return # connection may already be in a bad state; best-effort + try: + if dialect == "postgresql": + if env_timeout_ms is not None: + cursor.execute(f"SET statement_timeout = {int(env_timeout_ms)}") + else: + cursor.execute("RESET statement_timeout") + elif dialect == "mysql": + if env_timeout_ms is not None: + cursor.execute( + f"SET SESSION max_execution_time = {int(env_timeout_ms)}" + ) + else: + cursor.execute("SET SESSION max_execution_time = 0") + except Exception: + # Connection may already be closing/dead; checkin reset is best-effort. + pass + finally: + try: + cursor.close() + except Exception: + pass + + +@dataclass +class DatabaseExecutableEnvironmentKwargs(BaseParams): + """Type-specific kwargs for DatabaseExecutableEnvironment.""" + + connection: DatabaseConnectionConfig | None = None + read_only: bool = False + statement_timeout_ms: int | None = None + audit: bool = False + + def __post_init__(self) -> None: + """Coerce ``connection`` dict into a ``DatabaseConnectionConfig``.""" + if isinstance(self.connection, dict): + self.connection = DatabaseConnectionConfig(**self.connection) + + def __finalize_and_validate__(self) -> None: + """Validate that connection is set and numeric fields are in range.""" + if self.connection is None: + raise ValueError( + "DatabaseExecutableEnvironmentKwargs.connection is required." + ) + if self.statement_timeout_ms is not None and self.statement_timeout_ms <= 0: + raise ValueError( + f"statement_timeout_ms must be > 0, got {self.statement_timeout_ms}." + ) + + +@register_environment("database") +class DatabaseExecutableEnvironment(ExecutableEnvironment): + """Environment that runs user-supplied executors against a real SQL database. + + The DB *is* the state. Each ``step`` checks out a connection from the + SQLAlchemy pool, runs the executor in autocommit mode, and returns the + connection. SQL errors that escape the executor are auto-wrapped as a + structured ``ToolResult`` so the agent can self-correct. + """ + + tool_params_cls = DatabaseExecutableTool + _executor_context_kwarg: ClassVar[str] = "db" + + def __init__( + self, + params: EnvironmentParams, + kwargs: DatabaseExecutableEnvironmentKwargs, + engine: sqlalchemy.engine.Engine, + ) -> None: + """Initialize the env. Use ``from_params`` rather than calling directly.""" + self._params = params + self._kwargs = kwargs + self._engine = engine + self._executors: dict[str, Any] = {} + + @classmethod + def _validate_tools( + cls, + tools: list[Any], + kwargs: DatabaseExecutableEnvironmentKwargs, + ) -> None: + """Validate per-tool config invariants relative to env-level kwargs.""" + env_timeout = kwargs.statement_timeout_ms + for tool in tools: + if not isinstance(tool, DatabaseExecutableTool): + raise ValueError( + f"DatabaseExecutableEnvironment tool " + f"'{getattr(tool, 'id', '?')}' must be a " + f"DatabaseExecutableTool, got {type(tool).__name__}." + ) + if kwargs.read_only and not tool.read_only: + raise ValueError( + f"DatabaseExecutableEnvironment is read_only=True but " + f"tool '{tool.id}' has read_only=False. Read-only envs " + f"require all tools to be read_only=True." + ) + if tool.statement_timeout_ms is not None: + if env_timeout is None: + raise ValueError( + f"Tool '{tool.id}' has statement_timeout_ms=" + f"{tool.statement_timeout_ms} but the env has no " + f"statement_timeout_ms set. Set env-level timeout " + f"first, or remove the per-tool override." + ) + if tool.statement_timeout_ms > env_timeout: + raise ValueError( + f"Tool '{tool.id}' has statement_timeout_ms=" + f"{tool.statement_timeout_ms} which exceeds env-level " + f"statement_timeout_ms={env_timeout}. Per-tool overrides " + f"must be stricter (<=) than env-level." + ) + + @classmethod + def from_params(cls, params: EnvironmentParams) -> DatabaseExecutableEnvironment: + """Build a DatabaseExecutableEnvironment from its params object.""" + kwargs = DatabaseExecutableEnvironmentKwargs(**(params.env_kwargs or {})) + kwargs.finalize_and_validate() + cls._validate_tools(params.tools, kwargs) + assert kwargs.connection is not None # validated above + + url = kwargs.connection.resolve_url() + engine_kwargs: dict[str, Any] = { + "pool_pre_ping": kwargs.connection.pool_pre_ping, + "isolation_level": "AUTOCOMMIT", + "future": True, + } + # SQLite uses SingletonThreadPool which does not accept pool_size / + # max_overflow; skip those args for in-process databases. + if url.get_dialect().name != "sqlite": + engine_kwargs["pool_size"] = kwargs.connection.pool_size + engine_kwargs["max_overflow"] = kwargs.connection.pool_max_overflow + engine = sqlalchemy.create_engine(url, **engine_kwargs) + _install_dialect_guards(engine, kwargs) + + # Fail-fast: prove the connection works before any tool runs. + try: + with engine.connect() as conn: + conn.execute(sqlalchemy.text("SELECT 1")) + except Exception as e: + engine.dispose() + raise ValueError( + f"Failed to connect to database for env '{params.id}': {e}" + ) from e + + env = cls(params, kwargs, engine) + for tool in params.tools: + assert isinstance(tool, DatabaseExecutableTool) + env._executors[tool.id] = _import_executor(tool.executor, tool.id) + return env + + @contextmanager + def _build_execution_context( + self, tool: ExecutableTool, arguments: dict[str, Any] + ) -> Iterator[sqlalchemy.engine.Connection]: + """Check out a connection from the pool for one tool call. + + Per-tool ``statement_timeout_ms`` is applied as a session-level SET on + the checked-out connection. The engine's ``checkin`` event handler + (installed by ``_install_dialect_guards``) resets the session timeout + back to env-level on connection return, so the override doesn't leak. + """ + conn = self._engine.connect() + try: + assert isinstance(tool, DatabaseExecutableTool) + if tool.statement_timeout_ms is not None: + self._set_session_timeout(conn, tool.statement_timeout_ms) + yield conn + finally: + conn.close() # returns to pool; checkin handler resets timeout + + def _set_session_timeout( + self, conn: sqlalchemy.engine.Connection, timeout_ms: int + ) -> None: + """Apply a per-tool session-level timeout to a checked-out connection.""" + dialect = self._engine.dialect.name + if dialect == "postgresql": + conn.exec_driver_sql(f"SET statement_timeout = {int(timeout_ms)}") + elif dialect == "mysql": + conn.exec_driver_sql(f"SET SESSION max_execution_time = {int(timeout_ms)}") + # sqlite / unknown: no-op (warned at engine setup time). + + def close(self) -> None: + """Dispose the engine and its connection pool.""" + self._engine.dispose() + + def _invoke_executor( + self, + executor, + arguments, + ctx, + tool, + ): + """Run the executor; auto-wrap SQL errors as structured ToolResults. + + Returns ``(result, was_auto_wrapped)``. When ``was_auto_wrapped`` is + True the caller skips ``output_schema`` validation because the wrap + shape replaces the executor's normal return value. + """ + try: + return executor(arguments=arguments, db=ctx), False + except sqlalchemy.exc.DBAPIError as e: + orig = e.orig + # psycopg3 exposes SQLSTATE as ``sqlstate``; psycopg2 as ``pgcode``. + sql_state = ( + getattr(orig, "sqlstate", None) or getattr(orig, "pgcode", None) + if orig + else None + ) + wrapped = ToolResult( + output={ + "status": "error", + "error": type(e).__name__, + "message": str(orig) if orig else str(e), + "sql_state": sql_state, + } + ) + return wrapped, True + + def _absorb_result(self, tool, result) -> None: + """DB envs hold state in the DB; reject any ToolResult.updated_state.""" + if result.updated_state is not None: + raise ToolError( + f"DatabaseExecutable tool '{tool.id}' returned updated_state; " + f"DB-backed envs hold state in the database, not in ToolResult." + ) + + def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: + """Run a tool call with optional per-call audit logging.""" + if not self._kwargs.audit: + return super().step(tool_id, arguments) + + started = time.monotonic() + status = "ok" + try: + return super().step(tool_id, arguments) + except Exception: + status = "error" + raise + finally: + duration_ms = (time.monotonic() - started) * 1000.0 + logger.info( + "db_tool_call env_id=%s tool_id=%s status=%s duration_ms=%.2f", + self._params.id, + tool_id, + status, + duration_ms, + ) diff --git a/src/oumi/environments/database_executable_tool.py b/src/oumi/environments/database_executable_tool.py new file mode 100644 index 0000000000..6d6a2e2f85 --- /dev/null +++ b/src/oumi/environments/database_executable_tool.py @@ -0,0 +1,46 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Database-specific ExecutableTool subclass.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +from oumi.environments.executable_tool import ExecutableTool + + +@dataclass +class DatabaseExecutableTool(ExecutableTool): + """`ExecutableTool` for ``DatabaseExecutableEnvironment``. + + Adds an optional per-tool ``statement_timeout_ms`` that may **only** tighten + the env-level timeout (validated by the env at construction time). + """ + + statement_timeout_ms: int | None = None + + @classmethod + def create(cls, raw: Any) -> DatabaseExecutableTool: + """Create a DatabaseExecutableTool from raw config data.""" + if isinstance(raw, DatabaseExecutableTool): + return raw + if not isinstance(raw, Mapping): + raise TypeError( + f"Tool definitions must be tool objects or mappings, got {type(raw)}" + ) + return cls( + id=raw["id"], + name=raw["name"], + description=raw["description"], + parameters=dict(raw.get("parameters") or {"type": "object"}), + output_schema=( + dict(raw["output_schema"]) + if raw.get("output_schema") is not None + else None + ), + read_only=raw.get("read_only", True), + executor=raw["executor"], + statement_timeout_ms=raw.get("statement_timeout_ms"), + ) diff --git a/src/oumi/environments/executable_environment.py b/src/oumi/environments/executable_environment.py new file mode 100644 index 0000000000..61d9d40b88 --- /dev/null +++ b/src/oumi/environments/executable_environment.py @@ -0,0 +1,137 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Abstract base for envs backed by user-supplied dotted-path executors.""" + +from __future__ import annotations + +import importlib +from abc import ABC, abstractmethod +from collections.abc import Callable +from contextlib import AbstractContextManager +from typing import Any, ClassVar, cast + +import jsonschema + +from oumi.core.configs.params.environment_params import EnvironmentParams +from oumi.core.configs.params.tool_params import ( + ToolError, + ToolParams, +) +from oumi.core.types.tool_call import ToolResult +from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.executable_tool import ExecutableTool + +_MISSING = object() + + +def _import_executor(dotted: str, tool_id: str) -> Callable[..., Any]: + """Resolve a dotted import path to a callable. Raises ValueError on failure.""" + module_path, _, attr = dotted.rpartition(".") + if not module_path or not attr: + raise ValueError( + f"ExecutableTool '{tool_id}' executor '{dotted}' must be a dotted path." + ) + try: + module = importlib.import_module(module_path) + except ImportError as e: + raise ValueError( + f"ExecutableTool '{tool_id}' executor '{dotted}': " + f"could not import module '{module_path}': {e}" + ) from e + fn = getattr(module, attr, _MISSING) + if fn is _MISSING: + raise ValueError( + f"ExecutableTool '{tool_id}' executor '{dotted}': " + f"module '{module_path}' has no attribute '{attr}'." + ) + if not callable(fn): + raise ValueError( + f"ExecutableTool '{tool_id}' executor '{dotted}' is not callable." + ) + return fn + + +class ExecutableEnvironment(BaseEnvironment, ABC): + """Abstract base for envs that run user-supplied dotted-path executors. + + Subclasses provide the per-call execution context (DB connection, HTTP + client, ...) by implementing ``_build_execution_context`` as a context + manager. The orchestration (executor resolution, ``ToolResult`` + validation, schema validation, ``_absorb_result`` post-hook, ``close`` + lifecycle) lives here. + """ + + tool_params_cls: type[ToolParams] = ExecutableTool + + #: Keyword name under which subclasses pass the execution context to + #: user executors. Defaults to ``"context"``; ``DatabaseExecutableEnvironment`` + #: overrides to ``"db"``. + _executor_context_kwarg: ClassVar[str] = "context" + + _params: EnvironmentParams + _executors: dict[str, Callable[..., Any]] + + @abstractmethod + def _build_execution_context( + self, tool: ExecutableTool, arguments: dict[str, Any] + ) -> AbstractContextManager[Any]: + """Yield the per-call execution context (DB conn, HTTP client, ...).""" + + def _absorb_result(self, tool: ExecutableTool, result: ToolResult) -> None: + """Post-hook called after a successful executor call. Default no-op.""" + return None + + def close(self) -> None: + """Release any resources owned by this env. Default no-op.""" + return None + + def _invoke_executor( + self, + executor: Callable[..., Any], + arguments: dict[str, Any], + ctx: Any, + tool: ExecutableTool, + ) -> tuple[ToolResult, bool]: + """Run the executor; return (result, was_auto_wrapped). + + Default: pass ``ctx`` via ``_executor_context_kwarg``; never auto-wrap. + Subclasses override to translate transport-level exceptions into + structured ``ToolResult``s. + """ + return ( + executor(arguments=arguments, **{self._executor_context_kwarg: ctx}), + False, + ) + + def _lookup_tool(self, tool_id: str) -> ExecutableTool: + for tool in self._params.tools: + if tool.id == tool_id: + return cast(ExecutableTool, tool) + raise ValueError( + f"Tool '{tool_id}' not found in environment '{self._params.id}'. " + f"Available tools: {[tool.id for tool in self._params.tools]}" + ) + + def step(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: + """Execute a single tool call and return its result.""" + tool = self._lookup_tool(tool_id) + tool.validate_arguments(arguments) + executor = self._executors[tool_id] + + with self._build_execution_context(tool, arguments) as ctx: + result, auto_wrapped = self._invoke_executor(executor, arguments, ctx, tool) + + if not isinstance(result, ToolResult): + raise ToolError( + f"Executor '{tool.executor}' must return ToolResult, " + f"got {type(result).__name__}." + ) + if tool.output_schema is not None and not auto_wrapped: + try: + jsonschema.validate(result.output, tool.output_schema) + except jsonschema.ValidationError as e: + raise ToolError( + f"Tool '{tool_id}' output failed schema validation: {e.message}" + ) from e + self._absorb_result(tool, result) + return result diff --git a/src/oumi/environments/executable_tool.py b/src/oumi/environments/executable_tool.py new file mode 100644 index 0000000000..1d23505e7b --- /dev/null +++ b/src/oumi/environments/executable_tool.py @@ -0,0 +1,29 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Shared base class for tools that resolve a dotted-path Python executor.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from oumi.core.configs.params.tool_params import ToolParams + + +@dataclass +class ExecutableTool(ToolParams): + """`ToolParams` variant for envs that take user-supplied dotted-path executors. + + Subclasses (``DatabaseExecutableTool``, future ``HTTPExecutableTool``, etc.) + inherit this and may add transport-specific per-tool overrides. + """ + + executor: str = "" + + def __post_init__(self) -> None: + """Validate inherited fields and enforce non-empty executor.""" + super().__post_init__() + if not self.executor: + raise ValueError( + f"{type(self).__name__} '{self.id}' must declare a non-empty " + f"executor (dotted import path)." + ) diff --git a/src/oumi/examples/__init__.py b/src/oumi/examples/__init__.py new file mode 100644 index 0000000000..254d44f4be --- /dev/null +++ b/src/oumi/examples/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runnable example modules referenced by configs in ``configs/examples/``.""" diff --git a/src/oumi/examples/ehr_db/__init__.py b/src/oumi/examples/ehr_db/__init__.py new file mode 100644 index 0000000000..5470dd1e21 --- /dev/null +++ b/src/oumi/examples/ehr_db/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EHR Database example — DatabaseExecutableEnvironment with SQLAlchemy. + +A clinical EHR demo with patients, allergies, medications, diagnoses, and +vitals tables, plus six tool executors for read/write agentic workflows. + +Run the synthesis example with: + + oumi synth -c configs/examples/synthesis/ehr_db_synth.yaml +""" diff --git a/src/oumi/examples/ehr_db/executors.py b/src/oumi/examples/ehr_db/executors.py new file mode 100644 index 0000000000..74879d4eba --- /dev/null +++ b/src/oumi/examples/ehr_db/executors.py @@ -0,0 +1,272 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""EHR Database tool executors. + +Each executor takes a SQLAlchemy connection (autocommit mode — every +``conn.execute(...)`` is its own transaction). Executors that handle known +constraint violations (duplicate diagnosis, allergy conflict, etc.) catch +the matching ``IntegrityError`` and return a structured ``{"status": "error"}`` +``ToolResult`` so the agent can self-correct. Unhandled SQL errors propagate +to the env, which auto-wraps them into a generic structured error. +""" + +from __future__ import annotations + +from typing import Any + +import sqlalchemy +import sqlalchemy.exc +from sqlalchemy.engine import Connection + +from oumi.core.types.tool_call import ToolResult + + +def _patient_exists(conn: Connection, patient_id: str) -> bool: + row = conn.execute( + sqlalchemy.text("SELECT 1 FROM patients WHERE patient_id = :pid"), + {"pid": patient_id}, + ).first() + return row is not None + + +def list_patients(arguments: dict[str, Any], db: Connection) -> ToolResult: + """Return patient summaries (read-only).""" + rows = ( + db.execute( + sqlalchemy.text( + "SELECT patient_id, name, dob, status FROM patients ORDER BY name" + ) + ) + .mappings() + .all() + ) + return ToolResult(output={"patients": [dict(r) for r in rows]}) + + +def get_patient(arguments: dict[str, Any], db: Connection) -> ToolResult: + """Fetch the full record for a patient_id (read-only).""" + patient_id = arguments["patient_id"] + base = ( + db.execute( + sqlalchemy.text( + "SELECT patient_id, name, dob, status FROM patients " + "WHERE patient_id = :pid" + ), + {"pid": patient_id}, + ) + .mappings() + .first() + ) + if base is None: + return ToolResult( + output={"status": "error", "error": "not_found", "patient_id": patient_id} + ) + + allergies = ( + db.execute( + sqlalchemy.text("SELECT substance FROM allergies WHERE patient_id = :pid"), + {"pid": patient_id}, + ) + .scalars() + .all() + ) + medications = ( + db.execute( + sqlalchemy.text( + "SELECT name, dose FROM medications WHERE patient_id = :pid" + ), + {"pid": patient_id}, + ) + .mappings() + .all() + ) + diagnoses = ( + db.execute( + sqlalchemy.text( + "SELECT code, description, date FROM diagnoses WHERE patient_id = :pid" + ), + {"pid": patient_id}, + ) + .mappings() + .all() + ) + vitals_history = ( + db.execute( + sqlalchemy.text( + "SELECT timestamp, bp, hr, temp_f FROM vitals " + "WHERE patient_id = :pid ORDER BY timestamp" + ), + {"pid": patient_id}, + ) + .mappings() + .all() + ) + + return ToolResult( + output={ + "status": "ok", + "patient": { + **dict(base), + "allergies": list(allergies), + "medications": [dict(m) for m in medications], + "diagnoses": [dict(d) for d in diagnoses], + "vitals_history": [dict(v) for v in vitals_history], + }, + } + ) + + +def record_vitals(arguments: dict[str, Any], db: Connection) -> ToolResult: + """Append a vitals reading.""" + patient_id = arguments["patient_id"] + if not _patient_exists(db, patient_id): + return ToolResult( + output={"status": "error", "error": "not_found", "patient_id": patient_id} + ) + entry = { + "patient_id": patient_id, + "timestamp": arguments["timestamp"], + "bp": arguments["bp"], + "hr": arguments["hr"], + "temp_f": arguments["temp_f"], + } + db.execute( + sqlalchemy.text( + "INSERT INTO vitals (patient_id, timestamp, bp, hr, temp_f) " + "VALUES (:patient_id, :timestamp, :bp, :hr, :temp_f)" + ), + entry, + ) + return ToolResult( + output={ + "status": "ok", + "patient_id": patient_id, + "vitals_recorded": { + k: entry[k] for k in ("timestamp", "bp", "hr", "temp_f") + }, + } + ) + + +def add_diagnosis(arguments: dict[str, Any], db: Connection) -> ToolResult: + """Append an ICD-10 diagnosis. Refuses duplicate codes.""" + patient_id = arguments["patient_id"] + if not _patient_exists(db, patient_id): + return ToolResult( + output={"status": "error", "error": "not_found", "patient_id": patient_id} + ) + new_diagnosis = { + "patient_id": patient_id, + "code": arguments["code"], + "description": arguments["description"], + "date": arguments["date"], + } + try: + db.execute( + sqlalchemy.text( + "INSERT INTO diagnoses (patient_id, code, description, date) " + "VALUES (:patient_id, :code, :description, :date)" + ), + new_diagnosis, + ) + except sqlalchemy.exc.IntegrityError: + return ToolResult( + output={ + "status": "error", + "error": "duplicate_diagnosis", + "patient_id": patient_id, + "code": new_diagnosis["code"], + } + ) + return ToolResult( + output={ + "status": "ok", + "patient_id": patient_id, + "diagnosis_added": { + k: new_diagnosis[k] for k in ("code", "description", "date") + }, + } + ) + + +def prescribe_medication(arguments: dict[str, Any], db: Connection) -> ToolResult: + """Add a medication. Refuses duplicates and allergy conflicts.""" + patient_id = arguments["patient_id"] + if not _patient_exists(db, patient_id): + return ToolResult( + output={"status": "error", "error": "not_found", "patient_id": patient_id} + ) + name = arguments["name"] + dose = arguments["dose"] + + # Allergy conflict check (case-insensitive match on the substance name). + allergy_hit = db.execute( + sqlalchemy.text( + "SELECT 1 FROM allergies WHERE patient_id = :pid " + "AND LOWER(substance) = LOWER(:name)" + ), + {"pid": patient_id, "name": name}, + ).first() + if allergy_hit is not None: + return ToolResult( + output={ + "status": "error", + "error": "allergy_conflict", + "patient_id": patient_id, + "medication": name, + } + ) + + try: + db.execute( + sqlalchemy.text( + "INSERT INTO medications (patient_id, name, dose) " + "VALUES (:patient_id, :name, :dose)" + ), + {"patient_id": patient_id, "name": name, "dose": dose}, + ) + except sqlalchemy.exc.IntegrityError: + return ToolResult( + output={ + "status": "error", + "error": "already_prescribed", + "patient_id": patient_id, + "medication": name, + } + ) + return ToolResult( + output={ + "status": "ok", + "patient_id": patient_id, + "medication_added": {"name": name, "dose": dose}, + } + ) + + +def update_allergies(arguments: dict[str, Any], db: Connection) -> ToolResult: + """Replace a patient's allergy list with the supplied list.""" + patient_id = arguments["patient_id"] + if not _patient_exists(db, patient_id): + return ToolResult( + output={"status": "error", "error": "not_found", "patient_id": patient_id} + ) + new_allergies = list(arguments["allergies"]) + db.execute( + sqlalchemy.text("DELETE FROM allergies WHERE patient_id = :pid"), + {"pid": patient_id}, + ) + for substance in new_allergies: + db.execute( + sqlalchemy.text( + "INSERT INTO allergies (patient_id, substance) " + "VALUES (:pid, :substance)" + ), + {"pid": patient_id, "substance": substance}, + ) + return ToolResult( + output={ + "status": "ok", + "patient_id": patient_id, + "allergies": new_allergies, + } + ) diff --git a/src/oumi/examples/ehr_db/schema.sql b/src/oumi/examples/ehr_db/schema.sql new file mode 100644 index 0000000000..1406ce8d25 --- /dev/null +++ b/src/oumi/examples/ehr_db/schema.sql @@ -0,0 +1,42 @@ +-- EHR demo schema (SQLite-compatible; uses standard SQL where possible). + +CREATE TABLE patients ( + patient_id TEXT PRIMARY KEY, + name TEXT NOT NULL, + dob TEXT NOT NULL, + status TEXT NOT NULL +); + +CREATE TABLE allergies ( + patient_id TEXT NOT NULL, + substance TEXT NOT NULL, + PRIMARY KEY (patient_id, substance), + FOREIGN KEY (patient_id) REFERENCES patients(patient_id) +); + +CREATE TABLE medications ( + patient_id TEXT NOT NULL, + name TEXT NOT NULL, + dose TEXT NOT NULL, + PRIMARY KEY (patient_id, name), + FOREIGN KEY (patient_id) REFERENCES patients(patient_id) +); + +CREATE TABLE diagnoses ( + patient_id TEXT NOT NULL, + code TEXT NOT NULL, + description TEXT NOT NULL, + date TEXT NOT NULL, + PRIMARY KEY (patient_id, code), + FOREIGN KEY (patient_id) REFERENCES patients(patient_id) +); + +CREATE TABLE vitals ( + patient_id TEXT NOT NULL, + timestamp TEXT NOT NULL, + bp TEXT NOT NULL, + hr INTEGER NOT NULL, + temp_f REAL NOT NULL, + PRIMARY KEY (patient_id, timestamp), + FOREIGN KEY (patient_id) REFERENCES patients(patient_id) +); diff --git a/src/oumi/examples/ehr_db/seed.sql b/src/oumi/examples/ehr_db/seed.sql new file mode 100644 index 0000000000..86aab97784 --- /dev/null +++ b/src/oumi/examples/ehr_db/seed.sql @@ -0,0 +1,35 @@ +-- 6-patient subset of an EHR fixture for the e2e test. + +INSERT INTO patients (patient_id, name, dob, status) VALUES + ('P001', 'Jane Smith', '1985-03-15', 'active'), + ('P002', 'Marcus Lee', '1972-11-04', 'active'), + ('P003', 'Aisha Khan', '1990-07-22', 'active'), + ('P004', 'Rafael Ortiz', '1958-09-30', 'active'), + ('P005', 'Priya Iyer', '1995-12-01', 'active'), + ('P006', 'Daniel Park', '1980-04-19', 'active'); + +INSERT INTO allergies (patient_id, substance) VALUES + ('P001', 'penicillin'), + ('P003', 'sulfa'), + ('P005', 'latex'), + ('P005', 'shellfish'); + +INSERT INTO medications (patient_id, name, dose) VALUES + ('P001', 'lisinopril', '10mg daily'), + ('P003', 'levothyroxine', '50mcg daily'), + ('P004', 'metformin', '1000mg twice daily'), + ('P004', 'atorvastatin', '20mg nightly'), + ('P006', 'sertraline', '100mg daily'); + +INSERT INTO diagnoses (patient_id, code, description, date) VALUES + ('P001', 'I10', 'Essential hypertension', '2024-06-12'), + ('P003', 'E03.9', 'Hypothyroidism, unspecified', '2023-02-08'), + ('P004', 'E11.9', 'Type 2 diabetes mellitus without complications', '2022-04-19'), + ('P004', 'E78.5', 'Hyperlipidemia, unspecified', '2022-04-19'), + ('P006', 'F41.1', 'Generalized anxiety disorder', '2024-09-02'); + +INSERT INTO vitals (patient_id, timestamp, bp, hr, temp_f) VALUES + ('P001', '2024-06-12T10:00', '138/85', 72, 98.4), + ('P003', '2025-11-04T08:15', '118/74', 64, 98.2), + ('P004', '2026-01-15T14:00', '146/92', 78, 98.6), + ('P006', '2025-12-10T11:20', '124/80', 70, 98.5); diff --git a/tests/e2e/synthesis/__init__.py b/tests/e2e/synthesis/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/e2e/synthesis/test_ehr_db_e2e.py b/tests/e2e/synthesis/test_ehr_db_e2e.py new file mode 100644 index 0000000000..d4ed5251c1 --- /dev/null +++ b/tests/e2e/synthesis/test_ehr_db_e2e.py @@ -0,0 +1,178 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end checks for the EHR database example. + +Loads the shipped YAML, builds the env via the registry, seeds a fresh +SQLite DB from the bundled schema/seed SQL, then walks through realistic +clinical flows by calling ``env.step`` directly. No LLM required. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +import sqlalchemy + +from oumi.builders.environments import build_environment +from oumi.core.configs.synthesis_config import SynthesisConfig +from oumi.environments.database_executable_environment import ( + DatabaseExecutableEnvironment, +) + +CONFIG_PATH = ( + Path(__file__).resolve().parents[3] + / "configs" + / "examples" + / "synthesis" + / "ehr_db_synth.yaml" +) +SCHEMA_DIR = ( + Path(__file__).resolve().parents[3] / "src" / "oumi" / "examples" / "ehr_db" +) + + +def _exec_sql_file(conn, path: Path) -> None: + """Execute statements in a `.sql` file. Strips line comments before splitting.""" + raw = path.read_text() + lines = [line for line in raw.splitlines() if not line.lstrip().startswith("--")] + body = "\n".join(lines) + for stmt in body.split(";"): + if stmt.strip(): + conn.execute(sqlalchemy.text(stmt)) + + +@pytest.fixture +def db_path(tmp_path): + """Create a fresh SQLite DB seeded from schema.sql + seed.sql.""" + db_file = tmp_path / "ehr_test.db" + engine = sqlalchemy.create_engine( + f"sqlite:///{db_file}", isolation_level="AUTOCOMMIT", future=True + ) + with engine.connect() as conn: + _exec_sql_file(conn, SCHEMA_DIR / "schema.sql") + _exec_sql_file(conn, SCHEMA_DIR / "seed.sql") + engine.dispose() + return db_file + + +@pytest.fixture +def env_params(db_path): + cfg = SynthesisConfig.from_yaml(str(CONFIG_PATH)) + assert cfg.environment_config is not None + env_params = cfg.environment_config.environments[0] + # YAML uses an OmegaConf env-var interpolation for the DB path; in tests we + # override directly so each test gets its own fresh sqlite file. + assert env_params.env_kwargs is not None + env_params.env_kwargs["connection"]["database"] = str(db_path) + return env_params + + +def test_yaml_loads_and_env_builds(env_params): + env = build_environment(env_params) + try: + assert isinstance(env, DatabaseExecutableEnvironment) + assert set(env._executors.keys()) == { + "list_patients", + "get_patient", + "record_vitals", + "add_diagnosis", + "prescribe_medication", + "update_allergies", + } + finally: + if isinstance(env, DatabaseExecutableEnvironment): + env.close() + + +def test_chart_review_flow(env_params): + """Clinician asks for the chart for Jane Smith.""" + env = build_environment(env_params) + assert isinstance(env, DatabaseExecutableEnvironment) + try: + listing = env.step("list_patients", {}) + assert isinstance(listing.output, dict) + summaries = listing.output["patients"] + jane = next(p for p in summaries if p["name"] == "Jane Smith") + + record = env.step("get_patient", {"patient_id": jane["patient_id"]}) + assert isinstance(record.output, dict) + assert record.output["status"] == "ok" + patient = record.output["patient"] + assert "penicillin" in patient["allergies"] + assert any(m["name"] == "lisinopril" for m in patient["medications"]) + finally: + env.close() + + +def test_record_vitals_flow(env_params): + env = build_environment(env_params) + assert isinstance(env, DatabaseExecutableEnvironment) + try: + write = env.step( + "record_vitals", + { + "patient_id": "P002", + "timestamp": "2026-05-01T08:00", + "bp": "118/76", + "hr": 70, + "temp_f": 98.4, + }, + ) + assert isinstance(write.output, dict) + assert write.output["status"] == "ok" + + read = env.step("get_patient", {"patient_id": "P002"}) + assert isinstance(read.output, dict) + history = read.output["patient"]["vitals_history"] + assert any(v["timestamp"] == "2026-05-01T08:00" for v in history) + finally: + env.close() + + +def test_allergy_conflict_blocks_prescription(env_params): + """Jane Smith is allergic to penicillin — prescribing it must be refused.""" + env = build_environment(env_params) + assert isinstance(env, DatabaseExecutableEnvironment) + try: + result = env.step( + "prescribe_medication", + {"patient_id": "P001", "name": "penicillin", "dose": "500mg"}, + ) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + assert result.output["error"] == "allergy_conflict" + finally: + env.close() + + +def test_duplicate_diagnosis_returns_error(env_params): + env = build_environment(env_params) + assert isinstance(env, DatabaseExecutableEnvironment) + try: + result = env.step( + "add_diagnosis", + { + "patient_id": "P001", + "code": "I10", + "description": "Essential hypertension", + "date": "2026-05-01", + }, + ) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + assert result.output["error"] == "duplicate_diagnosis" + finally: + env.close() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/environments/__init__.py b/tests/integration/environments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/environments/test_database_executable_environment_postgres.py b/tests/integration/environments/test_database_executable_environment_postgres.py new file mode 100644 index 0000000000..c49b083a25 --- /dev/null +++ b/tests/integration/environments/test_database_executable_environment_postgres.py @@ -0,0 +1,218 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Postgres integration tests for DatabaseExecutableEnvironment. + +Gated behind the ``requires_postgres`` marker. Run via: + + uv run --extra dev pytest -m requires_postgres + +Starts a Postgres container per session via testcontainers-python. +Verifies dialect-specific behavior that SQLite cannot exercise: +``default_transaction_read_only``, real ``statement_timeout``, and the +``IntegrityError.sqlstate`` propagation. +""" + +from __future__ import annotations + +import pytest +import sqlalchemy + +from oumi.core.configs.params.environment_params import EnvironmentParams +from oumi.core.types.tool_call import ToolResult +from oumi.environments.database_executable_environment import ( + DatabaseExecutableEnvironment, +) +from oumi.environments.database_executable_tool import DatabaseExecutableTool + +pytestmark = pytest.mark.requires_postgres + + +# Module-scope executors (dotted-path resolution). + + +def _setup_executor(arguments, db): + db.execute( + sqlalchemy.text( + "CREATE TABLE IF NOT EXISTS patients (" + "id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE)" + ) + ) + db.execute( + sqlalchemy.text( + "INSERT INTO patients (id, name) VALUES (1, 'Jane') " + "ON CONFLICT (id) DO NOTHING" + ) + ) + return ToolResult(output={"status": "ok"}) + + +def _try_insert_executor(arguments, db): + db.execute(sqlalchemy.text("INSERT INTO patients (id, name) VALUES (2, 'Marcus')")) + return ToolResult(output={"status": "ok"}) + + +def _slow_query_executor(arguments, db): + db.execute(sqlalchemy.text("SELECT pg_sleep(2)")) + return ToolResult(output={"status": "ok"}) + + +def _duplicate_pk_executor(arguments, db): + db.execute(sqlalchemy.text("INSERT INTO patients (id, name) VALUES (1, 'X')")) + return ToolResult(output={"status": "should not reach"}) + + +@pytest.fixture(scope="session") +def postgres_dsn(): + """Yield a SQLAlchemy DSN for a fresh Postgres container.""" + from testcontainers.postgres import PostgresContainer + + with PostgresContainer("postgres:16-alpine") as pg: + yield pg.get_connection_url() + + +@pytest.fixture +def setup_db(postgres_dsn, monkeypatch): + monkeypatch.setenv("PG_TEST_DSN", postgres_dsn) + setup_params = EnvironmentParams( + id="setup", + name="setup", + description="seed", + env_type="database", + env_kwargs={ + "connection": {"dsn_env_var": "PG_TEST_DSN"}, + }, + tools=[ + DatabaseExecutableTool( + id="setup", + name="setup", + description="seed", + executor=( + "tests.integration.environments." + "test_database_executable_environment_postgres._setup_executor" + ), + read_only=False, + ) + ], + ) + env = DatabaseExecutableEnvironment.from_params(setup_params) + try: + env.step("setup", {}) + finally: + env.close() + yield postgres_dsn + + +def test_read_only_blocks_writes(setup_db, monkeypatch): + monkeypatch.setenv("PG_TEST_DSN", setup_db) + params = EnvironmentParams( + id="ro", + name="ro", + description="ro", + env_type="database", + env_kwargs={ + "connection": {"dsn_env_var": "PG_TEST_DSN"}, + "read_only": True, + }, + tools=[ + DatabaseExecutableTool( + id="t", + name="t", + description="t", + executor=( + "tests.integration.environments." + "test_database_executable_environment_postgres._try_insert_executor" + ), + read_only=True, + ) + ], + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + result = env.step("t", {}) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + msg = result.output["message"].lower() + assert "read-only" in msg or "read only" in msg + finally: + env.close() + + +def test_statement_timeout_cancels_long_query(setup_db, monkeypatch): + monkeypatch.setenv("PG_TEST_DSN", setup_db) + params = EnvironmentParams( + id="slow", + name="slow", + description="slow", + env_type="database", + env_kwargs={ + "connection": {"dsn_env_var": "PG_TEST_DSN"}, + "statement_timeout_ms": 200, + }, + tools=[ + DatabaseExecutableTool( + id="slow", + name="slow", + description="slow", + executor=( + "tests.integration.environments." + "test_database_executable_environment_postgres._slow_query_executor" + ), + read_only=True, + ) + ], + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + result = env.step("slow", {}) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + msg = result.output["message"].lower() + assert "statement timeout" in msg or "canceling statement" in msg + finally: + env.close() + + +def test_integrity_error_carries_sqlstate(setup_db, monkeypatch): + monkeypatch.setenv("PG_TEST_DSN", setup_db) + params = EnvironmentParams( + id="dup", + name="dup", + description="dup", + env_type="database", + env_kwargs={ + "connection": {"dsn_env_var": "PG_TEST_DSN"}, + }, + tools=[ + DatabaseExecutableTool( + id="dup", + name="dup", + description="dup", + executor=( + "tests.integration.environments." + "test_database_executable_environment_postgres._duplicate_pk_executor" + ), + read_only=False, + ) + ], + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + result = env.step("dup", {}) + assert isinstance(result.output, dict) + assert result.output["error"] == "IntegrityError" + # Postgres unique-violation SQLSTATE is 23505. + assert result.output["sql_state"] == "23505" + finally: + env.close() diff --git a/tests/unit/configs/__init__.py b/tests/unit/configs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/configs/params/__init__.py b/tests/unit/configs/params/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/configs/params/test_database_connection_params.py b/tests/unit/configs/params/test_database_connection_params.py new file mode 100644 index 0000000000..e3d4eb91d3 --- /dev/null +++ b/tests/unit/configs/params/test_database_connection_params.py @@ -0,0 +1,98 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Unit tests for DatabaseConnectionConfig.""" + +from __future__ import annotations + +import pytest + +from oumi.core.configs.params.database_connection_params import ( + DatabaseConnectionConfig, +) + + +def test_structured_fields_resolve_url(monkeypatch): + monkeypatch.setenv("PWD_VAR", "secret123") + cfg = DatabaseConnectionConfig( + driver="postgresql+psycopg", + host="db.local", + port=5432, + database="ehr", + username="oumi", + password_env_var="PWD_VAR", + ) + cfg.finalize_and_validate() + url = cfg.resolve_url() + assert url.drivername == "postgresql+psycopg" + assert url.host == "db.local" + assert url.port == 5432 + assert url.database == "ehr" + assert url.username == "oumi" + assert url.password == "secret123" + + +def test_dsn_env_var_resolve_url(monkeypatch): + monkeypatch.setenv("DSN_VAR", "sqlite:///:memory:") + cfg = DatabaseConnectionConfig(dsn_env_var="DSN_VAR") + cfg.finalize_and_validate() + url = cfg.resolve_url() + assert url.drivername == "sqlite" + assert url.database == ":memory:" + + +def test_structured_and_dsn_mutually_exclusive(): + cfg = DatabaseConnectionConfig( + driver="postgresql+psycopg", + database="ehr", + dsn_env_var="DSN_VAR", + ) + with pytest.raises(ValueError, match="not both"): + cfg.finalize_and_validate() + + +def test_neither_structured_nor_dsn_raises(): + cfg = DatabaseConnectionConfig() + with pytest.raises(ValueError, match="Got neither"): + cfg.finalize_and_validate() + + +def test_dsn_env_var_unset_raises(monkeypatch): + monkeypatch.delenv("DSN_VAR", raising=False) + cfg = DatabaseConnectionConfig(dsn_env_var="DSN_VAR") + cfg.finalize_and_validate() + with pytest.raises(ValueError, match="not set"): + cfg.resolve_url() + + +def test_password_env_var_unset_yields_no_password(monkeypatch): + monkeypatch.delenv("MISSING_PWD", raising=False) + cfg = DatabaseConnectionConfig( + driver="postgresql+psycopg", + database="ehr", + password_env_var="MISSING_PWD", + ) + cfg.finalize_and_validate() + url = cfg.resolve_url() + assert url.password is None + + +def test_pool_size_must_be_positive(): + cfg = DatabaseConnectionConfig(driver="sqlite", database=":memory:", pool_size=0) + with pytest.raises(ValueError, match="pool_size"): + cfg.finalize_and_validate() + + +def test_connect_timeout_must_be_positive(): + cfg = DatabaseConnectionConfig( + driver="sqlite", database=":memory:", connect_timeout_s=0 + ) + with pytest.raises(ValueError, match="connect_timeout_s"): + cfg.finalize_and_validate() + + +def test_pool_max_overflow_must_be_nonnegative(): + cfg = DatabaseConnectionConfig( + driver="sqlite", database=":memory:", pool_max_overflow=-1 + ) + with pytest.raises(ValueError, match="pool_max_overflow"): + cfg.finalize_and_validate() diff --git a/tests/unit/environments/test_database_executable_environment.py b/tests/unit/environments/test_database_executable_environment.py new file mode 100644 index 0000000000..26e62e2df2 --- /dev/null +++ b/tests/unit/environments/test_database_executable_environment.py @@ -0,0 +1,540 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Unit tests for DatabaseExecutableEnvironment (SQLite-backed).""" + +from __future__ import annotations + +import logging +import tempfile +from pathlib import Path + +import pytest +import sqlalchemy +import sqlalchemy.exc + +from oumi.builders.environments import build_environment +from oumi.core.configs.params.environment_params import EnvironmentParams +from oumi.core.types.tool_call import ToolResult +from oumi.environments.database_executable_environment import ( + DatabaseExecutableEnvironment, + DatabaseExecutableEnvironmentKwargs, +) +from oumi.environments.database_executable_tool import DatabaseExecutableTool + + +# Test executors at module scope (dotted-path resolution). +def _select_one_executor(arguments, db): + rows = db.execute(sqlalchemy.text("SELECT 1 AS one")).mappings().all() + return ToolResult(output={"rows": [dict(r) for r in rows]}) + + +def _make_tool(executor, tool_id="t1", read_only=True, **extra): + return DatabaseExecutableTool( + id=tool_id, + name=tool_id.upper(), + description="A DB test tool.", + executor=executor, + read_only=read_only, + **extra, + ) + + +def _make_params(tools, env_kwargs=None, env_id="env1"): + if env_kwargs is None: + env_kwargs = { + "connection": { + "driver": "sqlite", + "database": ":memory:", + } + } + return EnvironmentParams( + id=env_id, + name="Env1", + description="A DB env.", + env_type="database", + tools=tools, + env_kwargs=env_kwargs, + ) + + +def test_from_params_constructs_engine(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor" + ) + ] + ) + env = DatabaseExecutableEnvironment.from_params(params) + assert isinstance(env, DatabaseExecutableEnvironment) + assert isinstance(env._kwargs, DatabaseExecutableEnvironmentKwargs) + assert env._engine is not None + env.close() + + +def test_from_params_requires_connection(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor" + ) + ], + env_kwargs={}, + ) + with pytest.raises(ValueError, match="connection"): + DatabaseExecutableEnvironment.from_params(params) + + +def test_from_params_fail_fast_on_bad_url(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor" + ) + ], + env_kwargs={ + "connection": { + "driver": "sqlite", + "database": "/nonexistent_dir_for_test/db.sqlite", + } + }, + ) + with pytest.raises(ValueError, match="Failed to connect"): + DatabaseExecutableEnvironment.from_params(params) + + +def test_close_disposes_engine(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor" + ) + ] + ) + env = DatabaseExecutableEnvironment.from_params(params) + engine = env._engine + assert engine is not None + env.close() + # After dispose() we still hold the engine reference; the test confirms + # close() ran without error. + assert env._engine is engine + + +def _create_table_executor(arguments, db): + db.execute( + sqlalchemy.text("CREATE TABLE patients (id INTEGER PRIMARY KEY, name TEXT)") + ) + db.execute( + sqlalchemy.text( + "INSERT INTO patients (id, name) VALUES (1, 'Jane'), (2, 'Marcus')" + ) + ) + return ToolResult(output={"status": "ok"}) + + +def _list_patients_executor(arguments, db): + rows = ( + db.execute(sqlalchemy.text("SELECT id, name FROM patients ORDER BY id")) + .mappings() + .all() + ) + return ToolResult(output={"patients": [dict(r) for r in rows]}) + + +def test_step_runs_executor_and_returns_rows(): + # Use a file-backed SQLite DB so two tool calls hit the same database. + with tempfile.TemporaryDirectory() as tmp: + db_path = Path(tmp) / "test.db" + env_kwargs = { + "connection": { + "driver": "sqlite", + "database": str(db_path), + } + } + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._create_table_executor", + tool_id="setup", + read_only=False, + ), + _make_tool( + "tests.unit.environments.test_database_executable_environment._list_patients_executor", + tool_id="list", + ), + ], + env_kwargs=env_kwargs, + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + setup_result = env.step("setup", {}) + assert setup_result.output == {"status": "ok"} + + list_result = env.step("list", {}) + assert list_result.output == { + "patients": [ + {"id": 1, "name": "Jane"}, + {"id": 2, "name": "Marcus"}, + ] + } + finally: + env.close() + + +def _try_insert_executor(arguments, db): + """Catches OperationalError and returns the message.""" + try: + db.execute( + sqlalchemy.text("INSERT INTO patients (id, name) VALUES (3, 'Aisha')") + ) + return ToolResult(output={"status": "ok"}) + except sqlalchemy.exc.OperationalError as e: + return ToolResult( + output={"status": "error", "message": str(e.orig) if e.orig else str(e)} + ) + + +def test_dialect_guards_sqlite_read_only_rejects_writes(): + """With read_only=True the SQLite connection has PRAGMA query_only=ON.""" + with tempfile.TemporaryDirectory() as tmp: + db_path = Path(tmp) / "ro.db" + # Seed a row first under a separate, non-read-only env. + seed_params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._create_table_executor", + tool_id="setup", + read_only=False, + ) + ], + env_kwargs={"connection": {"driver": "sqlite", "database": str(db_path)}}, + ) + seed_env = DatabaseExecutableEnvironment.from_params(seed_params) + try: + seed_env.step("setup", {}) + finally: + seed_env.close() + + # Now open a read-only env on the same file. + ro_params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._try_insert_executor", + tool_id="try_write", + read_only=True, + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": str(db_path)}, + "read_only": True, + }, + ) + ro_env = DatabaseExecutableEnvironment.from_params(ro_params) + try: + result = ro_env.step("try_write", {}) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + msg = result.output["message"].lower() + assert "read" in msg or "readonly" in msg + finally: + ro_env.close() + + +def test_dialect_guards_sqlite_statement_timeout_warns(caplog): + """SQLite doesn't support statement_timeout — env warns instead of failing.""" + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor" + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": ":memory:"}, + "statement_timeout_ms": 5000, + }, + ) + with caplog.at_level(logging.WARNING): + env = DatabaseExecutableEnvironment.from_params(params) + try: + env.step("t1", {}) + finally: + env.close() + assert any( + "SQLite" in record.message and "statement_timeout" in record.message + for record in caplog.records + ) + + +def _executor_returns_updated_state(arguments, db): + return ToolResult(output={"status": "ok"}, updated_state={"x": 1}) + + +def _executor_lets_integrity_error_escape(arguments, db): + db.execute( + sqlalchemy.text("CREATE TABLE IF NOT EXISTS uq (id INTEGER PRIMARY KEY)") + ) + db.execute(sqlalchemy.text("INSERT INTO uq (id) VALUES (1)")) + db.execute(sqlalchemy.text("INSERT INTO uq (id) VALUES (1)")) # PK conflict + return ToolResult(output={"status": "should not reach"}) + + +def _executor_lets_programming_error_escape(arguments, db): + db.execute(sqlalchemy.text("SELECT * FROM table_that_does_not_exist")) + return ToolResult(output={"status": "should not reach"}) + + +def test_step_rejects_updated_state_in_result(): + """DB envs hold state in the DB; ToolResult.updated_state is forbidden.""" + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._executor_returns_updated_state", + read_only=False, + ) + ] + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + with pytest.raises(Exception, match="updated_state"): + env.step("t1", {}) + finally: + env.close() + + +def test_step_auto_wraps_integrity_error(): + with tempfile.TemporaryDirectory() as tmp: + db_path = Path(tmp) / "iw.db" + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._executor_lets_integrity_error_escape", + read_only=False, + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": str(db_path)}, + }, + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + result = env.step("t1", {}) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + assert result.output["error"] == "IntegrityError" + msg_lower = result.output["message"].lower() + assert ("unique" in msg_lower) or ("primary key" in msg_lower) + finally: + env.close() + + +def test_step_auto_wraps_programming_error(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._executor_lets_programming_error_escape" + ) + ] + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + result = env.step("t1", {}) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + # SQLite raises OperationalError, not ProgrammingError, for missing tables. + assert result.output["error"] in {"OperationalError", "ProgrammingError"} + assert ( + "table_that_does_not_exist" in result.output["message"] + or "no such table" in result.output["message"].lower() + ) + finally: + env.close() + + +def test_auto_wrap_skips_output_schema_validation(): + """A tool with strict output_schema still surfaces auto-wrap shape on SQL error.""" + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._executor_lets_programming_error_escape", + output_schema={ + "type": "object", + "properties": {"some_field": {"type": "string"}}, + "required": ["some_field"], + }, + ) + ] + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + result = env.step("t1", {}) + # The wrap shape doesn't include "some_field"; we should get the wrap, + # not a schema-validation error. + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + finally: + env.close() + + +def test_per_tool_timeout_override_on_sqlite_is_no_op_smoke(): + """SQLite can't enforce per-statement timeouts, but the SET path must not crash.""" + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor", + statement_timeout_ms=500, + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": ":memory:"}, + "statement_timeout_ms": 5000, + }, + ) + env = DatabaseExecutableEnvironment.from_params(params) + try: + result = env.step("t1", {}) + assert result.output == {"rows": [{"one": 1}]} + finally: + env.close() + + +def test_read_only_env_requires_read_only_tools(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor", + tool_id="writer", + read_only=False, + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": ":memory:"}, + "read_only": True, + }, + ) + with pytest.raises(ValueError, match="read_only"): + DatabaseExecutableEnvironment.from_params(params) + + +def test_per_tool_timeout_must_be_stricter_than_env(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor", + statement_timeout_ms=10000, + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": ":memory:"}, + "statement_timeout_ms": 5000, + }, + ) + with pytest.raises(ValueError, match="statement_timeout"): + DatabaseExecutableEnvironment.from_params(params) + + +def test_per_tool_timeout_at_env_level_ok(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor", + statement_timeout_ms=5000, + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": ":memory:"}, + "statement_timeout_ms": 5000, + }, + ) + # Equal is acceptable. + env = DatabaseExecutableEnvironment.from_params(params) + env.close() + + +def test_audit_off_by_default(caplog): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor" + ) + ] + ) + with caplog.at_level(logging.INFO): + env = DatabaseExecutableEnvironment.from_params(params) + try: + env.step("t1", {}) + finally: + env.close() + audit_records = [r for r in caplog.records if "db_tool_call" in r.message] + assert audit_records == [] + + +def test_audit_on_logs_per_tool_call(caplog): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor" + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": ":memory:"}, + "audit": True, + }, + ) + with caplog.at_level(logging.INFO): + env = DatabaseExecutableEnvironment.from_params(params) + try: + env.step("t1", {}) + finally: + env.close() + audit_records = [r for r in caplog.records if "db_tool_call" in r.message] + assert len(audit_records) == 1 + msg = audit_records[0].message + assert "env1" in msg + assert "t1" in msg + assert "status=ok" in msg + + +def _executor_raises_value_error(arguments, db): + raise ValueError("boom") + + +def test_audit_logs_error_status_when_executor_raises_non_db_error(caplog): + """Bugs (non-DBAPIError exceptions) propagate but still get an audit entry.""" + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._executor_raises_value_error", + tool_id="bad", + ) + ], + env_kwargs={ + "connection": {"driver": "sqlite", "database": ":memory:"}, + "audit": True, + }, + ) + env = DatabaseExecutableEnvironment.from_params(params) + with caplog.at_level(logging.INFO): + try: + with pytest.raises(ValueError): + env.step("bad", {}) + finally: + env.close() + audit_records = [r for r in caplog.records if "db_tool_call" in r.message] + assert len(audit_records) == 1 + assert "status=error" in audit_records[0].message + + +def test_env_registered_under_database_key(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_database_executable_environment._select_one_executor" + ) + ] + ) + env = build_environment(params) + try: + assert isinstance(env, DatabaseExecutableEnvironment) + finally: + if isinstance(env, DatabaseExecutableEnvironment): + env.close() diff --git a/tests/unit/environments/test_database_executable_tool.py b/tests/unit/environments/test_database_executable_tool.py new file mode 100644 index 0000000000..b7d516a672 --- /dev/null +++ b/tests/unit/environments/test_database_executable_tool.py @@ -0,0 +1,70 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Unit tests for DatabaseExecutableTool.""" + +from __future__ import annotations + +import pytest + +from oumi.environments.database_executable_tool import DatabaseExecutableTool + + +def test_default_no_per_tool_overrides(): + tool = DatabaseExecutableTool( + id="t1", + name="T1", + description="A tool.", + executor="some.module.func", + ) + assert tool.statement_timeout_ms is None + + +def test_per_tool_timeout_override(): + tool = DatabaseExecutableTool( + id="t1", + name="T1", + description="A tool.", + executor="some.module.func", + statement_timeout_ms=500, + ) + assert tool.statement_timeout_ms == 500 + + +def test_create_from_mapping(): + raw = { + "id": "t1", + "name": "T1", + "description": "A tool.", + "parameters": {"type": "object"}, + "executor": "some.module.func", + "statement_timeout_ms": 250, + } + tool = DatabaseExecutableTool.create(raw) + assert isinstance(tool, DatabaseExecutableTool) + assert tool.executor == "some.module.func" + assert tool.statement_timeout_ms == 250 + + +def test_create_passes_through_existing_tool(): + tool = DatabaseExecutableTool( + id="t1", + name="T1", + description="A tool.", + executor="some.module.func", + ) + assert DatabaseExecutableTool.create(tool) is tool + + +def test_create_rejects_non_mapping(): + with pytest.raises(TypeError, match="mappings"): + DatabaseExecutableTool.create(["not", "a", "mapping"]) + + +def test_empty_executor_raises(): + with pytest.raises(ValueError, match="executor"): + DatabaseExecutableTool( + id="t1", + name="T1", + description="A tool.", + executor="", + ) diff --git a/tests/unit/environments/test_ehr_db_executors.py b/tests/unit/environments/test_ehr_db_executors.py new file mode 100644 index 0000000000..6cc0db5349 --- /dev/null +++ b/tests/unit/environments/test_ehr_db_executors.py @@ -0,0 +1,150 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Unit tests for EHR DB executors (SQLite-backed in-memory).""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +import sqlalchemy + +from oumi.core.types.tool_call import ToolResult +from oumi.examples.ehr_db import executors as ehr_db + +_REPO_ROOT = Path(__file__).resolve().parents[3] +SCHEMA_DIR = _REPO_ROOT / "src" / "oumi" / "examples" / "ehr_db" + + +@pytest.fixture +def db(): + """In-memory SQLite engine seeded from schema.sql + seed.sql.""" + engine = sqlalchemy.create_engine( + "sqlite:///:memory:", isolation_level="AUTOCOMMIT", future=True + ) + schema_sql = (SCHEMA_DIR / "schema.sql").read_text() + seed_sql = (SCHEMA_DIR / "seed.sql").read_text() + + def _exec_sql(conn: sqlalchemy.engine.Connection, raw: str) -> None: + # Strip single-line comments before splitting on ";" so that + # leading comment lines don't get concatenated into the first statement. + lines = [ln for ln in raw.splitlines() if not ln.lstrip().startswith("--")] + cleaned = "\n".join(lines) + for stmt in cleaned.split(";"): + if stmt.strip(): + conn.execute(sqlalchemy.text(stmt)) + + with engine.connect() as conn: + _exec_sql(conn, schema_sql) + _exec_sql(conn, seed_sql) + yield engine + engine.dispose() + + +def test_list_patients(db): + with db.connect() as conn: + result = ehr_db.list_patients({}, conn) + assert isinstance(result, ToolResult) + assert isinstance(result.output, dict) + patients = result.output["patients"] + assert len(patients) == 6 + assert {p["patient_id"] for p in patients} == {f"P00{i}" for i in range(1, 7)} + assert all({"patient_id", "name", "dob", "status"} <= set(p) for p in patients) + + +def test_get_patient_known(db): + with db.connect() as conn: + result = ehr_db.get_patient({"patient_id": "P001"}, conn) + assert isinstance(result.output, dict) + assert result.output["status"] == "ok" + patient = result.output["patient"] + assert patient["name"] == "Jane Smith" + assert "penicillin" in patient["allergies"] + assert any(m["name"] == "lisinopril" for m in patient["medications"]) + + +def test_get_patient_unknown(db): + with db.connect() as conn: + result = ehr_db.get_patient({"patient_id": "P999"}, conn) + assert result.output == { + "status": "error", + "error": "not_found", + "patient_id": "P999", + } + + +def test_record_vitals_appends(db): + args = { + "patient_id": "P001", + "timestamp": "2026-05-01T09:00", + "bp": "120/80", + "hr": 70, + "temp_f": 98.6, + } + with db.connect() as conn: + result = ehr_db.record_vitals(args, conn) + assert isinstance(result.output, dict) + assert result.output["status"] == "ok" + with db.connect() as conn: + rows = ( + conn.execute( + sqlalchemy.text( + "SELECT timestamp, bp, hr, temp_f FROM vitals " + "WHERE patient_id='P001' ORDER BY timestamp" + ) + ) + .mappings() + .all() + ) + assert any(r["timestamp"] == "2026-05-01T09:00" for r in rows) + + +def test_add_diagnosis_duplicate_returns_error(db): + args = { + "patient_id": "P001", + "code": "I10", # already present in seed + "description": "Essential hypertension", + "date": "2026-05-01", + } + with db.connect() as conn: + result = ehr_db.add_diagnosis(args, conn) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + assert result.output["error"] == "duplicate_diagnosis" + + +def test_prescribe_medication_allergy_conflict(db): + args = {"patient_id": "P001", "name": "Penicillin", "dose": "500mg"} + with db.connect() as conn: + result = ehr_db.prescribe_medication(args, conn) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + assert result.output["error"] == "allergy_conflict" + + +def test_prescribe_medication_already_prescribed(db): + args = {"patient_id": "P001", "name": "lisinopril", "dose": "20mg daily"} + with db.connect() as conn: + result = ehr_db.prescribe_medication(args, conn) + assert isinstance(result.output, dict) + assert result.output["status"] == "error" + assert result.output["error"] == "already_prescribed" + + +def test_update_allergies_replaces(db): + args = {"patient_id": "P001", "allergies": ["latex"]} + with db.connect() as conn: + result = ehr_db.update_allergies(args, conn) + assert isinstance(result.output, dict) + assert result.output["status"] == "ok" + with db.connect() as conn: + rows = ( + conn.execute( + sqlalchemy.text( + "SELECT substance FROM allergies WHERE patient_id='P001'" + ) + ) + .scalars() + .all() + ) + assert sorted(rows) == ["latex"] diff --git a/tests/unit/environments/test_executable_environment.py b/tests/unit/environments/test_executable_environment.py new file mode 100644 index 0000000000..fadb8743b6 --- /dev/null +++ b/tests/unit/environments/test_executable_environment.py @@ -0,0 +1,145 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Unit tests for the abstract ExecutableEnvironment base class.""" + +from __future__ import annotations + +from contextlib import contextmanager + +import pytest + +from oumi.core.configs.params.environment_params import EnvironmentParams +from oumi.core.configs.params.tool_params import ToolError +from oumi.core.registry import register_environment +from oumi.core.types.tool_call import ToolResult +from oumi.environments.executable_environment import ( + ExecutableEnvironment, + _import_executor, +) +from oumi.environments.executable_tool import ExecutableTool + + +# Module-scope test executors (so dotted-path resolution works). +def _ok_executor(arguments, ctx): + return ToolResult(output={"echo": arguments, "ctx": ctx}) + + +def _bad_executor_returns_str(arguments, ctx): + return "not a ToolResult" + + +def _executor_raises_value_error(arguments, ctx): + raise ValueError("boom") + + +_NOT_CALLABLE = 42 + + +@register_environment("_test_fake_executable") +class _FakeExecutableEnvironment(ExecutableEnvironment): + """Concrete subclass for testing — yields a sentinel string as the context.""" + + tool_params_cls = ExecutableTool + _executor_context_kwarg = "ctx" + + def __init__(self, params, kwargs=None): + self._params = params + self._kwargs = kwargs + self._executors = {} + + @classmethod + def from_params(cls, params): + env = cls(params) + for tool in params.tools: + env._executors[tool.id] = _import_executor(tool.executor, tool.id) + return env + + @contextmanager + def _build_execution_context(self, tool, arguments): + yield "fake-ctx-sentinel" + + +def _make_tool(executor, tool_id="t1"): + return ExecutableTool( + id=tool_id, + name=tool_id.upper(), + description="A test tool.", + executor=executor, + ) + + +def _make_params(tools): + return EnvironmentParams( + id="env1", + name="Env1", + description="A fake executable env.", + env_type="_test_fake_executable", + tools=tools, + ) + + +def test_step_returns_executor_result(): + params = _make_params( + [_make_tool("tests.unit.environments.test_executable_environment._ok_executor")] + ) + env = _FakeExecutableEnvironment.from_params(params) + result = env.step("t1", {"k": "v"}) + assert isinstance(result, ToolResult) + assert result.output == {"echo": {"k": "v"}, "ctx": "fake-ctx-sentinel"} + + +def test_step_rejects_non_toolresult_return(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_executable_environment._bad_executor_returns_str" + ) + ] + ) + env = _FakeExecutableEnvironment.from_params(params) + with pytest.raises(ToolError, match="ToolResult"): + env.step("t1", {}) + + +def test_step_propagates_executor_value_error(): + params = _make_params( + [ + _make_tool( + "tests.unit.environments.test_executable_environment._executor_raises_value_error" + ) + ] + ) + env = _FakeExecutableEnvironment.from_params(params) + with pytest.raises(ValueError, match="boom"): + env.step("t1", {}) + + +def test_import_executor_rejects_missing_module(): + with pytest.raises(ValueError, match="could not import module"): + _import_executor("oumi.does_not_exist.foo", "t1") + + +def test_import_executor_rejects_missing_attr(): + with pytest.raises(ValueError, match="has no attribute"): + _import_executor("oumi.environments.executable_tool.NopeNotHere", "t1") + + +def test_import_executor_rejects_non_callable(): + with pytest.raises(ValueError, match="not callable"): + _import_executor( + "tests.unit.environments.test_executable_environment._NOT_CALLABLE", + "t1", + ) + + +def test_import_executor_rejects_non_dotted_path(): + with pytest.raises(ValueError, match="dotted path"): + _import_executor("just_a_name", "t1") + + +def test_close_default_is_no_op(): + params = _make_params( + [_make_tool("tests.unit.environments.test_executable_environment._ok_executor")] + ) + env = _FakeExecutableEnvironment.from_params(params) + env.close() # should not raise diff --git a/tests/unit/environments/test_executable_tool.py b/tests/unit/environments/test_executable_tool.py new file mode 100644 index 0000000000..1db4bfc436 --- /dev/null +++ b/tests/unit/environments/test_executable_tool.py @@ -0,0 +1,40 @@ +# Copyright 2025 - Oumi +# Licensed under the Apache License, Version 2.0 +"""Unit tests for ExecutableTool base class.""" + +from __future__ import annotations + +import pytest + +from oumi.environments.executable_tool import ExecutableTool + + +def test_valid_tool_constructs(): + tool = ExecutableTool( + id="t1", + name="T1", + description="A tool.", + executor="some.module.func", + ) + assert tool.executor == "some.module.func" + + +def test_empty_executor_raises(): + with pytest.raises(ValueError, match="executor"): + ExecutableTool( + id="t1", + name="T1", + description="A tool.", + executor="", + ) + + +def test_inherits_toolparams_validation(): + # Empty id should still fail via parent ToolParams.__post_init__ + with pytest.raises(ValueError, match="id"): + ExecutableTool( + id="", + name="T1", + description="A tool.", + executor="some.module.func", + )