Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| from io import BytesIO | |
| import base64 | |
| import torch | |
| import re | |
| import logging | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize global variables | |
| model = None | |
| processor = None | |
| tokenizer = None | |
| model_name = "microsoft/GUI-Actor-2B-Qwen2-VL" | |
| model_loaded = False | |
| async def load_model(): | |
| """Load model with proper error handling""" | |
| global model, processor, tokenizer, model_loaded | |
| try: | |
| logger.info("Starting model loading...") | |
| # Import required modules | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| logger.info("Loading processor...") | |
| # Use AutoProcessor for better compatibility | |
| processor = AutoProcessor.from_pretrained( | |
| model_name, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Processor loaded successfully") | |
| tokenizer = processor.tokenizer | |
| logger.info("Loading model...") | |
| # Use AutoModelForCausalLM for better compatibility | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, | |
| device_map=None, # CPU only | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True # For better memory management | |
| ).eval() | |
| logger.info("Model loaded successfully!") | |
| model_loaded = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| model_loaded = False | |
| return False | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| logger.info("Starting up GUI-Actor API...") | |
| await load_model() | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down GUI-Actor API...") | |
| # Initialize FastAPI app with lifespan | |
| app = FastAPI( | |
| title="GUI-Actor API", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| class Base64Request(BaseModel): | |
| image_base64: str | |
| instruction: str | |
| def extract_coordinates(text): | |
| """ | |
| Extract coordinates from model output text | |
| """ | |
| # Pattern untuk mencari koordinat dalam berbagai format | |
| patterns = [ | |
| r'click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # click(x, y) | |
| r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', # [x, y] | |
| r'(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)', # x, y | |
| r'point:\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # point: (x, y) | |
| ] | |
| for pattern in patterns: | |
| matches = re.findall(pattern, text.lower()) | |
| if matches: | |
| try: | |
| x, y = float(matches[0][0]), float(matches[0][1]) | |
| # Normalize jika koordinat > 1 (asumsi pixel coordinates) | |
| if x > 1 or y > 1: | |
| # Asumsi resolusi 1920x1080 untuk normalisasi | |
| x = x / 1920 if x > 1 else x | |
| y = y / 1080 if y > 1 else y | |
| return [(x, y)] | |
| except (ValueError, IndexError): | |
| continue | |
| # Default ke center jika tidak ditemukan | |
| return [(0.5, 0.5)] | |
| def cpu_inference(conversation, model, tokenizer, processor): | |
| """ | |
| Inference function untuk CPU | |
| """ | |
| try: | |
| # Apply chat template | |
| text = processor.apply_chat_template( | |
| conversation, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Get image from conversation | |
| image = conversation[1]["content"][0]["image"] | |
| # Process inputs | |
| inputs = processor( | |
| text=[text], | |
| images=[image], | |
| return_tensors="pt" | |
| ) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.3, | |
| top_p=0.8, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| generated_ids = outputs[0][inputs["input_ids"].shape[1]:] | |
| response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| # Extract coordinates | |
| coordinates = extract_coordinates(response) | |
| return { | |
| "topk_points": coordinates, | |
| "response": response, | |
| "success": True | |
| } | |
| except Exception as e: | |
| logger.error(f"Inference error: {e}") | |
| return { | |
| "topk_points": [(0.5, 0.5)], | |
| "response": f"Error during inference: {str(e)}", | |
| "success": False | |
| } | |
| async def root(): | |
| return { | |
| "message": "GUI-Actor API is running", | |
| "status": "healthy", | |
| "model_loaded": model_loaded | |
| } | |
| async def predict_click_base64(data: Base64Request): | |
| if not model_loaded: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded properly" | |
| ) | |
| try: | |
| # Decode base64 to image | |
| try: | |
| # Handle data URL format | |
| if "," in data.image_base64: | |
| image_data = base64.b64decode(data.image_base64.split(",")[-1]) | |
| else: | |
| image_data = base64.b64decode(data.image_base64) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid base64 image: {e}") | |
| try: | |
| pil_image = Image.open(BytesIO(image_data)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid image format: {e}") | |
| conversation = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. Please provide the click coordinates.", | |
| } | |
| ] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": pil_image, | |
| }, | |
| { | |
| "type": "text", | |
| "text": data.instruction, | |
| }, | |
| ], | |
| }, | |
| ] | |
| # Run inference | |
| pred = cpu_inference(conversation, model, tokenizer, processor) | |
| px, py = pred["topk_points"][0] | |
| return JSONResponse(content={ | |
| "x": round(px, 4), | |
| "y": round(py, 4), | |
| "response": pred["response"], | |
| "success": pred["success"] | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Internal server error: {str(e)}" | |
| ) | |
| async def health_check(): | |
| return { | |
| "status": "healthy" if model_loaded else "unhealthy", | |
| "model": model_name, | |
| "device": "cpu", | |
| "torch_dtype": "float32", | |
| "model_loaded": model_loaded | |
| } | |
| async def predict_click_form( | |
| image_base64: str = Form(...), | |
| instruction: str = Form(...) | |
| ): | |
| data = Base64Request(image_base64=image_base64, instruction=instruction) | |
| return await predict_click_base64(data) |