diff --git a/src/bespokelabs/curator/request_processor/batch/gemini_batch_request_processor.py b/src/bespokelabs/curator/request_processor/batch/gemini_batch_request_processor.py index 694e88b9..5abec81f 100644 --- a/src/bespokelabs/curator/request_processor/batch/gemini_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/batch/gemini_batch_request_processor.py @@ -1,5 +1,6 @@ import copy import json +import mimetypes import os import typing as t from functools import lru_cache @@ -228,8 +229,8 @@ def create_api_specific_request_batch(self, generic_request: GenericRequest) -> """Creates an API-specific request body from a generic request. Transforms a GenericRequest into the format expected by Gemini's batch API. - Combines and constructs a system message with schema and instructions using - the instructor package for JSON response formatting. + Handles multi-modal inputs by parsing the dictionary structure created by the + curator framework, which combines text and images into a single message. Args: generic_request: The generic request object containing model, messages, @@ -239,17 +240,57 @@ def create_api_specific_request_batch(self, generic_request: GenericRequest) -> dict: API specific request body formatted for Gemini's batch API, including custom_id and request parameters. """ - contents = [] + parts = [] + # The prompt() function returns (text, image), which curator combines into + # a single message with a dictionary as its content. We need to parse this dict. for message in generic_request.messages: - contents.append({"role": message["role"], "parts": [{"text": message["content"]}]}) - request_object = {"contents": contents} + content = message["content"] + + # Primary Path: Handle the multimodal dictionary from curator + if isinstance(content, dict) and ("texts" in content or "images" in content): + # Add all text parts from the 'texts' list + for text_item in content.get("texts", []): + if isinstance(text_item, str): + parts.append({"text": text_item}) + + # Add all image parts from the 'images' list + for image_item in content.get("images", []): + # image_item is a dict like {'url': '...', 'mime_type': '...'} + url = image_item.get("url") + if not url: + continue # Skip if there's no URL + + mime_type = image_item.get("mime_type") + if not mime_type: + # Fallback to guessing the mime type if not provided + mime_type, _ = mimetypes.guess_type(url) + if not mime_type: + logger.warning(f"Could not determine MIME type for {url}. Defaulting to 'image/png'.") + mime_type = "image/png" + + parts.append({"fileData": {"fileUri": url, "mimeType": mime_type}}) + + # Fallback for simple text-only messages + elif isinstance(content, str): + parts.append({"text": content}) + + # Fallback for any other unexpected type + else: + logger.warning(f"Unsupported message content type: {type(content)}. Converting to string.") + parts.append({"text": str(content)}) + + # For multi-modal requests, Gemini expects a single 'contents' object + # with a 'parts' list. The role is typically 'user'. + request_object = {"contents": [{"role": "user", "parts": parts}]} + if generic_request.response_format: - request_object.update( + # Ensure generationConfig exists before trying to update it + if "generationConfig" not in request_object: + request_object["generationConfig"] = {} + request_object["generationConfig"].update( { - "generationConfig": { - "responseMimeType": "application/json", - "responseSchema": _response_format_to_json(self.prompt_formatter.response_format), - } + "responseMimeType": "application/json", + "responseSchema": _response_format_to_json(self.prompt_formatter.response_format), } )