From 0355ea71bbe8a20607caee3c55f56efeda806221 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 16 Jun 2026 11:40:09 -0700 Subject: [PATCH 01/12] feat(environments): implement ExecutableEnvironment executor dispatch --- .../environments/executable_environment.py | 38 ++++++++++- .../test_executable_environment.py | 63 +++++++++++++++++-- 2 files changed, 92 insertions(+), 9 deletions(-) diff --git a/src/oumi/environments/executable_environment.py b/src/oumi/environments/executable_environment.py index 2d26438c8d..264174ebbd 100644 --- a/src/oumi/environments/executable_environment.py +++ b/src/oumi/environments/executable_environment.py @@ -16,13 +16,14 @@ from __future__ import annotations +import jsonschema from abc import abstractmethod from collections.abc import Callable from contextlib import AbstractContextManager from typing import Any, ClassVar from oumi.core.configs.params.environment_params import EnvironmentParams -from oumi.core.configs.params.tool_params import ToolParams +from oumi.core.configs.params.tool_params import ToolError, ToolLookupError, ToolParams from oumi.core.types.tool_call import ToolResult from oumi.environments.base_environment import BaseEnvironment from oumi.environments.executable_tool import ExecutableTool @@ -66,6 +67,37 @@ def step(self, calls: list[tuple[str, dict[str, Any]]]) -> list[ToolResult]: """Execute a batch of tool calls; results are returned in input order.""" return [self._step_one(tool_id, arguments) for tool_id, arguments in calls] + def _lookup_tool(self, tool_id: str) -> ExecutableTool: + for tool in self._params.tools: + if tool.id == tool_id: + return tool + raise ToolLookupError( + f"Tool '{tool_id}' not found in environment '{self._params.id}'. " + f"Available tools: {[t.id for t in self._params.tools]}" + ) + + def _validate_result(self, tool: ExecutableTool, result: Any) -> ToolResult: + if not isinstance(result, ToolResult): + raise ToolError( + f"Tool '{tool.id}' executor must return ToolResult, got " + f"{type(result).__name__}." + ) + if tool.output_schema is not None: + try: + jsonschema.validate(result.output, tool.output_schema) + except jsonschema.ValidationError as e: + raise ToolError( + f"Tool '{tool.id}' executor output failed schema validation: {e}" + ) from e + return result + def _step_one(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: - """Execute a single tool call.""" - raise NotImplementedError + tool = self._lookup_tool(tool_id) + tool.validate_arguments(arguments) + with self._build_execution_context(tool, arguments) as ctx: + result = self._executors[tool_id]( + **{"arguments": arguments, self._executor_context_kwarg: ctx} + ) + validated = self._validate_result(tool, result) + self._absorb_result(tool, validated) + return validated diff --git a/tests/unit/environments/test_executable_environment.py b/tests/unit/environments/test_executable_environment.py index cdc99e1305..e4c10d9dd1 100644 --- a/tests/unit/environments/test_executable_environment.py +++ b/tests/unit/environments/test_executable_environment.py @@ -32,7 +32,9 @@ class _MinimalExecEnv(ExecutableEnvironment): """Smallest concrete subclass that satisfies the abstract surface.""" def __init__(self) -> None: - self._params = EnvironmentParams(id="test", env_type="executable") + self._params = EnvironmentParams( + id="test", name="test", description="d", env_type="executable" + ) self._executors = {} @contextmanager @@ -72,14 +74,63 @@ def test_absorb_result_is_noop(): def test_step_batch_dispatches_to_step_one(): - """Batch step() iterates the call list and dispatches each to _step_one.""" + """Batch step() dispatches each call to _step_one, which looks up the tool.""" + from oumi.core.configs.params.tool_params import ToolLookupError + env = _MinimalExecEnv() - with pytest.raises(NotImplementedError): + with pytest.raises(ToolLookupError): env.step([("tool_a", {})]) -def test_step_one_raises_not_implemented(): - """_step_one is the per-call dispatch hook; the base implementation raises.""" +def test_step_one_unknown_tool_raises_lookup_error_minimal(): + """_step_one raises a lookup error when the tool id is unknown.""" + from oumi.core.configs.params.tool_params import ToolLookupError + env = _MinimalExecEnv() - with pytest.raises(NotImplementedError): + with pytest.raises(ToolLookupError): env._step_one("tool_a", {}) + + +class _EchoExecEnv(ExecutableEnvironment): + """Concrete env whose executor echoes the context it was handed.""" + + def __init__(self, tools: list[ExecutableTool]) -> None: + self._params = EnvironmentParams( + id="echo", name="echo", description="d", env_type="executable", tools=tools + ) + self._executors = {t.id: _echo_executor for t in tools} + + @contextmanager + def _build_execution_context( + self, tool: ExecutableTool, arguments: dict[str, Any] + ) -> Iterator[Any]: + yield {"ctx_for": tool.id} + + +def _echo_executor(*, arguments: dict[str, Any], context: Any) -> ToolResult: + return ToolResult(output={"args": arguments, "context": context}) + + +def test_step_one_dispatches_to_executor_with_context(): + tool = ExecutableTool(id="t", name="t", description="d", executor="x.y") + env = _EchoExecEnv([tool]) + [result] = env.step([("t", {"a": 1})]) + assert result.output == {"args": {"a": 1}, "context": {"ctx_for": "t"}} + + +def test_step_one_unknown_tool_raises_lookup_error(): + from oumi.core.configs.params.tool_params import ToolLookupError + + env = _EchoExecEnv([]) + with pytest.raises(ToolLookupError): + env.step([("missing", {})]) + + +def test_step_one_rejects_non_toolresult_executor_return(): + from oumi.core.configs.params.tool_params import ToolError + + tool = ExecutableTool(id="bad", name="bad", description="d", executor="x.y") + env = _EchoExecEnv([tool]) + env._executors["bad"] = lambda **_: {"not": "a ToolResult"} + with pytest.raises(ToolError): + env.step([("bad", {})]) From 05eebae536ccb87b3c23f38c849ff14893954c53 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 16 Jun 2026 11:45:10 -0700 Subject: [PATCH 02/12] fix(environments): correct isort import order in ExecutableEnvironment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move `import jsonschema` from between __future__ and stdlib imports to its correct position (stdlib → third-party → first-party), fixing the ruff I001 lint failure that was blocking CI. --- src/oumi/environments/executable_environment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/oumi/environments/executable_environment.py b/src/oumi/environments/executable_environment.py index 264174ebbd..04120ff4b7 100644 --- a/src/oumi/environments/executable_environment.py +++ b/src/oumi/environments/executable_environment.py @@ -16,12 +16,13 @@ from __future__ import annotations -import jsonschema from abc import abstractmethod from collections.abc import Callable from contextlib import AbstractContextManager from typing import Any, ClassVar +import jsonschema + from oumi.core.configs.params.environment_params import EnvironmentParams from oumi.core.configs.params.tool_params import ToolError, ToolLookupError, ToolParams from oumi.core.types.tool_call import ToolResult From 131d63e4b2ff2592bdede3af5a3ee3b5ed9e2ab3 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 16 Jun 2026 11:48:52 -0700 Subject: [PATCH 03/12] feat(environments): add rollback-based SQLite isolation primitive --- src/oumi/environments/db_isolation.py | 73 +++++++++++++++++++ tests/unit/environments/test_db_isolation.py | 76 ++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 src/oumi/environments/db_isolation.py create mode 100644 tests/unit/environments/test_db_isolation.py diff --git a/src/oumi/environments/db_isolation.py b/src/oumi/environments/db_isolation.py new file mode 100644 index 0000000000..4e0aafea9c --- /dev/null +++ b/src/oumi/environments/db_isolation.py @@ -0,0 +1,73 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rollback-based SQLite isolation for per-rollout database environments. + +The SkyRL pattern: each rollout gets its own connection that opens a +transaction and never commits, so uncommitted writes are visible within the +rollout and discarded on close. The environment owns the transaction; +executors must not call ``commit()``. +""" + +from __future__ import annotations + +import sqlite3 +import tempfile +import uuid +from pathlib import Path + + +def materialize_sqlite_snapshot( + *, + schema_sql: str, + seed_sql: str | None = None, + dest: Path | str | None = None, +) -> Path: + """Build a snapshot SQLite file from DDL (+ optional seed INSERTs).""" + path = ( + Path(dest) + if dest is not None + else Path(tempfile.gettempdir()) / f"oumi_snapshot_{uuid.uuid4().hex}.sqlite" + ) + conn = sqlite3.connect(path) + try: + conn.executescript(schema_sql) + if seed_sql: + conn.executescript(seed_sql) + conn.commit() + finally: + conn.close() + return path + + +class RollbackSession: + """A per-rollout SQLite connection that never commits and rolls back on close. + + Set ``owns_file=True`` when the env built a throwaway per-rollout database + that should be deleted on teardown (as opposed to a shared snapshot). + """ + + def __init__(self, db_path: Path | str, *, owns_file: bool = False) -> None: + self._path = Path(db_path) + self._owns_file = owns_file + self.connection = sqlite3.connect(self._path) + + def close(self) -> None: + """Roll back any open transaction, close, and delete an owned file.""" + try: + self.connection.rollback() + finally: + self.connection.close() + if self._owns_file: + self._path.unlink(missing_ok=True) diff --git a/tests/unit/environments/test_db_isolation.py b/tests/unit/environments/test_db_isolation.py new file mode 100644 index 0000000000..7ced417eb9 --- /dev/null +++ b/tests/unit/environments/test_db_isolation.py @@ -0,0 +1,76 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sqlite3 +from pathlib import Path + +from oumi.environments.db_isolation import RollbackSession, materialize_sqlite_snapshot + +_SCHEMA = "CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT);" +_SEED = "INSERT INTO t VALUES (1, 'a');" + + +def test_materialize_builds_a_seeded_snapshot(tmp_path): + path = materialize_sqlite_snapshot( + schema_sql=_SCHEMA, seed_sql=_SEED, dest=tmp_path / "seed.sqlite" + ) + conn = sqlite3.connect(path) + assert conn.execute("SELECT v FROM t WHERE id = 1").fetchone()[0] == "a" + conn.close() + + +def test_rollback_session_discards_uncommitted_writes(tmp_path): + path = materialize_sqlite_snapshot( + schema_sql=_SCHEMA, seed_sql=_SEED, dest=tmp_path / "seed.sqlite" + ) + session = RollbackSession(path) + # Write without committing; visible on this connection... + session.connection.execute("UPDATE t SET v = 'mutated' WHERE id = 1") + assert session.connection.execute("SELECT v FROM t WHERE id = 1").fetchone()[0] == ( + "mutated" + ) + session.close() # rolls back + closes + # ...gone from the snapshot afterwards. + conn = sqlite3.connect(path) + assert conn.execute("SELECT v FROM t WHERE id = 1").fetchone()[0] == "a" + conn.close() + + +def test_two_sessions_on_one_snapshot_do_not_see_each_others_uncommitted_writes( + tmp_path, +): + path = materialize_sqlite_snapshot( + schema_sql=_SCHEMA, seed_sql=_SEED, dest=tmp_path / "seed.sqlite" + ) + a = RollbackSession(path) + b = RollbackSession(path) + try: + a.connection.execute("UPDATE t SET v = 'from_a' WHERE id = 1") + # b never sees a's uncommitted write. + assert b.connection.execute("SELECT v FROM t WHERE id = 1").fetchone()[0] == "a" + finally: + a.close() + b.close() + + +def test_owned_session_deletes_its_file_on_close(tmp_path): + path = materialize_sqlite_snapshot( + schema_sql=_SCHEMA, dest=tmp_path / "owned.sqlite" + ) + session = RollbackSession(path, owns_file=True) + assert path.exists() + session.close() + assert not path.exists() From b5c4bb4aae220041032578c565463afd8517a385 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 16 Jun 2026 11:53:41 -0700 Subject: [PATCH 04/12] fix(environments): resolve D107 missing __init__ docstring and F401 unused import Add Google-style docstring to RollbackSession.__init__ to satisfy ruff D107, and remove unused `from pathlib import Path` in the test file (ruff F401). Both violations would hard-fail the pre-commit ruff hook. --- src/oumi/environments/db_isolation.py | 1 + tests/unit/environments/test_db_isolation.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oumi/environments/db_isolation.py b/src/oumi/environments/db_isolation.py index 4e0aafea9c..baa719f530 100644 --- a/src/oumi/environments/db_isolation.py +++ b/src/oumi/environments/db_isolation.py @@ -59,6 +59,7 @@ class RollbackSession: """ def __init__(self, db_path: Path | str, *, owns_file: bool = False) -> None: + """Open a per-rollout connection; set owns_file to delete the DB on close.""" self._path = Path(db_path) self._owns_file = owns_file self.connection = sqlite3.connect(self._path) diff --git a/tests/unit/environments/test_db_isolation.py b/tests/unit/environments/test_db_isolation.py index 7ced417eb9..c9962a4532 100644 --- a/tests/unit/environments/test_db_isolation.py +++ b/tests/unit/environments/test_db_isolation.py @@ -15,7 +15,6 @@ from __future__ import annotations import sqlite3 -from pathlib import Path from oumi.environments.db_isolation import RollbackSession, materialize_sqlite_snapshot From 3cc60a6507fbe57034643fa7e9c6778111194a15 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Tue, 16 Jun 2026 11:57:31 -0700 Subject: [PATCH 05/12] feat(environments): add EHR example tools/executors --- src/oumi/environments/examples/__init__.py | 15 ++++++ src/oumi/environments/examples/ehr.py | 60 ++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 src/oumi/environments/examples/__init__.py create mode 100644 src/oumi/environments/examples/ehr.py diff --git a/src/oumi/environments/examples/__init__.py b/src/oumi/environments/examples/__init__.py new file mode 100644 index 0000000000..d866b050f8 --- /dev/null +++ b/src/oumi/environments/examples/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runnable example tools/executors for executable environments.""" diff --git a/src/oumi/environments/examples/ehr.py b/src/oumi/environments/examples/ehr.py new file mode 100644 index 0000000000..4a0cadfaa4 --- /dev/null +++ b/src/oumi/environments/examples/ehr.py @@ -0,0 +1,60 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EHR example: schema, seed data, and three SQL tool executors. + +Executors take (*, arguments, db) where ``db`` is a sqlite3.Connection handed +in by DatabaseExecutableEnvironment, and return a ToolResult. They must NOT +commit: the environment owns the transaction and rolls back on close. +""" + +from __future__ import annotations + +import sqlite3 +from typing import Any + +from oumi.core.types.tool_call import ToolResult + +EHR_SCHEMA = ( + "CREATE TABLE patients (" + " id INTEGER PRIMARY KEY, name TEXT NOT NULL, meds TEXT);" +) +EHR_SEED = ( + "INSERT INTO patients (id, name, meds) VALUES" + " (1, 'Bob', 'aspirin'), (2, 'Alice', 'ibuprofen'), (3, 'Carol', NULL);" +) + + +def list_patients(*, arguments: dict[str, Any], db: sqlite3.Connection) -> ToolResult: + rows = db.execute("SELECT id, name FROM patients ORDER BY id").fetchall() + return ToolResult(output={"patients": [{"id": r[0], "name": r[1]} for r in rows]}) + + +def lookup_patient(*, arguments: dict[str, Any], db: sqlite3.Connection) -> ToolResult: + row = db.execute( + "SELECT name, meds FROM patients WHERE id = ?", (arguments["pat_id"],) + ).fetchone() + if row is None: + return ToolResult(output={"error": "not found"}) + return ToolResult(output={"name": row[0], "meds": row[1]}) + + +def update_meds(*, arguments: dict[str, Any], db: sqlite3.Connection) -> ToolResult: + # No commit: the environment rolls back at episode end. The write is + # visible to later calls on this same connection within the episode. + cur = db.execute( + "UPDATE patients SET meds = ? WHERE id = ?", + (arguments["medication"], arguments["pat_id"]), + ) + return ToolResult(output={"updated_rows": cur.rowcount}) From f32f6d618dc69908b9c7800121b49e6ed635c44f Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 17 Jun 2026 07:34:24 -0700 Subject: [PATCH 06/12] feat(environments): add DatabaseExecutableEnvironment with rollback isolation Per-rollout RollbackSession (never commits, rolls back on close) so writes are visible within an episode and never persist or leak across rollouts. Includes the isolation proof tests (write-then-read within an episode, rollback on close, no cross-rollout leak, shared snapshot never mutated). --- src/oumi/environments/__init__.py | 4 + .../database_executable_environment.py | 102 ++++++++++++ .../test_database_executable_environment.py | 151 ++++++++++++++++++ 3 files changed, 257 insertions(+) create mode 100644 src/oumi/environments/database_executable_environment.py create mode 100644 tests/unit/environments/test_database_executable_environment.py diff --git a/src/oumi/environments/__init__.py b/src/oumi/environments/__init__.py index 69abd2a286..83504760b1 100644 --- a/src/oumi/environments/__init__.py +++ b/src/oumi/environments/__init__.py @@ -32,6 +32,9 @@ ) from oumi.core.types.tool_call import JSONSchema, ToolResult from oumi.environments.base_environment import BaseEnvironment +from oumi.environments.database_executable_environment import ( + DatabaseExecutableEnvironment, +) from oumi.environments.deterministic_environment import ( DeterministicEnvironment, DeterministicEnvironmentKwargs, @@ -47,6 +50,7 @@ __all__ = [ "BaseEnvironment", + "DatabaseExecutableEnvironment", "DeterministicEnvironment", "DeterministicEnvironmentKwargs", "ExecutableEnvironment", diff --git a/src/oumi/environments/database_executable_environment.py b/src/oumi/environments/database_executable_environment.py new file mode 100644 index 0000000000..c1f4fc3f2c --- /dev/null +++ b/src/oumi/environments/database_executable_environment.py @@ -0,0 +1,102 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Executable environment backed by a rollback-isolated SQLite session.""" + +from __future__ import annotations + +import importlib +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from typing import Any, ClassVar + +from oumi.core.configs.params.environment_params import EnvironmentParams +from oumi.core.registry import register_environment +from oumi.environments.db_isolation import RollbackSession, materialize_sqlite_snapshot +from oumi.environments.executable_environment import ExecutableEnvironment +from oumi.environments.executable_tool import ExecutableTool + + +def _import_executor(dotted: str, tool_id: str) -> Callable[..., Any]: + """Resolve a dotted import path to a callable.""" + module_path, _, attr = dotted.rpartition(".") + if not module_path or not attr: + raise ValueError( + f"Tool '{tool_id}': executor '{dotted}' must be a dotted import path." + ) + module = importlib.import_module(module_path) + executor = getattr(module, attr, None) + if not callable(executor): + raise ValueError( + f"Tool '{tool_id}': executor '{dotted}' did not resolve to a callable." + ) + return executor + + +@register_environment("database") +class DatabaseExecutableEnvironment(ExecutableEnvironment): + """Runs SQL-executing tools against a rollback-isolated SQLite session. + + One instance owns one session for the duration of an episode. Executors + must NOT commit; the env rolls back on ``close()`` so writes never persist. + ``requires_isolation()`` is ``True``, so the router builds a fresh instance + (hence a fresh session) per rollout. See ``db_isolation`` for the contract. + """ + + _executor_context_kwarg: ClassVar[str] = "db" + + def __init__(self, params: EnvironmentParams, session: RollbackSession) -> None: + """Bind the env to its params and an already-open rollback session.""" + self._params = params + self._session = session + self._executors = { + tool.id: _import_executor(tool.executor, tool.id) for tool in params.tools + } + + @classmethod + def from_params(cls, params: EnvironmentParams) -> DatabaseExecutableEnvironment: + """Build the env, opening a rollback session over its configured DB.""" + kwargs = dict(params.env_kwargs or {}) + db_path = kwargs.get("db_path") + schema_sql = kwargs.get("schema_sql") + if db_path: + # Shared snapshot: connect read-side, roll back on close. + session = RollbackSession(db_path) + elif schema_sql: + # Inline: build a fresh per-rollout DB this instance owns. + snapshot = materialize_sqlite_snapshot( + schema_sql=schema_sql, seed_sql=kwargs.get("seed_sql") + ) + session = RollbackSession(snapshot, owns_file=True) + else: + raise ValueError( + f"DatabaseExecutableEnvironment '{params.id}': env_kwargs must " + f"provide either 'db_path' or 'schema_sql'." + ) + return cls(params, session) + + def requires_isolation(self) -> bool: + """Each rollout needs its own session; never share across samples.""" + return True + + @contextmanager + def _build_execution_context( + self, tool: ExecutableTool, arguments: dict[str, Any] + ) -> Iterator[Any]: + """Yield the episode's connection (uncommitted writes persist within it).""" + yield self._session.connection + + def close(self) -> None: + """Roll back the episode's writes and tear down the session.""" + self._session.close() diff --git a/tests/unit/environments/test_database_executable_environment.py b/tests/unit/environments/test_database_executable_environment.py new file mode 100644 index 0000000000..2e525d4c38 --- /dev/null +++ b/tests/unit/environments/test_database_executable_environment.py @@ -0,0 +1,151 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Behavior + rollback-isolation tests for DatabaseExecutableEnvironment.""" + +from __future__ import annotations + +from oumi.core.configs.params.environment_params import EnvironmentParams +from oumi.environments.database_executable_environment import ( + DatabaseExecutableEnvironment, +) +from oumi.environments.db_isolation import materialize_sqlite_snapshot + +_SCHEMA = "CREATE TABLE patients (id INTEGER PRIMARY KEY, name TEXT, meds TEXT);" +_SEED = "INSERT INTO patients VALUES (1, 'Bob', 'aspirin');" + + +def _params(tools): + return EnvironmentParams( + id="ehr", + name="ehr", + description="EHR test env", + env_type="database", + tools=tools, + env_kwargs={"schema_sql": _SCHEMA, "seed_sql": _SEED}, + ) + + +def _lookup_tool(): + return { + "id": "lookup", + "name": "lookup", + "description": "look up a patient", + "parameters": { + "type": "object", + "properties": {"pat_id": {"type": "integer"}}, + "required": ["pat_id"], + }, + "executor": "oumi.environments.examples.ehr.lookup_patient", + "read_only": True, + } + + +def _update_tool(): + return { + "id": "update", + "name": "update", + "description": "update meds", + "parameters": { + "type": "object", + "properties": { + "pat_id": {"type": "integer"}, + "medication": {"type": "string"}, + }, + "required": ["pat_id", "medication"], + }, + "executor": "oumi.environments.examples.ehr.update_meds", + "read_only": False, + } + + +def test_requires_isolation_is_true(): + env = DatabaseExecutableEnvironment.from_params(_params([_lookup_tool()])) + try: + assert env.requires_isolation() is True + finally: + env.close() + + +def test_executes_read_tool_against_isolated_db(): + env = DatabaseExecutableEnvironment.from_params(_params([_lookup_tool()])) + try: + [result] = env.step([("lookup", {"pat_id": 1})]) + assert result.output == {"name": "Bob", "meds": "aspirin"} + finally: + env.close() + + +def test_uncommitted_write_visible_within_one_episode(): + env = DatabaseExecutableEnvironment.from_params( + _params([_lookup_tool(), _update_tool()]) + ) + try: + env.step([("update", {"pat_id": 1, "medication": "statin"})]) + [seen] = env.step([("lookup", {"pat_id": 1})]) + assert seen.output == {"name": "Bob", "meds": "statin"} + finally: + env.close() + + +def test_close_rolls_back_so_a_fresh_env_starts_clean(): + params = _params([_lookup_tool(), _update_tool()]) + env = DatabaseExecutableEnvironment.from_params(params) + env.step([("update", {"pat_id": 1, "medication": "mutated"})]) + env.close() # rolls back; the inline-built DB is also discarded + fresh = DatabaseExecutableEnvironment.from_params(params) + try: + [seen] = fresh.step([("lookup", {"pat_id": 1})]) + assert seen.output["meds"] == "aspirin" + finally: + fresh.close() + + +def test_writes_do_not_leak_across_concurrent_rollouts(): + params = _params([_lookup_tool(), _update_tool()]) + # N rollouts of the same task; each builds its own inline DB. + envs = [DatabaseExecutableEnvironment.from_params(params) for _ in range(4)] + try: + for i, env in enumerate(envs): + env.step([("update", {"pat_id": 1, "medication": f"drug_{i}"})]) + for i, env in enumerate(envs): + [seen] = env.step([("lookup", {"pat_id": 1})]) + assert seen.output["meds"] == f"drug_{i}" + finally: + for env in envs: + env.close() + + +def test_shared_snapshot_is_never_mutated(tmp_path): + snapshot = materialize_sqlite_snapshot( + schema_sql=_SCHEMA, seed_sql=_SEED, dest=tmp_path / "shared.sqlite" + ) + params = EnvironmentParams( + id="ehr", + name="ehr", + description="d", + env_type="database", + tools=[_lookup_tool(), _update_tool()], + env_kwargs={"db_path": str(snapshot)}, + ) + env = DatabaseExecutableEnvironment.from_params(params) + env.step([("update", {"pat_id": 1, "medication": "mutated"})]) + env.close() # rollback + # The shared snapshot file is untouched. + fresh = DatabaseExecutableEnvironment.from_params(params) + try: + [seen] = fresh.step([("lookup", {"pat_id": 1})]) + assert seen.output["meds"] == "aspirin" + finally: + fresh.close() From 1bffb44a8649df1f55748faf2f12a5886bf7f713 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 17 Jun 2026 07:36:31 -0700 Subject: [PATCH 07/12] feat(environments): add sql_execution_match reward + EHR env config Execution-match reward grades candidate vs gold SQL on a fresh rollback session (reusing the env's isolation on the grading side). EHR YAML config + builder test exercise the 'bring your DB' entry point through build_environment. --- .../database_env/ehr_database_env.yaml | 34 ++++++++++ .../grpo/rewards/sql_execution_match.py | 65 +++++++++++++++++++ .../grpo/rewards/test_sql_execution_match.py | 60 +++++++++++++++++ .../environments/test_database_env_config.py | 41 ++++++++++++ 4 files changed, 200 insertions(+) create mode 100644 configs/examples/database_env/ehr_database_env.yaml create mode 100644 src/oumi/datasets/grpo/rewards/sql_execution_match.py create mode 100644 tests/unit/datasets/grpo/rewards/test_sql_execution_match.py create mode 100644 tests/unit/environments/test_database_env_config.py diff --git a/configs/examples/database_env/ehr_database_env.yaml b/configs/examples/database_env/ehr_database_env.yaml new file mode 100644 index 0000000000..64c02ed496 --- /dev/null +++ b/configs/examples/database_env/ehr_database_env.yaml @@ -0,0 +1,34 @@ +# Minimal EHR database executable environment. +# Swap schema_sql/seed_sql for `db_path: /path/to/your.sqlite` to bring your own DB. +id: ehr_demo +name: EHR demo +description: Example EHR database environment with three tools. +env_type: database +env_kwargs: + schema_sql: "CREATE TABLE patients (id INTEGER PRIMARY KEY, name TEXT, meds TEXT);" + seed_sql: "INSERT INTO patients VALUES (1,'Bob','aspirin'),(2,'Alice','ibuprofen');" +tools: + - id: list_patients + name: list_patients + description: List all patients. + parameters: {type: object, properties: {}} + executor: oumi.environments.examples.ehr.list_patients + read_only: true + - id: lookup_patient + name: lookup_patient + description: Look up one patient by id. + parameters: + type: object + properties: {pat_id: {type: integer}} + required: [pat_id] + executor: oumi.environments.examples.ehr.lookup_patient + read_only: true + - id: update_meds + name: update_meds + description: Update a patient's medication. + parameters: + type: object + properties: {pat_id: {type: integer}, medication: {type: string}} + required: [pat_id, medication] + executor: oumi.environments.examples.ehr.update_meds + read_only: false diff --git a/src/oumi/datasets/grpo/rewards/sql_execution_match.py b/src/oumi/datasets/grpo/rewards/sql_execution_match.py new file mode 100644 index 0000000000..5c72e2a149 --- /dev/null +++ b/src/oumi/datasets/grpo/rewards/sql_execution_match.py @@ -0,0 +1,65 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Execution-match reward for SQL tasks, graded on an isolated snapshot.""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path +from typing import Any + +from oumi.core.registry import RegistryType, register +from oumi.environments.db_isolation import RollbackSession, materialize_sqlite_snapshot + + +def _run(connection: sqlite3.Connection, sql: str) -> list[tuple] | None: + try: + return connection.execute(sql).fetchall() + except sqlite3.Error: + return None + + +@register("sql_execution_match", RegistryType.REWARD_FUNCTION) +def sql_execution_match( + data_source: str, + solution_str: str, + ground_truth: str, + extra_info: dict[str, Any] | None = None, + **kwargs: Any, +) -> float: + """Return 1.0 if candidate SQL's result set matches the gold's, else 0.0. + + ``extra_info`` carries the DB descriptor: either ``db_path`` or + ``schema_sql`` (+ optional ``seed_sql``). Grading runs on a rollback + session, so a candidate that writes never mutates the snapshot. + """ + info = extra_info or {} + owns = False + if info.get("db_path"): + path = Path(info["db_path"]) + else: + path = materialize_sqlite_snapshot( + schema_sql=info["schema_sql"], seed_sql=info.get("seed_sql") + ) + owns = True + session = RollbackSession(path, owns_file=owns) + try: + gold_rows = _run(session.connection, ground_truth) + cand_rows = _run(session.connection, solution_str) + finally: + session.close() + if gold_rows is None or cand_rows is None: + return 0.0 + return 1.0 if gold_rows == cand_rows else 0.0 diff --git a/tests/unit/datasets/grpo/rewards/test_sql_execution_match.py b/tests/unit/datasets/grpo/rewards/test_sql_execution_match.py new file mode 100644 index 0000000000..5d3f61c829 --- /dev/null +++ b/tests/unit/datasets/grpo/rewards/test_sql_execution_match.py @@ -0,0 +1,60 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the sql_execution_match reward.""" + +from __future__ import annotations + +from oumi.datasets.grpo.rewards.sql_execution_match import sql_execution_match + +_SCHEMA = "CREATE TABLE patients (id INTEGER PRIMARY KEY, name TEXT, age INTEGER);" +_SEED = "INSERT INTO patients VALUES (1,'Bob',50),(2,'Alice',40),(3,'Carol',65);" + + +def _extra(): + return {"schema_sql": _SCHEMA, "seed_sql": _SEED} + + +def test_exact_match_scores_one(): + gold = "SELECT name FROM patients WHERE age > 45 ORDER BY name" + candidate = "SELECT name FROM patients WHERE age >= 50 ORDER BY name" + score = sql_execution_match( + data_source="ehr", + solution_str=candidate, + ground_truth=gold, + extra_info=_extra(), + ) + assert score == 1.0 # Bob, Carol in both + + +def test_mismatch_scores_zero(): + gold = "SELECT name FROM patients WHERE age > 45 ORDER BY name" + candidate = "SELECT name FROM patients ORDER BY name" + score = sql_execution_match( + data_source="ehr", + solution_str=candidate, + ground_truth=gold, + extra_info=_extra(), + ) + assert score == 0.0 + + +def test_invalid_sql_scores_zero(): + score = sql_execution_match( + data_source="ehr", + solution_str="SELECT FROM nonsense(", + ground_truth="SELECT 1", + extra_info=_extra(), + ) + assert score == 0.0 diff --git a/tests/unit/environments/test_database_env_config.py b/tests/unit/environments/test_database_env_config.py new file mode 100644 index 0000000000..ead2bfde0c --- /dev/null +++ b/tests/unit/environments/test_database_env_config.py @@ -0,0 +1,41 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end: build the EHR database env from its YAML config.""" + +from __future__ import annotations + +from pathlib import Path + +from omegaconf import OmegaConf + +from oumi.builders.environments import build_environment +from oumi.core.configs.params.environment_params import EnvironmentParams +from oumi.environments.database_executable_environment import ( + DatabaseExecutableEnvironment, +) + +_CONFIG = Path("configs/examples/database_env/ehr_database_env.yaml") + + +def test_build_environment_from_yaml(): + raw = OmegaConf.to_container(OmegaConf.load(_CONFIG), resolve=True) + params = EnvironmentParams(**raw) + env = build_environment(params) + assert isinstance(env, DatabaseExecutableEnvironment) + try: + [result] = env.step([("lookup_patient", {"pat_id": 1})]) + assert result.output == {"name": "Bob", "meds": "aspirin"} + finally: + env.close() From 001e78bcb47c300bf7a0d4061c1818ed027c92f8 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 17 Jun 2026 07:44:16 -0700 Subject: [PATCH 08/12] fix(environments): address final review on the DB env prototype - RollbackSession opens isolation_level=None + explicit BEGIN so leading DDL (CREATE/DROP as the first statement) is also rolled back, not just DML. - sql_execution_match grades gold and candidate on separate sessions so a mutating gold query can't contaminate the candidate. - export sql_execution_match from rewards/__init__ so @register fires on package import (otherwise it's missing from the registry). - config test resolves its path relative to __file__, not CWD. - document that db_path isolation is read-concurrent only (writers contend). --- src/oumi/datasets/grpo/rewards/__init__.py | 2 ++ .../grpo/rewards/sql_execution_match.py | 14 ++++++++++---- .../database_executable_environment.py | 9 ++++++++- src/oumi/environments/db_isolation.py | 14 ++++++++++++-- .../environments/test_database_env_config.py | 3 ++- tests/unit/environments/test_db_isolation.py | 18 ++++++++++++++++++ 6 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/oumi/datasets/grpo/rewards/__init__.py b/src/oumi/datasets/grpo/rewards/__init__.py index b116fdb4b6..f28efdc6bc 100644 --- a/src/oumi/datasets/grpo/rewards/__init__.py +++ b/src/oumi/datasets/grpo/rewards/__init__.py @@ -21,6 +21,7 @@ from oumi.datasets.grpo.rewards.count_letters_rewards import compute_letter_count_reward from oumi.datasets.grpo.rewards.countdown_rewards import countdown_reward from oumi.datasets.grpo.rewards.gsm8k_reward import gsm8k_reward +from oumi.datasets.grpo.rewards.sql_execution_match import sql_execution_match __all__ = [ "compute_letter_count_reward", @@ -28,4 +29,5 @@ "compute_sharp_target_token_length_reward", "countdown_reward", "gsm8k_reward", + "sql_execution_match", ] diff --git a/src/oumi/datasets/grpo/rewards/sql_execution_match.py b/src/oumi/datasets/grpo/rewards/sql_execution_match.py index 5c72e2a149..636f09d851 100644 --- a/src/oumi/datasets/grpo/rewards/sql_execution_match.py +++ b/src/oumi/datasets/grpo/rewards/sql_execution_match.py @@ -54,12 +54,18 @@ def sql_execution_match( schema_sql=info["schema_sql"], seed_sql=info.get("seed_sql") ) owns = True - session = RollbackSession(path, owns_file=owns) + # Grade gold and candidate on separate sessions so each runs against the + # pristine snapshot — a mutating gold query can't contaminate the candidate. + gold_session = RollbackSession(path) try: - gold_rows = _run(session.connection, ground_truth) - cand_rows = _run(session.connection, solution_str) + gold_rows = _run(gold_session.connection, ground_truth) finally: - session.close() + gold_session.close() + cand_session = RollbackSession(path, owns_file=owns) + try: + cand_rows = _run(cand_session.connection, solution_str) + finally: + cand_session.close() if gold_rows is None or cand_rows is None: return 0.0 return 1.0 if gold_rows == cand_rows else 0.0 diff --git a/src/oumi/environments/database_executable_environment.py b/src/oumi/environments/database_executable_environment.py index c1f4fc3f2c..674d520d00 100644 --- a/src/oumi/environments/database_executable_environment.py +++ b/src/oumi/environments/database_executable_environment.py @@ -66,7 +66,14 @@ def __init__(self, params: EnvironmentParams, session: RollbackSession) -> None: @classmethod def from_params(cls, params: EnvironmentParams) -> DatabaseExecutableEnvironment: - """Build the env, opening a rollback session over its configured DB.""" + """Build the env, opening a rollback session over its configured DB. + + ``db_path`` shares one snapshot file across rollouts (rollback isolation, + scales to large DBs). It is safe for concurrent *readers*, but SQLite + serializes concurrent *writers* on one file, so concurrent rollouts that + write will contend — those tasks should use ``schema_sql`` (a fresh + per-rollout file) until copy-on-write isolation lands. + """ kwargs = dict(params.env_kwargs or {}) db_path = kwargs.get("db_path") schema_sql = kwargs.get("schema_sql") diff --git a/src/oumi/environments/db_isolation.py b/src/oumi/environments/db_isolation.py index baa719f530..51038db77f 100644 --- a/src/oumi/environments/db_isolation.py +++ b/src/oumi/environments/db_isolation.py @@ -59,10 +59,20 @@ class RollbackSession: """ def __init__(self, db_path: Path | str, *, owns_file: bool = False) -> None: - """Open a per-rollout connection; set owns_file to delete the DB on close.""" + """Open a per-rollout connection; set owns_file to delete the DB on close. + + Opens with ``isolation_level=None`` and an explicit ``BEGIN`` so the whole + session is one transaction. ``sqlite3``'s legacy mode only opens an + implicit transaction before DML, which would let a leading DDL statement + (e.g. ``CREATE TABLE`` as the first call) run in autocommit and escape the + rollback; the explicit ``BEGIN`` brings DDL under transaction control too. + Executors still must not call ``commit()`` — an explicit commit persists + regardless and there is no way to undo it. + """ self._path = Path(db_path) self._owns_file = owns_file - self.connection = sqlite3.connect(self._path) + self.connection = sqlite3.connect(self._path, isolation_level=None) + self.connection.execute("BEGIN") def close(self) -> None: """Roll back any open transaction, close, and delete an owned file.""" diff --git a/tests/unit/environments/test_database_env_config.py b/tests/unit/environments/test_database_env_config.py index ead2bfde0c..9d9c294181 100644 --- a/tests/unit/environments/test_database_env_config.py +++ b/tests/unit/environments/test_database_env_config.py @@ -26,7 +26,8 @@ DatabaseExecutableEnvironment, ) -_CONFIG = Path("configs/examples/database_env/ehr_database_env.yaml") +_REPO_ROOT = Path(__file__).parents[3] +_CONFIG = _REPO_ROOT / "configs/examples/database_env/ehr_database_env.yaml" def test_build_environment_from_yaml(): diff --git a/tests/unit/environments/test_db_isolation.py b/tests/unit/environments/test_db_isolation.py index c9962a4532..873405e71f 100644 --- a/tests/unit/environments/test_db_isolation.py +++ b/tests/unit/environments/test_db_isolation.py @@ -65,6 +65,24 @@ def test_two_sessions_on_one_snapshot_do_not_see_each_others_uncommitted_writes( b.close() +def test_leading_ddl_is_rolled_back(tmp_path): + # DDL as the first statement must still roll back. Legacy sqlite3 only opens + # an implicit transaction before DML, so without the explicit BEGIN a leading + # CREATE TABLE would run in autocommit and persist past close(). + path = materialize_sqlite_snapshot( + schema_sql=_SCHEMA, seed_sql=_SEED, dest=tmp_path / "seed.sqlite" + ) + session = RollbackSession(path) + session.connection.execute("CREATE TABLE leaked (x INTEGER)") + session.close() + conn = sqlite3.connect(path) + leaked = conn.execute( + "SELECT count(*) FROM sqlite_master WHERE name = 'leaked'" + ).fetchone()[0] + conn.close() + assert leaked == 0 + + def test_owned_session_deletes_its_file_on_close(tmp_path): path = materialize_sqlite_snapshot( schema_sql=_SCHEMA, dest=tmp_path / "owned.sqlite" From fb75184f19788c7f334d0ec79c77e1d246ac9ce9 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 17 Jun 2026 08:40:18 -0700 Subject: [PATCH 09/12] fix(environments): assert full dict equality in DB env tests ToolResult.output is str | dict; subscripting it tripped pyright's pre-push check. Compare the whole output dict instead. --- .../environments/test_database_executable_environment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/environments/test_database_executable_environment.py b/tests/unit/environments/test_database_executable_environment.py index 2e525d4c38..c30137d229 100644 --- a/tests/unit/environments/test_database_executable_environment.py +++ b/tests/unit/environments/test_database_executable_environment.py @@ -107,7 +107,7 @@ def test_close_rolls_back_so_a_fresh_env_starts_clean(): fresh = DatabaseExecutableEnvironment.from_params(params) try: [seen] = fresh.step([("lookup", {"pat_id": 1})]) - assert seen.output["meds"] == "aspirin" + assert seen.output == {"name": "Bob", "meds": "aspirin"} finally: fresh.close() @@ -121,7 +121,7 @@ def test_writes_do_not_leak_across_concurrent_rollouts(): env.step([("update", {"pat_id": 1, "medication": f"drug_{i}"})]) for i, env in enumerate(envs): [seen] = env.step([("lookup", {"pat_id": 1})]) - assert seen.output["meds"] == f"drug_{i}" + assert seen.output == {"name": "Bob", "meds": f"drug_{i}"} finally: for env in envs: env.close() @@ -146,6 +146,6 @@ def test_shared_snapshot_is_never_mutated(tmp_path): fresh = DatabaseExecutableEnvironment.from_params(params) try: [seen] = fresh.step([("lookup", {"pat_id": 1})]) - assert seen.output["meds"] == "aspirin" + assert seen.output == {"name": "Bob", "meds": "aspirin"} finally: fresh.close() From d1c47f7fd5cf4ef59c3353e2739937c8a900d11b Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 17 Jun 2026 08:42:13 -0700 Subject: [PATCH 10/12] fix(environments): cast OmegaConf container for EnvironmentParams kwargs --- tests/unit/environments/test_database_env_config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/environments/test_database_env_config.py b/tests/unit/environments/test_database_env_config.py index 9d9c294181..42b3762ce0 100644 --- a/tests/unit/environments/test_database_env_config.py +++ b/tests/unit/environments/test_database_env_config.py @@ -17,6 +17,7 @@ from __future__ import annotations from pathlib import Path +from typing import Any, cast from omegaconf import OmegaConf @@ -31,7 +32,10 @@ def test_build_environment_from_yaml(): - raw = OmegaConf.to_container(OmegaConf.load(_CONFIG), resolve=True) + raw = cast( + "dict[str, Any]", + OmegaConf.to_container(OmegaConf.load(_CONFIG), resolve=True), + ) params = EnvironmentParams(**raw) env = build_environment(params) assert isinstance(env, DatabaseExecutableEnvironment) From 430bec8433073462323f629e1d608ea6ab994188 Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 17 Jun 2026 08:45:08 -0700 Subject: [PATCH 11/12] style(environments): ruff format ehr example schema string --- src/oumi/environments/examples/ehr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/oumi/environments/examples/ehr.py b/src/oumi/environments/examples/ehr.py index 4a0cadfaa4..a16b265203 100644 --- a/src/oumi/environments/examples/ehr.py +++ b/src/oumi/environments/examples/ehr.py @@ -27,8 +27,7 @@ from oumi.core.types.tool_call import ToolResult EHR_SCHEMA = ( - "CREATE TABLE patients (" - " id INTEGER PRIMARY KEY, name TEXT NOT NULL, meds TEXT);" + "CREATE TABLE patients ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, meds TEXT);" ) EHR_SEED = ( "INSERT INTO patients (id, name, meds) VALUES" From d470ad891dbcb3798fdc1e7749817e372bf202eb Mon Sep 17 00:00:00 2001 From: aniruddh-alt Date: Wed, 17 Jun 2026 08:47:37 -0700 Subject: [PATCH 12/12] docs(environments): add docstrings to EHR example executors --- src/oumi/environments/examples/ehr.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/oumi/environments/examples/ehr.py b/src/oumi/environments/examples/ehr.py index a16b265203..6077fe0d10 100644 --- a/src/oumi/environments/examples/ehr.py +++ b/src/oumi/environments/examples/ehr.py @@ -36,11 +36,13 @@ def list_patients(*, arguments: dict[str, Any], db: sqlite3.Connection) -> ToolResult: + """List every patient's id and name.""" rows = db.execute("SELECT id, name FROM patients ORDER BY id").fetchall() return ToolResult(output={"patients": [{"id": r[0], "name": r[1]} for r in rows]}) def lookup_patient(*, arguments: dict[str, Any], db: sqlite3.Connection) -> ToolResult: + """Return one patient's name and meds by id, or an error if absent.""" row = db.execute( "SELECT name, meds FROM patients WHERE id = ?", (arguments["pat_id"],) ).fetchone() @@ -50,6 +52,7 @@ def lookup_patient(*, arguments: dict[str, Any], db: sqlite3.Connection) -> Tool def update_meds(*, arguments: dict[str, Any], db: sqlite3.Connection) -> ToolResult: + """Set a patient's medication (uncommitted; rolled back at episode end).""" # No commit: the environment rolls back at episode end. The write is # visible to later calls on this same connection within the episode. cur = db.execute(