import gradio as gr from transformers import AutoFeatureExtractor, AutoModelForImageClassification from efficientnet_pytorch import EfficientNet import timm import torch import torch.nn.functional as F from PIL import Image import numpy as np import torchvision.transforms as T import urllib.request import json import cv2 # --------------------------- # Model Configs # --------------------------- MODEL_CONFIGS = { "DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"}, "DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"}, "ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"}, "ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"}, "EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"}, "SqueezeNet": {"type": "timm", "id": "squeezenet1_1"}, "MobileNet-V2": {"type": "timm", "id": "mobilenet_v2"} } # --------------------------- # ImageNet Labels # --------------------------- IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" with urllib.request.urlopen(IMAGENET_URL) as url: IMAGENET_LABELS = json.load(url) # --------------------------- # Lazy Load # --------------------------- loaded_models = {} def load_model(model_name): if model_name in loaded_models: return loaded_models[model_name] config = MODEL_CONFIGS[model_name] if config["type"] == "hf": extractor = AutoFeatureExtractor.from_pretrained(config["id"]) model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True) model.eval() elif config["type"] == "timm": model = timm.create_model(config["id"], pretrained=True) model.eval() extractor = None elif config["type"] == "efficientnet": model = EfficientNet.from_pretrained(config["id"]) model.eval() extractor = None loaded_models[model_name] = (model, extractor) return model, extractor # --------------------------- # Grad-CAM Helper for CNNs # --------------------------- def get_gradcam(model, image_tensor): grad = None fmap = None def forward_hook(module, input, output): nonlocal fmap fmap = output.detach() def backward_hook(module, grad_in, grad_out): nonlocal grad grad = grad_out[0].detach() # Try last conv layer last_conv = None for name, module in reversed(model.named_modules()): if isinstance(module, torch.nn.Conv2d): last_conv = module break if last_conv is None: return np.ones((224, 224)) handle_fwd = last_conv.register_forward_hook(forward_hook) handle_bwd = last_conv.register_backward_hook(backward_hook) out = model(image_tensor) class_idx = out.argmax(dim=1).item() score = out[0, class_idx] model.zero_grad() score.backward() weights = grad.mean(dim=(2, 3), keepdim=True) cam = (weights * fmap).sum(dim=1, keepdim=True) cam = F.relu(cam) cam = cam.squeeze().cpu().numpy() cam = cv2.resize(cam, (224, 224)) cam = (cam - cam.min()) / (cam.max() + 1e-8) handle_fwd.remove() handle_bwd.remove() return cam # --------------------------- # ViT Attention Rollout Helper # --------------------------- def vit_attention_rollout(outputs): if "attentions" not in outputs: return np.ones((14, 14)) attn = outputs.attentions[-1] # last layer attn = attn.mean(1) # mean over heads attn = attn[:, 0, 1:] # discard CLS token attn_map = attn.reshape(1, 14, 14) attn_map = attn_map.squeeze().detach().cpu().numpy() attn_map = (attn_map - attn_map.min()) / (attn_map.max() + 1e-8) return attn_map # --------------------------- # Overlay Helper # --------------------------- def overlay_attention(pil_img, attention_map): heatmap = (attention_map * 255).astype(np.uint8) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = Image.fromarray(heatmap).resize(pil_img.size) blended = Image.blend(pil_img.convert("RGBA"), heatmap.convert("RGBA"), alpha=0.5) return blended # --------------------------- # Main Prediction Function # --------------------------- def predict(image, model_name): try: model, extractor = load_model(model_name) transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) x = transform(image).unsqueeze(0) with torch.no_grad(): if MODEL_CONFIGS[model_name]["type"] == "hf": inputs = extractor(images=image, return_tensors="pt") outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1)[0] top5_prob, top5_idx = torch.topk(probs, k=5) top5_labels = [model.config.id2label[idx.item()] for idx in top5_idx] att_map = vit_attention_rollout(outputs) else: outputs = model(x) probs = F.softmax(outputs, dim=-1)[0] top5_prob, top5_idx = torch.topk(probs, k=5) top5_labels = [IMAGENET_LABELS[idx.item()] for idx in top5_idx] att_map = get_gradcam(model, x) overlay = overlay_attention(image, att_map) result = {label: float(prob) for label, prob in zip(top5_labels, top5_prob)} return result, overlay except Exception as e: return {"Error": f"Model '{model_name}' failed: {str(e)}"}, None # --------------------------- # Gradio UI # --------------------------- with gr.Blocks() as demo: gr.Markdown("## 🧠 Multi-Model Image Classifier with Real Attention Maps") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Image") model_dropdown = gr.Dropdown(list(MODEL_CONFIGS.keys()), label="Select Model") run_button = gr.Button("Run Model") with gr.Column(scale=2): output_label = gr.Label(num_top_classes=5, label="Predictions") output_image = gr.Image(label="Attention Map Overlay") gr.Markdown( "💡 *This app dynamically loads models (DeiT, ViT, ConvNeXt, etc.) and visualizes what the model focused on!*" ) run_button.click( predict, inputs=[input_image, model_dropdown], outputs=[output_label, output_image] ) if __name__ == "__main__": demo.launch()