|
|
import gradio as gr |
|
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
from efficientnet_pytorch import EfficientNet |
|
|
import timm |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torchvision.transforms as T |
|
|
import urllib.request |
|
|
import json |
|
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
"DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"}, |
|
|
"DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"}, |
|
|
"ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"}, |
|
|
"ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"}, |
|
|
"EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"}, |
|
|
"SqueezeNet": {"type": "timm", "id": "squeezenet1_1"}, |
|
|
"MobileNet-V2": {"type": "timm", "id": "mobilenet_v2"} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" |
|
|
with urllib.request.urlopen(IMAGENET_URL) as url: |
|
|
IMAGENET_LABELS = json.load(url) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loaded_models = {} |
|
|
|
|
|
def load_model(model_name): |
|
|
if model_name in loaded_models: |
|
|
return loaded_models[model_name] |
|
|
|
|
|
config = MODEL_CONFIGS[model_name] |
|
|
if config["type"] == "hf": |
|
|
extractor = AutoFeatureExtractor.from_pretrained(config["id"]) |
|
|
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True) |
|
|
model.eval() |
|
|
elif config["type"] == "timm": |
|
|
model = timm.create_model(config["id"], pretrained=True) |
|
|
model.eval() |
|
|
extractor = None |
|
|
elif config["type"] == "efficientnet": |
|
|
model = EfficientNet.from_pretrained(config["id"]) |
|
|
model.eval() |
|
|
extractor = None |
|
|
|
|
|
loaded_models[model_name] = (model, extractor) |
|
|
return model, extractor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gradcam(model, image_tensor): |
|
|
grad = None |
|
|
fmap = None |
|
|
|
|
|
def forward_hook(module, input, output): |
|
|
nonlocal fmap |
|
|
fmap = output.detach() |
|
|
|
|
|
def backward_hook(module, grad_in, grad_out): |
|
|
nonlocal grad |
|
|
grad = grad_out[0].detach() |
|
|
|
|
|
|
|
|
last_conv = None |
|
|
for name, module in reversed(model.named_modules()): |
|
|
if isinstance(module, torch.nn.Conv2d): |
|
|
last_conv = module |
|
|
break |
|
|
if last_conv is None: |
|
|
return np.ones((224, 224)) |
|
|
|
|
|
handle_fwd = last_conv.register_forward_hook(forward_hook) |
|
|
handle_bwd = last_conv.register_backward_hook(backward_hook) |
|
|
|
|
|
out = model(image_tensor) |
|
|
class_idx = out.argmax(dim=1).item() |
|
|
score = out[0, class_idx] |
|
|
model.zero_grad() |
|
|
score.backward() |
|
|
|
|
|
weights = grad.mean(dim=(2, 3), keepdim=True) |
|
|
cam = (weights * fmap).sum(dim=1, keepdim=True) |
|
|
cam = F.relu(cam) |
|
|
cam = cam.squeeze().cpu().numpy() |
|
|
cam = cv2.resize(cam, (224, 224)) |
|
|
cam = (cam - cam.min()) / (cam.max() + 1e-8) |
|
|
|
|
|
handle_fwd.remove() |
|
|
handle_bwd.remove() |
|
|
return cam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vit_attention_rollout(outputs): |
|
|
if "attentions" not in outputs: |
|
|
return np.ones((14, 14)) |
|
|
|
|
|
attn = outputs.attentions[-1] |
|
|
attn = attn.mean(1) |
|
|
attn = attn[:, 0, 1:] |
|
|
attn_map = attn.reshape(1, 14, 14) |
|
|
attn_map = attn_map.squeeze().detach().cpu().numpy() |
|
|
attn_map = (attn_map - attn_map.min()) / (attn_map.max() + 1e-8) |
|
|
return attn_map |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def overlay_attention(pil_img, attention_map): |
|
|
heatmap = (attention_map * 255).astype(np.uint8) |
|
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
|
|
heatmap = Image.fromarray(heatmap).resize(pil_img.size) |
|
|
blended = Image.blend(pil_img.convert("RGBA"), heatmap.convert("RGBA"), alpha=0.5) |
|
|
return blended |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(image, model_name): |
|
|
try: |
|
|
model, extractor = load_model(model_name) |
|
|
transform = T.Compose([ |
|
|
T.Resize((224, 224)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
x = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if MODEL_CONFIGS[model_name]["type"] == "hf": |
|
|
inputs = extractor(images=image, return_tensors="pt") |
|
|
outputs = model(**inputs) |
|
|
probs = F.softmax(outputs.logits, dim=-1)[0] |
|
|
top5_prob, top5_idx = torch.topk(probs, k=5) |
|
|
top5_labels = [model.config.id2label[idx.item()] for idx in top5_idx] |
|
|
att_map = vit_attention_rollout(outputs) |
|
|
|
|
|
else: |
|
|
outputs = model(x) |
|
|
probs = F.softmax(outputs, dim=-1)[0] |
|
|
top5_prob, top5_idx = torch.topk(probs, k=5) |
|
|
top5_labels = [IMAGENET_LABELS[idx.item()] for idx in top5_idx] |
|
|
att_map = get_gradcam(model, x) |
|
|
|
|
|
overlay = overlay_attention(image, att_map) |
|
|
result = {label: float(prob) for label, prob in zip(top5_labels, top5_prob)} |
|
|
return result, overlay |
|
|
|
|
|
except Exception as e: |
|
|
return {"Error": f"Model '{model_name}' failed: {str(e)}"}, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 🧠 Multi-Model Image Classifier with Real Attention Maps") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
|
model_dropdown = gr.Dropdown(list(MODEL_CONFIGS.keys()), label="Select Model") |
|
|
run_button = gr.Button("Run Model") |
|
|
with gr.Column(scale=2): |
|
|
output_label = gr.Label(num_top_classes=5, label="Predictions") |
|
|
output_image = gr.Image(label="Attention Map Overlay") |
|
|
gr.Markdown( |
|
|
"💡 *This app dynamically loads models (DeiT, ViT, ConvNeXt, etc.) and visualizes what the model focused on!*" |
|
|
) |
|
|
|
|
|
run_button.click( |
|
|
predict, |
|
|
inputs=[input_image, model_dropdown], |
|
|
outputs=[output_label, output_image] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|