Song Yi commited on
Commit
ca8c2ab
·
verified ·
1 Parent(s): ca016db

Create api_server.py

Browse files
Files changed (1) hide show
  1. api_server.py +355 -0
api_server.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kirim-1-Math API Server
3
+ FastAPI-based REST API for mathematical reasoning
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel, Field
9
+ from typing import List, Dict, Optional, Any
10
+ import uvicorn
11
+ import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ import json
14
+ import logging
15
+ from datetime import datetime
16
+ import asyncio
17
+ from inference_math import KirimMath, MathToolExecutor
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Initialize FastAPI app
27
+ app = FastAPI(
28
+ title="Kirim-1-Math API",
29
+ description="Advanced Mathematical Reasoning API with Tool Calling",
30
+ version="1.0.0"
31
+ )
32
+
33
+ # Add CORS middleware
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ # Global model instance
43
+ model_instance = None
44
+
45
+
46
+ # Request/Response models
47
+ class MathProblemRequest(BaseModel):
48
+ problem: str = Field(..., description="Mathematical problem to solve")
49
+ show_work: bool = Field(True, description="Show step-by-step solution")
50
+ use_tools: bool = Field(True, description="Enable tool calling")
51
+ temperature: float = Field(0.1, ge=0.0, le=2.0, description="Sampling temperature")
52
+ max_tokens: int = Field(4096, ge=1, le=8192, description="Maximum tokens to generate")
53
+ language: Optional[str] = Field("auto", description="Response language: 'auto', 'en', 'zh'")
54
+
55
+
56
+ class ToolCallRequest(BaseModel):
57
+ tool_name: str = Field(..., description="Name of the tool to call")
58
+ arguments: Dict[str, Any] = Field(..., description="Tool arguments")
59
+
60
+
61
+ class BatchMathRequest(BaseModel):
62
+ problems: List[str] = Field(..., description="List of problems to solve")
63
+ show_work: bool = Field(True, description="Show work for all problems")
64
+ use_tools: bool = Field(True, description="Enable tool calling")
65
+ temperature: float = Field(0.1, ge=0.0, le=2.0)
66
+
67
+
68
+ class MathProblemResponse(BaseModel):
69
+ problem: str
70
+ solution: str
71
+ tools_used: List[str] = []
72
+ execution_time_ms: float
73
+ tokens_generated: int
74
+ model: str = "Kirim-1-Math"
75
+
76
+
77
+ class ToolCallResponse(BaseModel):
78
+ tool_name: str
79
+ result: str
80
+ success: bool
81
+ execution_time_ms: float
82
+
83
+
84
+ class HealthResponse(BaseModel):
85
+ status: str
86
+ model_loaded: bool
87
+ cuda_available: bool
88
+ gpu_memory_used_gb: float
89
+ gpu_memory_total_gb: float
90
+
91
+
92
+ class ModelInfoResponse(BaseModel):
93
+ model_name: str
94
+ parameters: str
95
+ capabilities: List[str]
96
+ supported_tools: List[str]
97
+ version: str
98
+
99
+
100
+ # Startup event
101
+ @app.on_event("startup")
102
+ async def load_model():
103
+ """Load the model on startup"""
104
+ global model_instance
105
+
106
+ try:
107
+ logger.info("Loading Kirim-1-Math model...")
108
+ model_instance = KirimMath(
109
+ model_path="Kirim-ai/Kirim-1-Math",
110
+ device="auto",
111
+ load_in_4bit=False # Change to True for lower memory
112
+ )
113
+ logger.info("Model loaded successfully!")
114
+ except Exception as e:
115
+ logger.error(f"Failed to load model: {e}")
116
+ raise
117
+
118
+
119
+ # Health check endpoint
120
+ @app.get("/health", response_model=HealthResponse)
121
+ async def health_check():
122
+ """Check API health and model status"""
123
+ cuda_available = torch.cuda.is_available()
124
+
125
+ if cuda_available:
126
+ gpu_memory_allocated = torch.cuda.memory_allocated() / 1e9
127
+ gpu_memory_total = torch.cuda.get_device_properties(0).total_memory / 1e9
128
+ else:
129
+ gpu_memory_allocated = 0
130
+ gpu_memory_total = 0
131
+
132
+ return HealthResponse(
133
+ status="healthy" if model_instance else "model_not_loaded",
134
+ model_loaded=model_instance is not None,
135
+ cuda_available=cuda_available,
136
+ gpu_memory_used_gb=round(gpu_memory_allocated, 2),
137
+ gpu_memory_total_gb=round(gpu_memory_total, 2)
138
+ )
139
+
140
+
141
+ # Model info endpoint
142
+ @app.get("/info", response_model=ModelInfoResponse)
143
+ async def model_info():
144
+ """Get model information"""
145
+ return ModelInfoResponse(
146
+ model_name="Kirim-1-Math",
147
+ parameters="30B",
148
+ capabilities=[
149
+ "mathematical_reasoning",
150
+ "tool_calling",
151
+ "code_execution",
152
+ "symbolic_computation",
153
+ "bilingual (Chinese/English)"
154
+ ],
155
+ supported_tools=[
156
+ "calculator",
157
+ "symbolic_solver",
158
+ "derivative",
159
+ "integrate",
160
+ "simplify",
161
+ "latex_formatter",
162
+ "code_executor"
163
+ ],
164
+ version="1.0.0"
165
+ )
166
+
167
+
168
+ # Solve math problem endpoint
169
+ @app.post("/solve", response_model=MathProblemResponse)
170
+ async def solve_problem(request: MathProblemRequest):
171
+ """Solve a mathematical problem"""
172
+ if not model_instance:
173
+ raise HTTPException(status_code=503, detail="Model not loaded")
174
+
175
+ try:
176
+ start_time = datetime.now()
177
+
178
+ logger.info(f"Solving problem: {request.problem[:100]}...")
179
+
180
+ solution = model_instance.solve_problem(
181
+ problem=request.problem,
182
+ show_work=request.show_work,
183
+ use_tools=request.use_tools,
184
+ max_new_tokens=request.max_tokens,
185
+ temperature=request.temperature
186
+ )
187
+
188
+ end_time = datetime.now()
189
+ execution_time = (end_time - start_time).total_seconds() * 1000
190
+
191
+ # Extract tools used (simplified)
192
+ tools_used = []
193
+ if "<tool_call>" in solution:
194
+ # Parse tool calls
195
+ import re
196
+ tool_pattern = r'"name":\s*"([^"]+)"'
197
+ tools_used = list(set(re.findall(tool_pattern, solution)))
198
+
199
+ # Estimate tokens (rough approximation)
200
+ tokens_generated = len(solution.split()) * 1.3
201
+
202
+ return MathProblemResponse(
203
+ problem=request.problem,
204
+ solution=solution,
205
+ tools_used=tools_used,
206
+ execution_time_ms=round(execution_time, 2),
207
+ tokens_generated=int(tokens_generated)
208
+ )
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error solving problem: {e}")
212
+ raise HTTPException(status_code=500, detail=str(e))
213
+
214
+
215
+ # Batch solve endpoint
216
+ @app.post("/solve/batch")
217
+ async def solve_batch(request: BatchMathRequest):
218
+ """Solve multiple problems in batch"""
219
+ if not model_instance:
220
+ raise HTTPException(status_code=503, detail="Model not loaded")
221
+
222
+ results = []
223
+
224
+ for problem in request.problems:
225
+ try:
226
+ solution = model_instance.solve_problem(
227
+ problem=problem,
228
+ show_work=request.show_work,
229
+ use_tools=request.use_tools,
230
+ temperature=request.temperature
231
+ )
232
+
233
+ results.append({
234
+ "problem": problem,
235
+ "solution": solution,
236
+ "success": True
237
+ })
238
+ except Exception as e:
239
+ results.append({
240
+ "problem": problem,
241
+ "solution": None,
242
+ "success": False,
243
+ "error": str(e)
244
+ })
245
+
246
+ return {"results": results, "total": len(request.problems)}
247
+
248
+
249
+ # Direct tool call endpoint
250
+ @app.post("/tools/call", response_model=ToolCallResponse)
251
+ async def call_tool(request: ToolCallRequest):
252
+ """Directly call a mathematical tool"""
253
+ try:
254
+ start_time = datetime.now()
255
+
256
+ tool_executor = MathToolExecutor()
257
+ result = tool_executor.execute_tool(request.tool_name, request.arguments)
258
+
259
+ end_time = datetime.now()
260
+ execution_time = (end_time - start_time).total_seconds() * 1000
261
+
262
+ return ToolCallResponse(
263
+ tool_name=request.tool_name,
264
+ result=result,
265
+ success="error" not in result.lower(),
266
+ execution_time_ms=round(execution_time, 2)
267
+ )
268
+
269
+ except Exception as e:
270
+ return ToolCallResponse(
271
+ tool_name=request.tool_name,
272
+ result=str(e),
273
+ success=False,
274
+ execution_time_ms=0
275
+ )
276
+
277
+
278
+ # List available tools
279
+ @app.get("/tools/list")
280
+ async def list_tools():
281
+ """List all available mathematical tools"""
282
+ tools = [
283
+ {
284
+ "name": "calculator",
285
+ "description": "Perform precise arithmetic calculations",
286
+ "parameters": ["expression", "precision"]
287
+ },
288
+ {
289
+ "name": "symbolic_solver",
290
+ "description": "Solve algebraic equations symbolically",
291
+ "parameters": ["equation", "variable", "domain"]
292
+ },
293
+ {
294
+ "name": "derivative",
295
+ "description": "Calculate symbolic derivatives",
296
+ "parameters": ["function", "variable", "order"]
297
+ },
298
+ {
299
+ "name": "integrate",
300
+ "description": "Calculate integrals",
301
+ "parameters": ["function", "variable", "lower_bound", "upper_bound"]
302
+ },
303
+ {
304
+ "name": "simplify",
305
+ "description": "Simplify mathematical expressions",
306
+ "parameters": ["expression", "method"]
307
+ },
308
+ {
309
+ "name": "latex_formatter",
310
+ "description": "Format expressions in LaTeX",
311
+ "parameters": ["expression", "inline"]
312
+ }
313
+ ]
314
+
315
+ return {"tools": tools, "total": len(tools)}
316
+
317
+
318
+ # Statistics endpoint
319
+ @app.get("/stats")
320
+ async def get_stats():
321
+ """Get API usage statistics"""
322
+ # In production, implement proper tracking
323
+ return {
324
+ "requests_processed": "N/A",
325
+ "average_response_time_ms": "N/A",
326
+ "model_status": "active" if model_instance else "inactive"
327
+ }
328
+
329
+
330
+ # Main entry point
331
+ def main():
332
+ import argparse
333
+
334
+ parser = argparse.ArgumentParser(description="Kirim-1-Math API Server")
335
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host address")
336
+ parser.add_argument("--port", type=int, default=8000, help="Port number")
337
+ parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
338
+ parser.add_argument("--workers", type=int, default=1, help="Number of workers")
339
+
340
+ args = parser.parse_args()
341
+
342
+ logger.info(f"Starting Kirim-1-Math API server on {args.host}:{args.port}")
343
+
344
+ uvicorn.run(
345
+ "api_server:app",
346
+ host=args.host,
347
+ port=args.port,
348
+ reload=args.reload,
349
+ workers=args.workers,
350
+ log_level="info"
351
+ )
352
+
353
+
354
+ if __name__ == "__main__":
355
+ main()