diff --git a/src/bespokelabs/curator/llm/llm.py b/src/bespokelabs/curator/llm/llm.py index a43c14e0..833e16cc 100644 --- a/src/bespokelabs/curator/llm/llm.py +++ b/src/bespokelabs/curator/llm/llm.py @@ -78,6 +78,7 @@ def __init__( generation_params: dict | None = None, backend_params: BackendParamsType | None = None, system_prompt: str | None = None, + default_app_id: Optional[str] = None, ): """Initialize a LLM. @@ -117,6 +118,7 @@ def __init__( - gpu_memory_utilization: The GPU memory utilization to use for the VLLM backend - batch_size: The size of the batch to use, only used if batch is True system_prompt: The system prompt to use for the LLM + default_app_id: The default application ID to use when opening datasets in Curator Viewer """ generation_params = generation_params or {} @@ -143,6 +145,8 @@ def __init__( return_completions_object=self.return_completions_object, ) + self.default_app_id = default_app_id + def _hash_fingerprint(self, dataset_hash: str = "", disable_cache: bool = False): if disable_cache: fingerprint = xxh64(os.urandom(8)).hexdigest() @@ -262,6 +266,10 @@ def __call__( "run_hash": fingerprint, "batch_mode": self.batch_mode, } + + # Only include default_app_id in metadata dictionary if it's not None + if self.default_app_id is not None: + metadata_dict["default_app_id"] = self.default_app_id existing_session_id = metadata_db.get_existing_session_id(metadata_dict["run_hash"]) existing_viewer_sync = metadata_db.check_existing_hosted_sync(metadata_dict["run_hash"])