|
|
from typing import Any, List, Optional, Type, Union, get_args, get_origin |
|
|
|
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from app.tool import BaseTool |
|
|
|
|
|
|
|
|
class CreateChatCompletion(BaseTool): |
|
|
name: str = "create_chat_completion" |
|
|
description: str = ( |
|
|
"Creates a structured completion with specified output formatting." |
|
|
) |
|
|
|
|
|
|
|
|
type_mapping: dict = { |
|
|
str: "string", |
|
|
int: "integer", |
|
|
float: "number", |
|
|
bool: "boolean", |
|
|
dict: "object", |
|
|
list: "array", |
|
|
} |
|
|
response_type: Optional[Type] = None |
|
|
required: List[str] = Field(default_factory=lambda: ["response"]) |
|
|
|
|
|
def __init__(self, response_type: Optional[Type] = str): |
|
|
"""Initialize with a specific response type.""" |
|
|
super().__init__() |
|
|
self.response_type = response_type |
|
|
self.parameters = self._build_parameters() |
|
|
|
|
|
def _build_parameters(self) -> dict: |
|
|
"""Build parameters schema based on response type.""" |
|
|
if self.response_type == str: |
|
|
return { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"response": { |
|
|
"type": "string", |
|
|
"description": "The response text that should be delivered to the user.", |
|
|
}, |
|
|
}, |
|
|
"required": self.required, |
|
|
} |
|
|
|
|
|
if isinstance(self.response_type, type) and issubclass( |
|
|
self.response_type, BaseModel |
|
|
): |
|
|
schema = self.response_type.model_json_schema() |
|
|
return { |
|
|
"type": "object", |
|
|
"properties": schema["properties"], |
|
|
"required": schema.get("required", self.required), |
|
|
} |
|
|
|
|
|
return self._create_type_schema(self.response_type) |
|
|
|
|
|
def _create_type_schema(self, type_hint: Type) -> dict: |
|
|
"""Create a JSON schema for the given type.""" |
|
|
origin = get_origin(type_hint) |
|
|
args = get_args(type_hint) |
|
|
|
|
|
|
|
|
if origin is None: |
|
|
return { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"response": { |
|
|
"type": self.type_mapping.get(type_hint, "string"), |
|
|
"description": f"Response of type {type_hint.__name__}", |
|
|
} |
|
|
}, |
|
|
"required": self.required, |
|
|
} |
|
|
|
|
|
|
|
|
if origin is list: |
|
|
item_type = args[0] if args else Any |
|
|
return { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"response": { |
|
|
"type": "array", |
|
|
"items": self._get_type_info(item_type), |
|
|
} |
|
|
}, |
|
|
"required": self.required, |
|
|
} |
|
|
|
|
|
|
|
|
if origin is dict: |
|
|
value_type = args[1] if len(args) > 1 else Any |
|
|
return { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"response": { |
|
|
"type": "object", |
|
|
"additionalProperties": self._get_type_info(value_type), |
|
|
} |
|
|
}, |
|
|
"required": self.required, |
|
|
} |
|
|
|
|
|
|
|
|
if origin is Union: |
|
|
return self._create_union_schema(args) |
|
|
|
|
|
return self._build_parameters() |
|
|
|
|
|
def _get_type_info(self, type_hint: Type) -> dict: |
|
|
"""Get type information for a single type.""" |
|
|
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel): |
|
|
return type_hint.model_json_schema() |
|
|
|
|
|
return { |
|
|
"type": self.type_mapping.get(type_hint, "string"), |
|
|
"description": f"Value of type {getattr(type_hint, '__name__', 'any')}", |
|
|
} |
|
|
|
|
|
def _create_union_schema(self, types: tuple) -> dict: |
|
|
"""Create schema for Union types.""" |
|
|
return { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"response": {"anyOf": [self._get_type_info(t) for t in types]} |
|
|
}, |
|
|
"required": self.required, |
|
|
} |
|
|
|
|
|
async def execute(self, required: list | None = None, **kwargs) -> Any: |
|
|
"""Execute the chat completion with type conversion. |
|
|
|
|
|
Args: |
|
|
required: List of required field names or None |
|
|
**kwargs: Response data |
|
|
|
|
|
Returns: |
|
|
Converted response based on response_type |
|
|
""" |
|
|
required = required or self.required |
|
|
|
|
|
|
|
|
if isinstance(required, list) and len(required) > 0: |
|
|
if len(required) == 1: |
|
|
required_field = required[0] |
|
|
result = kwargs.get(required_field, "") |
|
|
else: |
|
|
|
|
|
return {field: kwargs.get(field, "") for field in required} |
|
|
else: |
|
|
required_field = "response" |
|
|
result = kwargs.get(required_field, "") |
|
|
|
|
|
|
|
|
if self.response_type == str: |
|
|
return result |
|
|
|
|
|
if isinstance(self.response_type, type) and issubclass( |
|
|
self.response_type, BaseModel |
|
|
): |
|
|
return self.response_type(**kwargs) |
|
|
|
|
|
if get_origin(self.response_type) in (list, dict): |
|
|
return result |
|
|
|
|
|
try: |
|
|
return self.response_type(result) |
|
|
except (ValueError, TypeError): |
|
|
return result |
|
|
|