Skip to content
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ea3e5f2
chore(deps): add sqlalchemy and testcontainers for database env
aniruddh-alt May 8, 2026
53f2904
feat(configs): add DatabaseConnectionConfig with structured/DSN modes
aniruddh-alt May 8, 2026
1674ecf
feat(environments): add ExecutableTool base class
aniruddh-alt May 8, 2026
24c6468
feat(environments): add DatabaseExecutableTool with per-tool timeout
aniruddh-alt May 8, 2026
ff7bfe9
feat(environments): add ExecutableEnvironment abstract base class
aniruddh-alt May 8, 2026
04937c3
feat(environments): add DatabaseExecutableEnvironment skeleton with f…
aniruddh-alt May 8, 2026
7cabe53
test(environments): cover DatabaseExecutableEnvironment happy-path step
aniruddh-alt May 8, 2026
6e5e5cd
feat(environments): install dialect-specific read-only and timeout gu…
aniruddh-alt May 8, 2026
d378f72
feat(environments): auto-wrap DBAPIError and reject updated_state
aniruddh-alt May 8, 2026
3a32191
feat(environments): add per-tool statement_timeout override with chec…
aniruddh-alt May 8, 2026
a100031
feat(environments): cross-validate read_only and per-tool timeout aga…
aniruddh-alt May 8, 2026
c770b4e
feat(environments): add per-tool-call audit logging (opt-in)
aniruddh-alt May 8, 2026
1cbe44e
feat(environments): export DatabaseExecutableEnvironment from package…
aniruddh-alt May 8, 2026
a26484c
feat(examples): add EHR DB schema and seed fixture
aniruddh-alt May 8, 2026
d2005ea
feat(examples): add EHR DB SQL executors with constraint-aware error …
aniruddh-alt May 8, 2026
75523ae
feat(configs): add EHR DB synthesis example config
aniruddh-alt May 8, 2026
e73be5b
test(e2e): add SQLite-backed EHR DB synthesis e2e test
aniruddh-alt May 8, 2026
76c2790
test(integration): add Postgres-gated DatabaseExecutableEnvironment t…
aniruddh-alt May 8, 2026
3c1c866
style: ruff/lint fixes for new database executable env files
aniruddh-alt May 8, 2026
3ffc91c
fix(environments): satisfy pre-commit pyright on test type-narrowing
aniruddh-alt May 8, 2026
92df591
refactor: remove AI slop from database executable env files
aniruddh-alt May 8, 2026
3a26c48
Merge branch 'main' into aniruddh-alt/db-executable-env
aniruddh-alt May 11, 2026
7134081
Merge branch 'main' into aniruddh-alt/db-executable-env
aniruddh-alt May 12, 2026
1176b5b
Merge branch 'main' into aniruddh-alt/db-executable-env
aniruddh-alt May 12, 2026
181ccdc
Merge branch 'main' into aniruddh-alt/db-executable-env
aniruddh-alt May 14, 2026
35a57ea
fix(environments): read SQLSTATE from psycopg2 pgcode too
aniruddh-alt May 15, 2026
382b7c7
Merge branch 'main' into aniruddh-alt/db-executable-env
aniruddh-alt May 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions configs/examples/synthesis/ehr_db_synth.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# EHR Database Agent — Tool-Use Conversation Synthesis (real DB)
#
# Six tools that read or mutate live SQL state via SQLAlchemy. The DB *is*
# the state — nothing lives in memory.
#
# Setup before `oumi synth`:
# 1. Create the demo DB and load schema + seed:
# sqlite3 /tmp/ehr_demo.db < src/oumi/examples/ehr_db/schema.sql
# sqlite3 /tmp/ehr_demo.db < src/oumi/examples/ehr_db/seed.sql
# 2. Set env var (or change `database:` below):
# export EHR_DB_PATH=/tmp/ehr_demo.db
# 3. Set OPENAI_API_KEY (or change inference_config).
#
# Usage:
# oumi synth -c configs/examples/synthesis/ehr_db_synth.yaml
#
# See also:
# - Executors: src/oumi/examples/ehr_db/executors.py
# - Schema: src/oumi/examples/ehr_db/schema.sql
# - Seed: src/oumi/examples/ehr_db/seed.sql

strategy: GENERAL
num_samples: 50
output_path: ehr_db_dataset.jsonl

