import os import numpy as np from PIL import Image import torch import gradio as gr from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation MODEL_ID = "nvidia/segformer-b0-finetuned-cityscapes-512-1024" def make_palette(num_classes: int): base = [ (255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 0, 255), (255, 0, 255), (0, 255, 255), (255, 165, 0), (128, 0, 128), (255, 192, 203), (191, 255, 0), (0, 128, 128), (165, 42, 42), (0, 0, 128), (128, 128, 0), (128, 0, 0), (255, 215, 0), (192, 192, 192), (255, 127, 80), (75, 0, 130), (238, 130, 238), ] return [base[i % len(base)] for i in range(num_classes)] def colorize(mask: np.ndarray, palette): h, w = mask.shape out = np.zeros((h, w, 3), dtype=np.uint8) for i in range(len(palette)): out[mask == i] = palette[i] return Image.fromarray(out) device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID).to(device).eval() id2label = model.config.id2label NUM_CLASSES = len(id2label) PALETTE = make_palette(NUM_CLASSES) def segment(img: Image.Image, alpha: float = 0.5): if img is None: return None, None with torch.no_grad(): inputs = processor(images=img, return_tensors="pt").to(device) outputs = model(**inputs) logits = outputs.logits up = torch.nn.functional.interpolate( logits, size=img.size[::-1], mode="bilinear", align_corners=False ) pred = up.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) mask_img = colorize(pred, PALETTE) overlay = (np.array(img.convert("RGB")) * (1 - alpha) + np.array(mask_img) * alpha).astype(np.uint8) return mask_img, Image.fromarray(overlay) def list_examples(): exdir = "examples" if not os.path.isdir(exdir): return [] names = [f for f in os.listdir(exdir) if f.lower().endswith((".jpg", ".jpeg", ".png"))] return [[os.path.join(exdir, n)] for n in sorted(names)] title = "Cityscapes Segmentation (SegFormer-b0)" desc = ( "Cityscapes(19 classes)로 학습된 SegFormer-b0 모델 데모입니다. " "도시/도로 장면에서 차량, 보행자, 도로, 건물, 하늘 등을 분할합니다." ) with gr.Blocks(title=title) as demo: gr.Markdown(f"# 🚦 {title}\n{desc}") with gr.Row(): with gr.Column(scale=1): inp = gr.Image(type="pil", label="Input Image") alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay Transparency") btn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): out_mask = gr.Image(type="pil", label="Segmentation Mask") out_overlay = gr.Image(type="pil", label="Overlay (Image + Mask)") ex = list_examples() if ex: gr.Examples(examples=ex, inputs=[inp], examples_per_page=6, label="Examples") btn.click(segment, inputs=[inp, alpha], outputs=[out_mask, out_overlay]) demo.launch()