Elliot Sones commited on
Commit
30ecfe2
·
1 Parent(s): f5801b6

Fix label errors + predict per stroke

Browse files
Files changed (1) hide show
  1. app.py +29 -18
app.py CHANGED
@@ -196,13 +196,24 @@ except Exception as e:
196
 
197
  def predict(strokes_json):
198
  """Predict from JSON stroke data."""
199
- if LOAD_ERROR:
200
- return {"error": f"Model load failed: {LOAD_ERROR}"}
201
- if MODEL is None:
202
- return {"error": "Model not loaded (unknown reason)"}
203
-
204
  try:
205
- raw_strokes = json.loads(strokes_json) if isinstance(strokes_json, str) else strokes_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  if not raw_strokes:
207
  return {a: 0.0 for a in ANIMALS}
208
 
@@ -226,7 +237,8 @@ def predict(strokes_json):
226
 
227
  return {IDX_TO_CLASS.get(i, f"class_{i}"): float(probs[i]) for i in range(len(ANIMALS))}
228
  except Exception as e:
229
- return {"error": str(e)}
 
230
 
231
  # ============================================================================
232
  # Custom Canvas HTML
@@ -303,6 +315,14 @@ CANVAS_JS = r"""() => {
303
  strokes.push([currentStroke.x, currentStroke.y]);
304
  }
305
  isDrawing = false;
 
 
 
 
 
 
 
 
306
  };
307
 
308
  canvas.addEventListener("mousedown", (e) => {
@@ -347,20 +367,11 @@ CANVAS_JS = r"""() => {
347
  clearBtn.addEventListener("click", () => {
348
  ctx.clearRect(0, 0, canvas.width, canvas.height);
349
  strokes = [];
350
- const textbox = getTextInput();
351
- if (textbox) {
352
- textbox.value = "";
353
- textbox.dispatchEvent(new Event("input", { bubbles: true }));
354
- }
355
  });
356
 
357
  predictBtn.addEventListener("click", () => {
358
- const strokesJson = JSON.stringify(strokes);
359
- const textbox = getTextInput();
360
- if (textbox) {
361
- textbox.value = strokesJson;
362
- textbox.dispatchEvent(new Event("input", { bubbles: true }));
363
- }
364
  const btn = getGradioPredictButton();
365
  if (btn) btn.click();
366
  });
 
196
 
197
  def predict(strokes_json):
198
  """Predict from JSON stroke data."""
 
 
 
 
 
199
  try:
200
+ if LOAD_ERROR or MODEL is None:
201
+ return {a: 0.0 for a in ANIMALS}
202
+
203
+ if strokes_json is None:
204
+ return {a: 0.0 for a in ANIMALS}
205
+
206
+ if isinstance(strokes_json, str):
207
+ s = strokes_json.strip()
208
+ if not s:
209
+ return {a: 0.0 for a in ANIMALS}
210
+ try:
211
+ raw_strokes = json.loads(s)
212
+ except Exception:
213
+ return {a: 0.0 for a in ANIMALS}
214
+ else:
215
+ raw_strokes = strokes_json
216
+
217
  if not raw_strokes:
218
  return {a: 0.0 for a in ANIMALS}
219
 
 
237
 
238
  return {IDX_TO_CLASS.get(i, f"class_{i}"): float(probs[i]) for i in range(len(ANIMALS))}
239
  except Exception as e:
240
+ print(f"Prediction failed: {e}")
241
+ return {a: 0.0 for a in ANIMALS}
242
 
243
  # ============================================================================
244
  # Custom Canvas HTML
 
315
  strokes.push([currentStroke.x, currentStroke.y]);
316
  }
317
  isDrawing = false;
318
+ syncToTextbox();
319
+ };
320
+
321
+ const syncToTextbox = () => {
322
+ const textbox = getTextInput();
323
+ if (!textbox) return;
324
+ textbox.value = JSON.stringify(strokes);
325
+ textbox.dispatchEvent(new Event("input", { bubbles: true }));
326
  };
327
 
328
  canvas.addEventListener("mousedown", (e) => {
 
367
  clearBtn.addEventListener("click", () => {
368
  ctx.clearRect(0, 0, canvas.width, canvas.height);
369
  strokes = [];
370
+ syncToTextbox();
 
 
 
 
371
  });
372
 
373
  predictBtn.addEventListener("click", () => {
374
+ syncToTextbox();
 
 
 
 
 
375
  const btn = getGradioPredictButton();
376
  if (btn) btn.click();
377
  });