ockkjs commited on
Commit
61db721
Β·
1 Parent(s): 43d9531

change details

Browse files
Files changed (1) hide show
  1. app.py +64 -16
app.py CHANGED
@@ -10,8 +10,8 @@ MODEL_ID = "nvidia/segformer-b4-finetuned-cityscapes-1024-1024"
10
  processor = AutoImageProcessor.from_pretrained(MODEL_ID)
11
  model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)
12
 
 
13
  def ade_palette():
14
- """ADE20K palette that maps each class to RGB values."""
15
  return [
16
  [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
17
  [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
@@ -19,6 +19,7 @@ def ade_palette():
19
  [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32],
20
  ]
21
 
 
22
  labels_list = []
23
  with open("labels.txt", "r", encoding="utf-8") as fp:
24
  for line in fp:
@@ -26,6 +27,7 @@ with open("labels.txt", "r", encoding="utf-8") as fp:
26
 
27
  colormap = np.asarray(ade_palette(), dtype=np.uint8)
28
 
 
29
  def label_to_color_image(label):
30
  if label.ndim != 2:
31
  raise ValueError("Expect 2-D input label")
@@ -33,6 +35,7 @@ def label_to_color_image(label):
33
  raise ValueError("label value too large.")
34
  return colormap[label]
35
 
 
36
  def draw_plot(pred_img, seg_np):
37
  fig = plt.figure(figsize=(20, 15))
38
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
@@ -40,6 +43,7 @@ def draw_plot(pred_img, seg_np):
40
  plt.subplot(grid_spec[0])
41
  plt.imshow(pred_img)
42
  plt.axis('off')
 
43
 
44
  LABEL_NAMES = np.asarray(labels_list)
45
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
@@ -52,8 +56,10 @@ def draw_plot(pred_img, seg_np):
52
  plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
53
  plt.xticks([], [])
54
  ax.tick_params(width=0.0, labelsize=25)
 
55
  return fig
56
 
 
57
  def run_inference(input_img):
58
  # input: numpy array from gradio -> PIL
59
  img = Image.fromarray(input_img.astype(np.uint8)) if isinstance(input_img, np.ndarray) else input_img
@@ -63,32 +69,74 @@ def run_inference(input_img):
63
  inputs = processor(images=img, return_tensors="pt")
64
  with torch.no_grad():
65
  outputs = model(**inputs)
66
- logits = outputs.logits # (1, C, h/4, w/4)
67
 
68
  # resize to original
69
  upsampled = torch.nn.functional.interpolate(
70
  logits, size=img.size[::-1], mode="bilinear", align_corners=False
71
  )
72
- seg = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) # (H,W)
73
 
74
  # colorize & overlay
75
- color_seg = colormap[seg] # (H,W,3)
76
  pred_img = (np.array(img) * 0.5 + color_seg * 0.5).astype(np.uint8)
77
 
78
  fig = draw_plot(pred_img, seg)
79
  return fig
80
 
81
- demo = gr.Interface(
82
- fn=run_inference,
83
- inputs=gr.Image(type="numpy", label="Input Image"),
84
- outputs=gr.Plot(label="Overlay + Legend"),
85
- examples=[
86
- "road-2.jpg",
87
- "road-3.jpeg",
88
- ],
89
- flagging_mode="never",
90
- cache_examples=False,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  if __name__ == "__main__":
94
- demo.launch()
 
10
  processor = AutoImageProcessor.from_pretrained(MODEL_ID)
11
  model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)
12
 
13
+
14
  def ade_palette():
 
15
  return [
16
  [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
17
  [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
 
19
  [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32],
20
  ]
21
 
22
+
23
  labels_list = []
24
  with open("labels.txt", "r", encoding="utf-8") as fp:
25
  for line in fp:
 
27
 
28
  colormap = np.asarray(ade_palette(), dtype=np.uint8)
29
 
30
+
31
  def label_to_color_image(label):
32
  if label.ndim != 2:
33
  raise ValueError("Expect 2-D input label")
 
35
  raise ValueError("label value too large.")
36
  return colormap[label]
37
 
38
+
39
  def draw_plot(pred_img, seg_np):
40
  fig = plt.figure(figsize=(20, 15))
41
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
 
43
  plt.subplot(grid_spec[0])
44
  plt.imshow(pred_img)
45
  plt.axis('off')
46
+ plt.title('Segmentation Result', fontsize=20, pad=20)
47
 
48
  LABEL_NAMES = np.asarray(labels_list)
49
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
 
56
  plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
57
  plt.xticks([], [])
58
  ax.tick_params(width=0.0, labelsize=25)
59
+ plt.title('Detected Classes', fontsize=20, pad=20)
60
  return fig
61
 
62
+
63
  def run_inference(input_img):
64
  # input: numpy array from gradio -> PIL
65
  img = Image.fromarray(input_img.astype(np.uint8)) if isinstance(input_img, np.ndarray) else input_img
 
69
  inputs = processor(images=img, return_tensors="pt")
70
  with torch.no_grad():
71
  outputs = model(**inputs)
72
+ logits = outputs.logits
73
 
74
  # resize to original
75
  upsampled = torch.nn.functional.interpolate(
76
  logits, size=img.size[::-1], mode="bilinear", align_corners=False
77
  )
78
+ seg = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
79
 
80
  # colorize & overlay
81
+ color_seg = colormap[seg]
82
  pred_img = (np.array(img) * 0.5 + color_seg * 0.5).astype(np.uint8)
83
 
84
  fig = draw_plot(pred_img, seg)
85
  return fig
86
 
87
+
88
+ # ===== κ°œμ„ λœ μΈν„°νŽ˜μ΄μŠ€ (Blocks μ‚¬μš©) =====
89
+ with gr.Blocks(theme=gr.themes.Soft(), title="λ„μ‹œ μž₯λ©΄ λΆ„ν• ") as demo:
90
+ gr.Markdown(
91
+ """
92
+ # πŸ™οΈ λ„μ‹œ μž₯λ©΄ μ˜μƒ λΆ„ν•  (City Scene Segmentation)
93
+ **Cityscapes λ°μ΄ν„°μ…‹μœΌλ‘œ ν•™μŠ΅λœ SegFormer λͺ¨λΈ**을 ν™œμš©ν•œ λ„λ‘œ 및 λ„μ‹œ μž₯λ©΄ λΆ„ν•  데λͺ¨μž…λ‹ˆλ‹€.
94
+
95
+ λ„λ‘œ, 건물, μ°¨λŸ‰, λ³΄ν–‰μž λ“± 19개 클래슀λ₯Ό μžλ™μœΌλ‘œ μΈμ‹ν•˜κ³  λΆ„ν• ν•©λ‹ˆλ‹€.
96
+ """
97
+ )
98
+
99
+ with gr.Row():
100
+ with gr.Column(scale=1):
101
+ input_img = gr.Image(
102
+ type="numpy",
103
+ label="πŸ“· μž…λ ₯ 이미지",
104
+ height=400
105
+ )
106
+ submit_btn = gr.Button(
107
+ "🎯 λΆ„ν•  μ‹€ν–‰",
108
+ variant="primary",
109
+ size="lg"
110
+ )
111
+
112
+ gr.Markdown("### πŸ“Œ μ˜ˆμ‹œ 이미지")
113
+ gr.Examples(
114
+ examples=[
115
+ "road-2.jpg",
116
+ "road-3.jpeg",
117
+ ],
118
+ inputs=input_img,
119
+ label="λ„μ‹œ/λ„λ‘œ μž₯λ©΄ μƒ˜ν”Œ"
120
+ )
121
+
122
+ with gr.Column(scale=1):
123
+ output_plot = gr.Plot(label="✨ λΆ„ν•  κ²°κ³Ό 및 λ²”λ‘€")
124
+
125
+ gr.Markdown(
126
+ """
127
+ ---
128
+ ### πŸ“Š 감지 κ°€λŠ₯ν•œ 클래슀 (19개)
129
+ `λ„λ‘œ`, `보도`, `건물`, `λ²½`, `μšΈνƒ€λ¦¬`, `κΈ°λ‘₯`, `μ‹ ν˜Έλ“±`, `ν‘œμ§€νŒ`, `식물`,
130
+ `μ§€ν˜•`, `ν•˜λŠ˜`, `μ‚¬λžŒ`, `μžμ „κ±° νƒ‘μŠΉμž`, `μžλ™μ°¨`, `트럭`, `λ²„μŠ€`, `κΈ°μ°¨`, `μ˜€ν† λ°”μ΄`, `μžμ „κ±°`
131
+ """
132
+ )
133
+
134
+ # 이벀트 ν•Έλ“€λŸ¬
135
+ submit_btn.click(
136
+ fn=run_inference,
137
+ inputs=input_img,
138
+ outputs=output_plot
139
+ )
140
 
141
  if __name__ == "__main__":
142
+ demo.launch()