File size: 5,621 Bytes
88f3fce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from typing import Any, List, Optional, Type, Union, get_args, get_origin
from pydantic import BaseModel, Field
from app.tool import BaseTool
class CreateChatCompletion(BaseTool):
name: str = "create_chat_completion"
description: str = (
"Creates a structured completion with specified output formatting."
)
# Type mapping for JSON schema
type_mapping: dict = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
dict: "object",
list: "array",
}
response_type: Optional[Type] = None
required: List[str] = Field(default_factory=lambda: ["response"])
def __init__(self, response_type: Optional[Type] = str):
"""Initialize with a specific response type."""
super().__init__()
self.response_type = response_type
self.parameters = self._build_parameters()
def _build_parameters(self) -> dict:
"""Build parameters schema based on response type."""
if self.response_type == str:
return {
"type": "object",
"properties": {
"response": {
"type": "string",
"description": "The response text that should be delivered to the user.",
},
},
"required": self.required,
}
if isinstance(self.response_type, type) and issubclass(
self.response_type, BaseModel
):
schema = self.response_type.model_json_schema()
return {
"type": "object",
"properties": schema["properties"],
"required": schema.get("required", self.required),
}
return self._create_type_schema(self.response_type)
def _create_type_schema(self, type_hint: Type) -> dict:
"""Create a JSON schema for the given type."""
origin = get_origin(type_hint)
args = get_args(type_hint)
# Handle primitive types
if origin is None:
return {
"type": "object",
"properties": {
"response": {
"type": self.type_mapping.get(type_hint, "string"),
"description": f"Response of type {type_hint.__name__}",
}
},
"required": self.required,
}
# Handle List type
if origin is list:
item_type = args[0] if args else Any
return {
"type": "object",
"properties": {
"response": {
"type": "array",
"items": self._get_type_info(item_type),
}
},
"required": self.required,
}
# Handle Dict type
if origin is dict:
value_type = args[1] if len(args) > 1 else Any
return {
"type": "object",
"properties": {
"response": {
"type": "object",
"additionalProperties": self._get_type_info(value_type),
}
},
"required": self.required,
}
# Handle Union type
if origin is Union:
return self._create_union_schema(args)
return self._build_parameters()
def _get_type_info(self, type_hint: Type) -> dict:
"""Get type information for a single type."""
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
return type_hint.model_json_schema()
return {
"type": self.type_mapping.get(type_hint, "string"),
"description": f"Value of type {getattr(type_hint, '__name__', 'any')}",
}
def _create_union_schema(self, types: tuple) -> dict:
"""Create schema for Union types."""
return {
"type": "object",
"properties": {
"response": {"anyOf": [self._get_type_info(t) for t in types]}
},
"required": self.required,
}
async def execute(self, required: list | None = None, **kwargs) -> Any:
"""Execute the chat completion with type conversion.
Args:
required: List of required field names or None
**kwargs: Response data
Returns:
Converted response based on response_type
"""
required = required or self.required
# Handle case when required is a list
if isinstance(required, list) and len(required) > 0:
if len(required) == 1:
required_field = required[0]
result = kwargs.get(required_field, "")
else:
# Return multiple fields as a dictionary
return {field: kwargs.get(field, "") for field in required}
else:
required_field = "response"
result = kwargs.get(required_field, "")
# Type conversion logic
if self.response_type == str:
return result
if isinstance(self.response_type, type) and issubclass(
self.response_type, BaseModel
):
return self.response_type(**kwargs)
if get_origin(self.response_type) in (list, dict):
return result # Assuming result is already in correct format
try:
return self.response_type(result)
except (ValueError, TypeError):
return result
|