GUI-Agent / app.py
abiyyufahri's picture
Install error fix attemp 6
0b96209
raw
history blame
5.9 kB
from fastapi import FastAPI, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from PIL import Image
from io import BytesIO
import base64
import torch
import re
from transformers import AutoModelForCausalLM, AutoProcessor
app = FastAPI()
# Load model untuk CPU
model_name = "microsoft/GUI-Actor-2B-Qwen2-VL"
# Load processor
try:
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
except Exception as e:
print(f"Failed to load AutoProcessor: {e}")
from transformers import Qwen2VLProcessor
processor = Qwen2VLProcessor.from_pretrained(model_name)
tokenizer = processor.tokenizer
# Load model dengan CPU support
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # float32 untuk CPU
device_map=None, # CPU only
trust_remote_code=True, # untuk custom model
attn_implementation=None # skip flash attention
).eval()
class Base64Request(BaseModel):
image_base64: str
instruction: str
def extract_coordinates(text):
"""
Extract coordinates from model output text
"""
# Pattern untuk mencari koordinat dalam berbagai format
patterns = [
r'click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # click(x, y)
r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', # [x, y]
r'(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)', # x, y
r'point:\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # point: (x, y)
]
for pattern in patterns:
matches = re.findall(pattern, text.lower())
if matches:
try:
x, y = float(matches[0][0]), float(matches[0][1])
# Normalize jika koordinat > 1 (asumsi pixel coordinates)
if x > 1 or y > 1:
# Asumsi resolusi 1920x1080 untuk normalisasi
x = x / 1920 if x > 1 else x
y = y / 1080 if y > 1 else y
return [(x, y)]
except (ValueError, IndexError):
continue
# Default ke center jika tidak ditemukan
return [(0.5, 0.5)]
def cpu_inference(conversation, model, tokenizer, processor):
"""
Inference function untuk CPU tanpa GUI-Actor dependencies
"""
try:
# Apply chat template
text = processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Get image from conversation
image = conversation[1]["content"][0]["image"]
# Process inputs
inputs = processor(
text=[text],
images=[image],
return_tensors="pt"
)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.3,
top_p=0.8,
pad_token_id=tokenizer.eos_token_id
)
# Decode response
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
response = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Extract coordinates
coordinates = extract_coordinates(response)
return {
"topk_points": coordinates,
"response": response,
"success": True
}
except Exception as e:
return {
"topk_points": [(0.5, 0.5)],
"response": f"Error during inference: {str(e)}",
"success": False
}
@app.post("/click/base64")
async def predict_click_base64(data: Base64Request):
try:
# Decode base64 to image
image_data = base64.b64decode(data.image_base64.split(",")[-1])
pil_image = Image.open(BytesIO(image_data)).convert("RGB")
conversation = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. Please provide the click coordinates.",
}
]
},
{
"role": "user",
"content": [
{
"type": "image",
"image": pil_image,
},
{
"type": "text",
"text": data.instruction,
},
],
},
]
# Run inference
pred = cpu_inference(conversation, model, tokenizer, processor)
px, py = pred["topk_points"][0]
return JSONResponse(content={
"x": round(px, 4),
"y": round(py, 4),
"response": pred["response"],
"success": pred["success"]
})
except Exception as e:
return JSONResponse(
content={
"error": str(e),
"success": False,
"x": 0.5,
"y": 0.5
},
status_code=500
)
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model": model_name,
"device": "cpu",
"torch_dtype": "float32"
}
@app.post("/click/form")
async def predict_click_form(
image_base64: str = Form(...),
instruction: str = Form(...)
):
data = Base64Request(image_base64=image_base64, instruction=instruction)
return await predict_click_base64(data)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)