environment_config:
environments:
- id: ehr_db
env_type: database
name: EHR Database
description: A clinical EHR database backed by real SQL.
env_kwargs:
connection:
driver: sqlite
database: ${oc.env:EHR_DB_PATH,/tmp/ehr_demo.db}
pool_size: 10
pool_max_overflow: 10
read_only: false
audit: true
tools:
- id: list_patients
name: list_patients
description: >-
Return a list of patient summaries (patient_id, name, dob, status).
Call this first when the clinician refers to a patient by NAME so
you can resolve the name to a patient_id.
parameters:
type: object
properties: {}
read_only: true
executor: oumi.examples.ehr_db.executors.list_patients

- id: get_patient
name: get_patient
description: >-
Return the full record for a patient_id, including allergies,
active medications, diagnoses, and vitals history.
parameters:
type: object
properties:
patient_id: { type: string }
required: [patient_id]
read_only: true
executor: oumi.examples.ehr_db.executors.get_patient

- id: record_vitals
name: record_vitals
description: >-
Append a vitals reading (timestamp, blood pressure, heart rate,
temperature in F) to a patient's vitals_history.
parameters:
type: object
properties:
patient_id: { type: string }
timestamp: { type: string }
bp: { type: string }
hr: { type: integer }
temp_f: { type: number }
required: [patient_id, timestamp, bp, hr, temp_f]
read_only: false
executor: oumi.examples.ehr_db.executors.record_vitals

- id: add_diagnosis
name: add_diagnosis
description: >-
Append a new diagnosis (ICD-10 code, description, date) to a
patient. Returns an error if the same code is already present.
parameters:
type: object
properties:
patient_id: { type: string }
code: { type: string }
description: { type: string }
date: { type: string }
required: [patient_id, code, description, date]
read_only: false
executor: oumi.examples.ehr_db.executors.add_diagnosis

- id: prescribe_medication
name: prescribe_medication
description: >-
Add an active medication to a patient's medication list. Returns
an error if the same medication is already prescribed or if it
conflicts with a known allergy.
parameters:
type: object
properties:
patient_id: { type: string }
name: { type: string }
dose: { type: string }
required: [patient_id, name, dose]
read_only: false
executor: oumi.examples.ehr_db.executors.prescribe_medication

- id: update_allergies
name: update_allergies
description: >-
Replace a patient's allergy list with the supplied list. Pass an
empty list to clear all allergies.
parameters:
type: object
properties:
patient_id: { type: string }
allergies:
type: array
items: { type: string }
required: [patient_id, allergies]
read_only: false
executor: oumi.examples.ehr_db.executors.update_allergies

