Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions configs/examples/database_env/ehr_database_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Minimal EHR database executable environment.
# Swap schema_sql/seed_sql for `db_path: /path/to/your.sqlite` to bring your own DB.
id: ehr_demo
name: EHR demo
description: Example EHR database environment with three tools.
env_type: database
env_kwargs:
schema_sql: "CREATE TABLE patients (id INTEGER PRIMARY KEY, name TEXT, meds TEXT);"
seed_sql: "INSERT INTO patients VALUES (1,'Bob','aspirin'),(2,'Alice','ibuprofen');"
tools:
- id: list_patients
name: list_patients
description: List all patients.
parameters: {type: object, properties: {}}
executor: oumi.environments.examples.ehr.list_patients
read_only: true
- id: lookup_patient
name: lookup_patient
description: Look up one patient by id.
parameters:
type: object
properties: {pat_id: {type: integer}}
required: [pat_id]
executor: oumi.environments.examples.ehr.lookup_patient
read_only: true
- id: update_meds
name: update_meds
description: Update a patient's medication.
parameters:
type: object
properties: {pat_id: {type: integer}, medication: {type: string}}
required: [pat_id, medication]
executor: oumi.environments.examples.ehr.update_meds
read_only: false
2 changes: 2 additions & 0 deletions src/oumi/datasets/grpo/rewards/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from oumi.datasets.grpo.rewards.count_letters_rewards import compute_letter_count_reward
from oumi.datasets.grpo.rewards.countdown_rewards import countdown_reward
from oumi.datasets.grpo.rewards.gsm8k_reward import gsm8k_reward
from oumi.datasets.grpo.rewards.sql_execution_match import sql_execution_match

__all__ = [
"compute_letter_count_reward",
"compute_soft_target_token_length_reward",
"compute_sharp_target_token_length_reward",
"countdown_reward",
"gsm8k_reward",
"sql_execution_match",
]
71 changes: 71 additions & 0 deletions src/oumi/datasets/grpo/rewards/sql_execution_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Execution-match reward for SQL tasks, graded on an isolated snapshot."""

from __future__ import annotations

import sqlite3
from pathlib import Path
from typing import Any

from oumi.core.registry import RegistryType, register
from oumi.environments.db_isolation import RollbackSession, materialize_sqlite_snapshot


def _run(connection: sqlite3.Connection, sql: str) -> list[tuple] | None:
try:
return connection.execute(sql).fetchall()
except sqlite3.Error:
return None


@register("sql_execution_match", RegistryType.REWARD_FUNCTION)
def sql_execution_match(
data_source: str,
solution_str: str,
ground_truth: str,
extra_info: dict[str, Any] | None = None,
**kwargs: Any,
) -> float:
"""Return 1.0 if candidate SQL's result set matches the gold's, else 0.0.

``extra_info`` carries the DB descriptor: either ``db_path`` or
``schema_sql`` (+ optional ``seed_sql``). Grading runs on a rollback
session, so a candidate that writes never mutates the snapshot.
"""
info = extra_info or {}
owns = False
if info.get("db_path"):
path = Path(info["db_path"])
else:
path = materialize_sqlite_snapshot(
schema_sql=info["schema_sql"], seed_sql=info.get("seed_sql")
)
owns = True
# Grade gold and candidate on separate sessions so each runs against the
# pristine snapshot — a mutating gold query can't contaminate the candidate.
gold_session = RollbackSession(path)
try:
gold_rows = _run(gold_session.connection, ground_truth)
finally:
gold_session.close()
cand_session = RollbackSession(path, owns_file=owns)
try:
cand_rows = _run(cand_session.connection, solution_str)
finally:
cand_session.close()
if gold_rows is None or cand_rows is None:
return 0.0
return 1.0 if gold_rows == cand_rows else 0.0
4 changes: 4 additions & 0 deletions src/oumi/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
)
from oumi.core.types.tool_call import JSONSchema, ToolResult
from oumi.environments.base_environment import BaseEnvironment
from oumi.environments.database_executable_environment import (
DatabaseExecutableEnvironment,
)
from oumi.environments.deterministic_environment import (
DeterministicEnvironment,
DeterministicEnvironmentKwargs,
Expand All @@ -47,6 +50,7 @@

__all__ = [
"BaseEnvironment",
"DatabaseExecutableEnvironment",
"DeterministicEnvironment",
"DeterministicEnvironmentKwargs",
"ExecutableEnvironment",
Expand Down
109 changes: 109 additions & 0 deletions src/oumi/environments/database_executable_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Executable environment backed by a rollback-isolated SQLite session."""

from __future__ import annotations

import importlib
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any, ClassVar

from oumi.core.configs.params.environment_params import EnvironmentParams
from oumi.core.registry import register_environment
from oumi.environments.db_isolation import RollbackSession, materialize_sqlite_snapshot
from oumi.environments.executable_environment import ExecutableEnvironment
from oumi.environments.executable_tool import ExecutableTool


