silverjini0 commited on
Commit
c0c9b2d
Β·
1 Parent(s): 1c8b451

cityscapes segformer-b0 demo

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.11" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/prac2.iml" filepath="$PROJECT_DIR$/.idea/prac2.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/prac2.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.11" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py CHANGED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import gradio as gr
6
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
7
+
8
+ MODEL_ID = "nvidia/segformer-b0-finetuned-cityscapes-512-1024"
9
+
10
+ def make_palette(num_classes: int):
11
+ base = [
12
+ (255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 0, 255),
13
+ (255, 0, 255), (0, 255, 255), (255, 165, 0), (128, 0, 128),
14
+ (255, 192, 203), (191, 255, 0), (0, 128, 128), (165, 42, 42),
15
+ (0, 0, 128), (128, 128, 0), (128, 0, 0), (255, 215, 0),
16
+ (192, 192, 192), (255, 127, 80), (75, 0, 130), (238, 130, 238),
17
+ ]
18
+ return [base[i % len(base)] for i in range(num_classes)]
19
+
20
+ def colorize(mask: np.ndarray, palette):
21
+ h, w = mask.shape
22
+ out = np.zeros((h, w, 3), dtype=np.uint8)
23
+ for i in range(len(palette)):
24
+ out[mask == i] = palette[i]
25
+ return Image.fromarray(out)
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
29
+ model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID).to(device).eval()
30
+
31
+ id2label = model.config.id2label
32
+ NUM_CLASSES = len(id2label)
33
+ PALETTE = make_palette(NUM_CLASSES)
34
+
35
+ def segment(img: Image.Image, alpha: float = 0.5):
36
+ if img is None:
37
+ return None, None
38
+ with torch.no_grad():
39
+ inputs = processor(images=img, return_tensors="pt").to(device)
40
+ outputs = model(**inputs)
41
+ logits = outputs.logits
42
+ up = torch.nn.functional.interpolate(
43
+ logits, size=img.size[::-1], mode="bilinear", align_corners=False
44
+ )
45
+ pred = up.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
46
+ mask_img = colorize(pred, PALETTE)
47
+ overlay = (np.array(img.convert("RGB")) * (1 - alpha) +
48
+ np.array(mask_img) * alpha).astype(np.uint8)
49
+ return mask_img, Image.fromarray(overlay)
50
+
51
+ def list_examples():
52
+ exdir = "examples"
53
+ if not os.path.isdir(exdir):
54
+ return []
55
+ names = [f for f in os.listdir(exdir)
56
+ if f.lower().endswith((".jpg", ".jpeg", ".png"))]
57
+ return [[os.path.join(exdir, n)] for n in sorted(names)]
58
+
59
+ title = "Cityscapes Segmentation (SegFormer-b0)"
60
+ desc = (
61
+ "Cityscapes(19 classes)둜 ν•™μŠ΅λœ SegFormer-b0 λͺ¨λΈ 데λͺ¨μž…λ‹ˆλ‹€. "
62
+ "λ„μ‹œ/λ„λ‘œ μž₯λ©΄μ—μ„œ μ°¨λŸ‰, λ³΄ν–‰μž, λ„λ‘œ, 건물, ν•˜λŠ˜ 등을 λΆ„ν• ν•©λ‹ˆλ‹€."
63
+ )
64
+
65
+ with gr.Blocks(title=title) as demo:
66
+ gr.Markdown(f"# 🚦 {title}\n{desc}")
67
+ with gr.Row():
68
+ with gr.Column(scale=1):
69
+ inp = gr.Image(type="pil", label="Input Image")
70
+ alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay Transparency")
71
+ btn = gr.Button("Submit", variant="primary")
72
+ with gr.Column(scale=1):
73
+ out_mask = gr.Image(type="pil", label="Segmentation Mask")
74
+ out_overlay = gr.Image(type="pil", label="Overlay (Image + Mask)")
75
+ ex = list_examples()
76
+ if ex:
77
+ gr.Examples(examples=ex, inputs=[inp], examples_per_page=6, label="Examples")
78
+ btn.click(segment, inputs=[inp, alpha], outputs=[out_mask, out_overlay])
79
+
80
+ demo.launch()
examples/city-1.jpg ADDED

Git LFS Details

  • SHA256: bf4a4358e7413bb8e634f6b5bbaa2e40406c96768abb6b8f5cbce863d83ad054
  • Pointer size: 130 Bytes
  • Size of remote file: 65.8 kB
examples/city-2.jpg ADDED

Git LFS Details

  • SHA256: b593289addf03dbf8ddf9183a9d067738c6da6733d7d815f49b6a61bd6166ab8
  • Pointer size: 130 Bytes
  • Size of remote file: 76.8 kB
examples/city-3.jpg ADDED

Git LFS Details

  • SHA256: 16c08a15566ef8bfc45ceba63c1250b1df958aebf1dd3b8c773a38d413fd4602
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
examples/city-4.jpg ADDED

Git LFS Details

  • SHA256: 7d2af9cb7920a712f742284c5b938a74ceb5d177e87f010b8e9cd26661b10e72
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
examples/city-5.jpg ADDED

Git LFS Details

  • SHA256: 9603f5245007a8ce630535766c19c4dc7c2f526bbd859c604a860331c08be2cc
  • Pointer size: 130 Bytes
  • Size of remote file: 92.2 kB
requirements.txt CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==3.44.4
2
+ transformers==4.44.2
3
+ torch>=2.1.0
4
+ Pillow
5
+ numpy