inference_config:
model:
model_name: gpt-4o-mini
engine: OPENAI
generation:
max_new_tokens: 2048
temperature: 0.7
top_p: 0.9
remote_params:
num_workers: 8
politeness_policy: 60
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ dependencies = [
"responses>=0.25,<0.27",
"safetensors>=0.6,<0.8",
"skypilot>=0.11.1,<0.13", # 0.11.1 includes db locking fix
"sqlalchemy>=2.0,<3.0", # Used by DatabaseExecutableEnvironment
"tensorboard>=2.20,<2.21", # Optional, for monitoring training
"tiktoken>=0.7,<1.0",
"torch>=2.6,<2.11.0", # vllm only supports up to torch==2.7
Expand Down Expand Up @@ -108,6 +109,7 @@ dev = [
"pytest",
"responses",
"ruff",
"testcontainers>=4.0,<5.0", # Postgres fixture for opt-in integration tests
"torchfix", # Tool for automatically fixing common PyTorch issues
]
docs = [
Expand Down Expand Up @@ -342,6 +344,7 @@ markers = [
"e2e_eternal: Extremely slow e2e integration tests (for manual/selective runs)",
"single_gpu: The test uses max 1 GPU (can be potentially skipped on multi-GPU machine to conserve GPU resources)",
"multi_gpu: The test should run on a machine with multiple GPU-s",
"requires_postgres: marks tests that need a real Postgres instance (run with -m requires_postgres)",
]

[tool.coverage.run]
Expand Down
111 changes: 111 additions & 0 deletions src/oumi/core/configs/params/database_connection_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Connection parameters for DatabaseExecutableEnvironment."""

from __future__ import annotations

import os
from dataclasses import dataclass

import sqlalchemy
import sqlalchemy.engine

from oumi.core.configs.params.base_params import BaseParams


@dataclass
class DatabaseConnectionConfig(BaseParams):
"""SQLAlchemy connection config for DatabaseExecutableEnvironment.

Two modes (mutually exclusive):

- **Structured fields:** populate ``driver`` and ``database`` (and host/port/
username as needed). The password is resolved at ``resolve_url()`` time
from ``password_env_var`` so it never appears in YAML.
- **DSN env var:** populate ``dsn_env_var`` with the name of an env var
holding a complete SQLAlchemy URL.
"""

# Mode 1: structured fields
driver: str = "" # e.g. "postgresql+psycopg", "mysql+pymysql", "sqlite"
host: str = ""
port: int | None = None
database: str = ""
username: str = ""
password_env_var: str = "" # name of env var holding the password

# Mode 2: full DSN from env (escape hatch)
dsn_env_var: str = "" # mutually exclusive with structured fields

# Pool / timeouts
pool_size: int = 5
pool_max_overflow: int = 10
pool_pre_ping: bool = True
connect_timeout_s: float = 10.0

def __finalize_and_validate__(self) -> None:
"""Validate mode XOR (structured vs DSN) and pool/timeout numeric bounds."""
structured = bool(self.driver and self.database)
dsn = bool(self.dsn_env_var)
if structured and dsn:
raise ValueError(
"DatabaseConnectionConfig: set EITHER (driver, database, ...) "
"OR dsn_env_var, not both."
)
if not structured and not dsn:
raise ValueError(
"DatabaseConnectionConfig: must set either (driver + database) "
"OR dsn_env_var. Got neither."
)
if self.pool_size < 1:
raise ValueError(
f"DatabaseConnectionConfig.pool_size must be >= 1, "
f"got {self.pool_size}."
)
if self.pool_max_overflow < 0:
raise ValueError(
f"DatabaseConnectionConfig.pool_max_overflow must be >= 0, "
f"got {self.pool_max_overflow}."
)
if self.connect_timeout_s <= 0:
raise ValueError(
f"DatabaseConnectionConfig.connect_timeout_s must be > 0, "
f"got {self.connect_timeout_s}."
)

def resolve_url(self) -> sqlalchemy.URL:
"""Build the SQLAlchemy URL.

Reads ``password_env_var`` / ``dsn_env_var`` from the environment at
call time (not at YAML-parse time), so configs are safe to log/dump.
"""
if self.dsn_env_var:
dsn = os.environ.get(self.dsn_env_var)
if not dsn:
raise ValueError(
f"DatabaseConnectionConfig: env var '{self.dsn_env_var}' "
f"is not set; cannot build connection URL."
)
return sqlalchemy.engine.make_url(dsn)
password = (
os.environ.get(self.password_env_var) if self.password_env_var else None
)
return sqlalchemy.URL.create(
drivername=self.driver,
host=self.host or None,
port=self.port,
database=self.database,
username=self.username or None,
password=password,
)
16 changes: 16 additions & 0 deletions src/oumi/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
concrete environment's `@register_environment(...)` decorator.
"""

from oumi.core.configs.params.database_connection_params import (
DatabaseConnectionConfig,
)
from oumi.core.configs.params.grounding_params import (
GroundingConfig,
GroundingFact,
Expand All @@ -31,11 +34,18 @@
)
from oumi.core.types.tool_call import JSONSchema, ToolResult
from oumi.environments.base_environment import BaseEnvironment
from oumi.environments.database_executable_environment import (
DatabaseExecutableEnvironment,
DatabaseExecutableEnvironmentKwargs,
)
from oumi.environments.database_executable_tool import DatabaseExecutableTool
from oumi.environments.deterministic_environment import (
DeterministicEnvironment,
DeterministicEnvironmentKwargs,
ToolLookupEntry,
)
from oumi.environments.executable_environment import ExecutableEnvironment
from oumi.environments.executable_tool import ExecutableTool
from oumi.environments.synthetic_environment import (
SyntheticEnvironment,
SyntheticEnvironmentKwargs,
Expand All @@ -44,8 +54,14 @@

__all__ = [
"BaseEnvironment",
"DatabaseConnectionConfig",
"DatabaseExecutableEnvironment",
"DatabaseExecutableEnvironmentKwargs",
"DatabaseExecutableTool",
"DeterministicEnvironment",
"DeterministicEnvironmentKwargs",
"ExecutableEnvironment",
"ExecutableTool",
"GroundingConfig",
"GroundingFact",
"JSONSchema",
Expand Down
Loading
Loading