import spaces import torch import torch.nn.functional as F import gradio as gr from gradio_image_annotation import image_annotator from models.counter_infer import build_model from utils.arg_parser import get_argparser from utils.data import resize_and_pad import torchvision.ops as ops from torchvision import transforms as T from PIL import Image, ImageDraw from huggingface_hub import hf_hub_download import numpy as np import colorsys # ----------------------------- _MODEL = None _ARGS = None _WEIGHTS_PATH = None # ----------------------------- def _get_args(): global _ARGS if _ARGS is None: args = get_argparser().parse_args() args.zero_shot = True _ARGS = args return _ARGS def _get_weights_path(): global _WEIGHTS_PATH if _WEIGHTS_PATH is None: _WEIGHTS_PATH = hf_hub_download( repo_id="jerpelhan/geco2-assets", filename="weights/CNTQG_multitrain_ca44.pth", repo_type="dataset", ) return _WEIGHTS_PATH def _strip_module_prefix(state_dict: dict) -> dict: """ If weights were saved from torch.nn.DataParallel, keys are often prefixed with 'module.'. When loading into a non-DataParallel model, strip that prefix. """ if not isinstance(state_dict, dict) or len(state_dict) == 0: return state_dict # Only strip if it looks like DP has_module = any(k.startswith("module.") for k in state_dict.keys()) if not has_module: return state_dict return {k[len("module.") :]: v for k, v in state_dict.items()} def _extract_state_dict(ckpt) -> dict: """ Robustly extract a state_dict from typical checkpoint formats. """ if isinstance(ckpt, dict): # Common keys if "model" in ckpt and isinstance(ckpt["model"], dict): return ckpt["model"] if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict): return ckpt["state_dict"] # Fallback: checkpoint itself is the state_dict return ckpt def get_model_on_device(device: torch.device): """ Lazily build and load model, then move to the requested device. IMPORTANT: model is constructed/loaded without initializing CUDA in the main process. This function will be called from inside the @spaces.GPU worker. """ global _MODEL if _MODEL is None: args = _get_args() # Build on CPU first to avoid CUDA init in the wrong process model = build_model(args) weights_path = _get_weights_path() ckpt = torch.load(weights_path, map_location="cpu") # keep compatibility across torch versions state = _extract_state_dict(ckpt) state = _strip_module_prefix(state) model.load_state_dict(state, strict=False) model.eval() _MODEL = model _MODEL = _MODEL.to(device) if device.type == "cuda": torch.backends.cudnn.benchmark = True return _MODEL # ----------------------------- # Rotation helper (in case annotator reports orientation) # ----------------------------- def _rotate_image_and_boxes(image_np: np.ndarray, boxes: list[dict], angle: int): if angle is None: return image_np, boxes a = int(angle) % 4 if a == 0: return image_np, boxes H, W = image_np.shape[:2] # rotate image using the same convention as the component docs image_rot = np.rot90(image_np, k=-a) def clamp_box(xmin, ymin, xmax, ymax, newW, newH): xmin = max(0, min(newW, xmin)) xmax = max(0, min(newW, xmax)) ymin = max(0, min(newH, ymin)) ymax = max(0, min(newH, ymax)) if xmax < xmin: xmin, xmax = xmax, xmin if ymax < ymin: ymin, ymax = ymax, ymin return xmin, ymin, xmax, ymax boxes_rot = [] if a == 1: # 90 deg clockwise: (x,y) -> (H - 1 - y, x) newH, newW = W, H for b in boxes: xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"] nxmin = H - ymax nxmax = H - ymin nymin = xmin nymax = xmax nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH) bb = dict(b) bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax}) boxes_rot.append(bb) elif a == 2: # 180 deg: (x,y) -> (W - 1 - x, H - 1 - y) newH, newW = H, W for b in boxes: xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"] nxmin = W - xmax nxmax = W - xmin nymin = H - ymax nymax = H - ymin nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH) bb = dict(b) bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax}) boxes_rot.append(bb) else: # a == 3 # 90 deg counter-clockwise: (x,y) -> (y, W - 1 - x) newH, newW = W, H for b in boxes: xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"] nxmin = ymin nxmax = ymax nymin = W - xmax nymax = W - xmin nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH) bb = dict(b) bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax}) boxes_rot.append(bb) return image_rot, boxes_rot # ----------------------------- # Function to Process Image Once (GPU) # ----------------------------- @spaces.GPU def process_image_once(inputs, enable_mask): """ inputs is AnnotatedImageValue-like dict from gradio_image_annotation: { "image": np.ndarray | PIL | str, "boxes": [ {xmin,ymin,xmax,ymax,label?,color?}, ... ], "orientation": int? } """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = get_model_on_device(device) if inputs is None or inputs.get("image", None) is None: # keep behavior simple: return empty outputs return None, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, [] image = inputs["image"] boxes = inputs.get("boxes", []) or [] # Ensure numpy image (support numpy, PIL, OR local path string) if isinstance(image, Image.Image): image = np.array(image.convert("RGB")) elif isinstance(image, str): image = np.array(Image.open(image).convert("RGB")) elif isinstance(image, np.ndarray): pass else: raise ValueError(f"Unsupported image type from annotator: {type(image)}") angle = inputs.get("orientation", None) if angle is not None: image, boxes = _rotate_image_and_boxes(image, boxes, angle) drawn_boxes = [] for b in boxes: drawn_boxes.append([float(b["xmin"]), float(b["ymin"]), 0.0, float(b["xmax"]), float(b["ymax"])]) # If no boxes, do not call model (caller will handle warning) if len(drawn_boxes) == 0: return image, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, [] image_tensor = torch.tensor(image).to(device) image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0 image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor) bboxes_tensor = torch.tensor( [[box[0], box[1], box[3], box[4]] for box in drawn_boxes], dtype=torch.float32, ).to(device) img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0) img = img.unsqueeze(0).to(device) bboxes = bboxes.unsqueeze(0).to(device) # Faster inference mode use_amp = (device.type == "cuda") with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp): model.return_masks = enable_mask outputs, _, _, _, masks = model(img, bboxes) # Return ONLY CPU-native objects to main process. out0 = outputs[0] pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu() box_v_cpu = out0["box_v"].detach().float().cpu() outputs_cpu = [{"pred_boxes": pred_boxes_cpu, "box_v": box_v_cpu}] if enable_mask and masks is not None and masks[0] is not None: masks_cpu = [masks[0].detach().float().cpu()] else: masks_cpu = [None] img_cpu = img.detach().cpu() return image, outputs_cpu, masks_cpu, img_cpu, float(scale), drawn_boxes # ----------------------------- # Pastel visualization helpers # ----------------------------- def _hsv_to_rgb255(h, s, v): r, g, b = colorsys.hsv_to_rgb(h, s, v) return (int(255 * r), int(255 * g), int(255 * b)) def instance_colors(i: int): h = (i * 0.618033988749895) % 1.0 mask_rgb = _hsv_to_rgb255(h, s=0.28, v=1.00) box_rgb = _hsv_to_rgb255(h, s=0.42, v=0.95) return mask_rgb, box_rgb def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alpha=0.45): if mask_bool.dtype != np.bool_: mask_bool = mask_bool.astype(bool) h, w = mask_bool.shape overlay = np.zeros((h, w, 4), dtype=np.uint8) overlay[..., 0] = rgb[0] overlay[..., 1] = rgb[1] overlay[..., 2] = rgb[2] overlay[..., 3] = (mask_bool.astype(np.uint8) * int(255 * alpha)) overlay_img = Image.fromarray(overlay, mode="RGBA") return Image.alpha_composite(base_rgba, overlay_img) # ----------------------------- # Post-process and Update Output # ----------------------------- def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold): idx = 0 threshold = 1 / threshold score = outputs[idx]["box_v"] if score.numel() == 0: # no predictions image_pil = Image.fromarray((image).astype(np.uint8)).convert("RGB") return image_pil, 0 score_mask = score > score.max() / threshold keep = ops.nms( outputs[idx]["pred_boxes"][score_mask], score[score_mask], 0.5, ) pred_boxes = outputs[idx]["pred_boxes"][score_mask][keep] pred_boxes = torch.clamp(pred_boxes, 0, 1) pred_boxes = (pred_boxes / scale * img.shape[-1]).tolist() image = Image.fromarray((image).astype(np.uint8)).convert("RGBA") if enable_mask and masks is not None and masks[idx] is not None: masks_sel = masks[idx][score_mask[0]] if score_mask.ndim > 1 else masks[idx][score_mask] masks_sel = masks_sel[keep] target_h = int(img.shape[2] / scale) target_w = int(img.shape[3] / scale) resize_nearest = T.Resize((target_h, target_w), interpolation=T.InterpolationMode.NEAREST) W, H = image.size for i in range(masks_sel.shape[0]): mask_i = masks_sel[i] if mask_i.ndim == 3: mask_i = mask_i[0] mask_rs = resize_nearest(mask_i.unsqueeze(0))[0] mask_rs = mask_rs[:H, :W] mask_bool = (mask_rs > 0.0).cpu().numpy().astype(bool) mask_rgb, _ = instance_colors(i) image = overlay_single_mask(image, mask_bool, mask_rgb, alpha=0.45) draw = ImageDraw.Draw(image) box_width = 2 for i, box in enumerate(pred_boxes): _, box_rgb = instance_colors(i) x1, y1, x2, y2 = map(float, box) draw.rectangle([x1, y1, x2, y2], outline=box_rgb, width=box_width) exemplar_outline = (255, 255, 255, 255) exemplar_inner = (0, 0, 0, 255) for box in drawn_boxes: x1, y1, x2, y2 = box[0], box[1], box[3], box[4] draw.rectangle([x1, y1, x2, y2], outline=exemplar_outline, width=2) draw.rectangle([x1 + 1, y1 + 1, x2 - 1, y2 - 1], outline=exemplar_inner, width=1) return image.convert("RGB"), len(pred_boxes) # ----------------------------- # Examples: gallery click -> set annotator value # ----------------------------- EXAMPLE_PATHS = ["material/1.jpg", "material/2.jpg", "material/3.jpg", "material/4.jpg", "material/5.jpg"] def load_example_from_gallery(evt: gr.SelectData): """ When user clicks a thumbnail in the gallery, load that image into the annotator. """ idx = int(evt.index) path = EXAMPLE_PATHS[idx] return {"image": path, "boxes": []} # ----------------------------- # Gradio UI # ----------------------------- iface = gr.Blocks( title="GeCo2 Gradio Demo", ) with iface: gr.Markdown( """ # GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image without any retraining. 1) Upload an image or click an example below. 2) Draw bounding boxes on the target object (preferably ~3 instances). 3) Click **Count**. 4) If needed, adjust the threshold. """ ) # Store intermediate states image_input = gr.State() outputs_state = gr.State() masks_state = gr.State() img_state = gr.State() scale_state = gr.State() drawn_boxes_state = gr.State() with gr.Row(): annotator = image_annotator( value=None, image_type="numpy", # ensures inputs["image"] is a numpy array label_list=["Object"], label_colors=[(0, 255, 0)], use_default_label=True, enable_keyboard_shortcuts=True, interactive=True, show_label=False, ) image_output = gr.Image(type="pil") with gr.Row(): count_output = gr.Number(label="Total Count") enable_mask = gr.Checkbox(label="Predict masks", value=True) threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01, label="Threshold") count_button = gr.Button("Count") gallery = gr.Gallery( value=EXAMPLE_PATHS, columns=5, height=450, label="Examples (click an image to load it into the annotator)", show_label=True, allow_preview=False, ) gallery.select( fn=load_example_from_gallery, inputs=None, outputs=annotator, ) def initial_process(inputs, enable_mask, threshold): # Validate: must have at least one box if inputs is None or inputs.get("image", None) is None: gr.Warning("please delineate at least one target category object") return None, 0, None, None, None, None, None, None img_val = inputs.get("image", None) boxes = inputs.get("boxes", []) or [] if len(boxes) == 0: # Try to show current image in the output even if no boxes if isinstance(img_val, str): preview = Image.open(img_val).convert("RGB") elif isinstance(img_val, Image.Image): preview = img_val.convert("RGB") elif isinstance(img_val, np.ndarray): preview = Image.fromarray(img_val.astype(np.uint8)).convert("RGB") else: preview = None gr.Warning("please delineate at least one target category object") return preview, 0, None, None, None, None, None, None image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask) if image is None: return None, 0, None, None, None, None, None, None out_img, cnt = post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold) return ( out_img, cnt, image, outputs, masks, img, scale, drawn_boxes, ) def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask): if image is None or outputs is None or img is None: return None, 0 return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold) count_button.click( initial_process, [annotator, enable_mask, threshold], [image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state], ) threshold.change( update_threshold, [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask], [image_output, count_output], ) enable_mask.change( update_threshold, [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask], [image_output, count_output], ) if __name__ == "__main__": iface.queue().launch(ssr_mode=False)