DEIT / app.py
Godreign's picture
Update app.py
f2c33af verified
raw
history blame
6.62 kB
import gradio as gr
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from efficientnet_pytorch import EfficientNet
import timm
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import torchvision.transforms as T
import urllib.request
import json
import cv2
# ---------------------------
# Model Configs
# ---------------------------
MODEL_CONFIGS = {
"DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"},
"DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"},
"ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"},
"ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"},
"EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"},
"SqueezeNet": {"type": "timm", "id": "squeezenet1_1"},
"MobileNet-V2": {"type": "timm", "id": "mobilenet_v2"}
}
# ---------------------------
# ImageNet Labels
# ---------------------------
IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
with urllib.request.urlopen(IMAGENET_URL) as url:
IMAGENET_LABELS = json.load(url)
# ---------------------------
# Lazy Load
# ---------------------------
loaded_models = {}
def load_model(model_name):
if model_name in loaded_models:
return loaded_models[model_name]
config = MODEL_CONFIGS[model_name]
if config["type"] == "hf":
extractor = AutoFeatureExtractor.from_pretrained(config["id"])
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
model.eval()
elif config["type"] == "timm":
model = timm.create_model(config["id"], pretrained=True)
model.eval()
extractor = None
elif config["type"] == "efficientnet":
model = EfficientNet.from_pretrained(config["id"])
model.eval()
extractor = None
loaded_models[model_name] = (model, extractor)
return model, extractor
# ---------------------------
# Grad-CAM Helper for CNNs
# ---------------------------
def get_gradcam(model, image_tensor):
grad = None
fmap = None
def forward_hook(module, input, output):
nonlocal fmap
fmap = output.detach()
def backward_hook(module, grad_in, grad_out):
nonlocal grad
grad = grad_out[0].detach()
# Try last conv layer
last_conv = None
for name, module in reversed(model.named_modules()):
if isinstance(module, torch.nn.Conv2d):
last_conv = module
break
if last_conv is None:
return np.ones((224, 224))
handle_fwd = last_conv.register_forward_hook(forward_hook)
handle_bwd = last_conv.register_backward_hook(backward_hook)
out = model(image_tensor)
class_idx = out.argmax(dim=1).item()
score = out[0, class_idx]
model.zero_grad()
score.backward()
weights = grad.mean(dim=(2, 3), keepdim=True)
cam = (weights * fmap).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = cam.squeeze().cpu().numpy()
cam = cv2.resize(cam, (224, 224))
cam = (cam - cam.min()) / (cam.max() + 1e-8)
handle_fwd.remove()
handle_bwd.remove()
return cam
# ---------------------------
# ViT Attention Rollout Helper
# ---------------------------
def vit_attention_rollout(outputs):
if "attentions" not in outputs:
return np.ones((14, 14))
attn = outputs.attentions[-1] # last layer
attn = attn.mean(1) # mean over heads
attn = attn[:, 0, 1:] # discard CLS token
attn_map = attn.reshape(1, 14, 14)
attn_map = attn_map.squeeze().detach().cpu().numpy()
attn_map = (attn_map - attn_map.min()) / (attn_map.max() + 1e-8)
return attn_map
# ---------------------------
# Overlay Helper
# ---------------------------
def overlay_attention(pil_img, attention_map):
heatmap = (attention_map * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = Image.fromarray(heatmap).resize(pil_img.size)
blended = Image.blend(pil_img.convert("RGBA"), heatmap.convert("RGBA"), alpha=0.5)
return blended
# ---------------------------
# Main Prediction Function
# ---------------------------
def predict(image, model_name):
try:
model, extractor = load_model(model_name)
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
x = transform(image).unsqueeze(0)
with torch.no_grad():
if MODEL_CONFIGS[model_name]["type"] == "hf":
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0]
top5_prob, top5_idx = torch.topk(probs, k=5)
top5_labels = [model.config.id2label[idx.item()] for idx in top5_idx]
att_map = vit_attention_rollout(outputs)
else:
outputs = model(x)
probs = F.softmax(outputs, dim=-1)[0]
top5_prob, top5_idx = torch.topk(probs, k=5)
top5_labels = [IMAGENET_LABELS[idx.item()] for idx in top5_idx]
att_map = get_gradcam(model, x)
overlay = overlay_attention(image, att_map)
result = {label: float(prob) for label, prob in zip(top5_labels, top5_prob)}
return result, overlay
except Exception as e:
return {"Error": f"Model '{model_name}' failed: {str(e)}"}, None
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Multi-Model Image Classifier with Real Attention Maps")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
model_dropdown = gr.Dropdown(list(MODEL_CONFIGS.keys()), label="Select Model")
run_button = gr.Button("Run Model")
with gr.Column(scale=2):
output_label = gr.Label(num_top_classes=5, label="Predictions")
output_image = gr.Image(label="Attention Map Overlay")
gr.Markdown(
"💡 *This app dynamically loads models (DeiT, ViT, ConvNeXt, etc.) and visualizes what the model focused on!*"
)
run_button.click(
predict,
inputs=[input_image, model_dropdown],
outputs=[output_label, output_image]
)
if __name__ == "__main__":
demo.launch()