# handler.py import torch import torchvision.transforms as T from PIL import Image import io import json # Define class labels (must match training order) CLASS_LABELS = [ "glove_outline", "webbing", "thumb", "palm_pocket", "hand", "glove_exterior" ] # ---------------------------- # Load model directly from full .bin # ---------------------------- def load_model(): model = torch.load("pytorch_model.bin", map_location="cpu") model.eval() return model model = load_model() # ---------------------------- # Preprocessing # ---------------------------- transform = T.Compose([ T.Resize((720, 1280)), T.ToTensor() ]) def preprocess(input_bytes): image = Image.open(io.BytesIO(input_bytes)).convert("RGB") tensor = transform(image).unsqueeze(0) # [1, 3, H, W] return tensor # ---------------------------- # Dummy input wrapper # ---------------------------- class DummyInput: def __init__(self, image_tensor): B, C, H, W = image_tensor.shape self.images = image_tensor self.masks = [torch.zeros(B, H, W, dtype=torch.bool)] self.num_frames = 1 self.original_size = [(H, W)] self.target_size = [(H, W)] self.point_coords = [None] self.point_labels = [None] self.boxes = [None] self.mask_inputs = torch.zeros(B, 1, H, W) self.video_mask = torch.zeros(B, 1, H, W) self.flat_obj_to_img_idx = [[0]] # ---------------------------- # Postprocessing # ---------------------------- def postprocess(output_tensor): if isinstance(output_tensor, dict) and "masks" in output_tensor: logits = output_tensor["masks"] else: logits = output_tensor pred = torch.argmax(logits, dim=1)[0].cpu().numpy() return pred.tolist() # ---------------------------- # Inference Entry Point # ---------------------------- def infer(payload): if isinstance(payload, bytes): image_tensor = preprocess(payload) elif isinstance(payload, dict) and "inputs" in payload: from base64 import b64decode image_tensor = preprocess(b64decode(payload["inputs"])) else: raise ValueError("Unsupported input format") input_obj = DummyInput(image_tensor) with torch.no_grad(): output = model(input_obj) mask = postprocess(output) return { "mask": mask, "classes": CLASS_LABELS }