def _import_executor(dotted: str, tool_id: str) -> Callable[..., Any]:
"""Resolve a dotted import path to a callable."""
module_path, _, attr = dotted.rpartition(".")
if not module_path or not attr:
raise ValueError(
f"Tool '{tool_id}': executor '{dotted}' must be a dotted import path."
)
module = importlib.import_module(module_path)
executor = getattr(module, attr, None)
if not callable(executor):
raise ValueError(
f"Tool '{tool_id}': executor '{dotted}' did not resolve to a callable."
)
return executor


@register_environment("database")
class DatabaseExecutableEnvironment(ExecutableEnvironment):
"""Runs SQL-executing tools against a rollback-isolated SQLite session.

One instance owns one session for the duration of an episode. Executors
must NOT commit; the env rolls back on ``close()`` so writes never persist.
``requires_isolation()`` is ``True``, so the router builds a fresh instance
(hence a fresh session) per rollout. See ``db_isolation`` for the contract.
"""

_executor_context_kwarg: ClassVar[str] = "db"

def __init__(self, params: EnvironmentParams, session: RollbackSession) -> None:
"""Bind the env to its params and an already-open rollback session."""
self._params = params
self._session = session
self._executors = {
tool.id: _import_executor(tool.executor, tool.id) for tool in params.tools
}

@classmethod
def from_params(cls, params: EnvironmentParams) -> DatabaseExecutableEnvironment:
"""Build the env, opening a rollback session over its configured DB.

``db_path`` shares one snapshot file across rollouts (rollback isolation,
scales to large DBs). It is safe for concurrent *readers*, but SQLite
serializes concurrent *writers* on one file, so concurrent rollouts that
write will contend — those tasks should use ``schema_sql`` (a fresh
per-rollout file) until copy-on-write isolation lands.
"""
kwargs = dict(params.env_kwargs or {})
db_path = kwargs.get("db_path")
schema_sql = kwargs.get("schema_sql")
if db_path:
# Shared snapshot: connect read-side, roll back on close.
session = RollbackSession(db_path)
elif schema_sql:
# Inline: build a fresh per-rollout DB this instance owns.
snapshot = materialize_sqlite_snapshot(
schema_sql=schema_sql, seed_sql=kwargs.get("seed_sql")
)
session = RollbackSession(snapshot, owns_file=True)
else:
raise ValueError(
f"DatabaseExecutableEnvironment '{params.id}': env_kwargs must "
f"provide either 'db_path' or 'schema_sql'."
)
return cls(params, session)

def requires_isolation(self) -> bool:
"""Each rollout needs its own session; never share across samples."""
return True

@contextmanager
def _build_execution_context(
self, tool: ExecutableTool, arguments: dict[str, Any]
) -> Iterator[Any]:
"""Yield the episode's connection (uncommitted writes persist within it)."""
yield self._session.connection

def close(self) -> None:
"""Roll back the episode's writes and tear down the session."""
self._session.close()
84 changes: 84 additions & 0 deletions src/oumi/environments/db_isolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Rollback-based SQLite isolation for per-rollout database environments.

The SkyRL pattern: each rollout gets its own connection that opens a
transaction and never commits, so uncommitted writes are visible within the
rollout and discarded on close. The environment owns the transaction;
executors must not call ``commit()``.
"""

from __future__ import annotations

import sqlite3
import tempfile
import uuid
from pathlib import Path


def materialize_sqlite_snapshot(
*,
schema_sql: str,
seed_sql: str | None = None,
dest: Path | str | None = None,
) -> Path:
"""Build a snapshot SQLite file from DDL (+ optional seed INSERTs)."""
path = (
Path(dest)
if dest is not None
else Path(tempfile.gettempdir()) / f"oumi_snapshot_{uuid.uuid4().hex}.sqlite"
)
conn = sqlite3.connect(path)
try:
conn.executescript(schema_sql)
if seed_sql:
conn.executescript(seed_sql)
conn.commit()
finally:
conn.close()
return path


class RollbackSession:
"""A per-rollout SQLite connection that never commits and rolls back on close.

Set ``owns_file=True`` when the env built a throwaway per-rollout database
that should be deleted on teardown (as opposed to a shared snapshot).
"""

def __init__(self, db_path: Path | str, *, owns_file: bool = False) -> None:
"""Open a per-rollout connection; set owns_file to delete the DB on close.

Opens with ``isolation_level=None`` and an explicit ``BEGIN`` so the whole
session is one transaction. ``sqlite3``'s legacy mode only opens an
implicit transaction before DML, which would let a leading DDL statement
(e.g. ``CREATE TABLE`` as the first call) run in autocommit and escape the
rollback; the explicit ``BEGIN`` brings DDL under transaction control too.
Executors still must not call ``commit()`` — an explicit commit persists
regardless and there is no way to undo it.
"""
self._path = Path(db_path)
self._owns_file = owns_file
self.connection = sqlite3.connect(self._path, isolation_level=None)
self.connection.execute("BEGIN")

def close(self) -> None:
"""Roll back any open transaction, close, and delete an owned file."""
try:
self.connection.rollback()
finally:
self.connection.close()
if self._owns_file:
self._path.unlink(missing_ok=True)
15 changes: 15 additions & 0 deletions src/oumi/environments/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runnable example tools/executors for executable environments."""
Loading
Loading