nate commited on
Commit
ae43f7b
Β·
1 Parent(s): a054d09
Files changed (1) hide show
  1. src/streamlit_app.py +30 -44
src/streamlit_app.py CHANGED
@@ -114,7 +114,13 @@ def information_conservation_score(x, x_hat, lambda1=0.5):
114
 
115
  return ((lambda1 * ms_ssim_score.mean() + (1 - lambda1) * freq_score.mean())).item()
116
 
117
-
 
 
 
 
 
 
118
 
119
  def get_image(scene, frame, kind):
120
 
@@ -126,19 +132,24 @@ def get_image(scene, frame, kind):
126
  )
127
  return Image.open(image)
128
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  def validate_inputs(full_name):
131
  """Validate user inputs"""
132
  if not full_name.strip():
133
  return False, "Please enter your full name."
134
 
135
- # if not user_id.strip():
136
- # return False, "Please enter your ID number."
137
-
138
- # Basic validation for ID (you can customize this)
139
- # if not re.match(r'^[a-zA-Z0-9_-]+$', user_id.strip()):
140
- # return False, "ID number can only contain letters, numbers, hyphens, and underscores."
141
-
142
  return True, ""
143
 
144
  def random_crop(gt_array, a_array, b_array, crop_size=512):
@@ -262,6 +273,7 @@ if not st.session_state.user_authenticated:
262
  st.error(error_message)
263
 
264
  else:
 
265
  # Main image comparison interface
266
  st.header(f"Welcome, {st.session_state.full_name}")
267
  st.write(f"User ID: {st.session_state.user_id}")
@@ -313,9 +325,7 @@ else:
313
  a_crop = st.session_state.current_crops['a']
314
  b_crop = st.session_state.current_crops['b']
315
 
316
- st.write(f"### Image set {st.session_state.index + 1}/{st.session_state.target_responses}")
317
- # st.write(f"Scene: {st.session_state.current_scene}, Frame: {st.session_state.current_frame}")
318
- # st.write("Compare Image A and Image B to the Ground Truth. Which one is closer to the Ground Truth?")
319
 
320
  # Display images
321
  cols = st.columns(3)
@@ -331,10 +341,6 @@ else:
331
  cols[1].image(gt_crop, caption="Ground Truth", use_container_width=True)
332
  cols[2].image(b_crop, caption="Image B", use_container_width=True)
333
 
334
- # Choice selection
335
- # choice = st.radio("Which image is closer to the Ground Truth?",
336
- # ["Image A (Left)", "Image B (Right)"],
337
- # key=f"choice_{st.session_state.index}")
338
 
339
  col1, col2 = st.columns([1, 1])
340
 
@@ -476,35 +482,15 @@ else:
476
  st.success("Choice and metrics recorded!")
477
  st.rerun()
478
 
479
-
480
- # with col3:
481
- # if st.button("Reset Study"):
482
- # # Reset session state
483
- # for key in list(st.session_state.keys()):
484
- # del st.session_state[key]
485
- # st.rerun()
486
-
487
- # Progress and data display
488
  st.write(f"Completed comparisons: {st.session_state.index}")
489
-
490
- # Show collected responses
491
- # if st.session_state.responses_data and st.checkbox("Show collected responses", value=False):
492
- # import pandas as pd
493
- # df = pd.DataFrame(st.session_state.responses_data, columns=st.session_state.csv_headers)
494
- # st.dataframe(df)
495
-
496
- # Display current metrics (optional, for debugging)
497
- # if st.checkbox("Show current metrics", value=False):
498
- # with st.spinner("Computing current metrics..."):
499
- # current_metrics = compute_metrics(gt_crop, a_crop, b_crop, ics_lambda=0.5)
500
- # if current_metrics:
501
- # st.json(current_metrics)
502
 
503
  except Exception as e:
504
- st.success("πŸŽ‰ Study ended early!")
505
- st.session_state.study_completed = True
506
- st.rerun()
507
- # st.error(f"Error loading images: {str(e)}")
508
- # st.write("This might be due to network issues, the Hugging Face repository being unavailable, or missing dependencies.")
509
- # st.write("Make sure you have the following packages installed:")
510
- # st.code("pip install torch torchvision torchmetrics scikit-image opencv-python")
 
114
 
115
  return ((lambda1 * ms_ssim_score.mean() + (1 - lambda1) * freq_score.mean())).item()
116
 
117
+ @st.cache_resource
118
+ def load_my_dataset():
119
+ return load_dataset(
120
+ "rain-maker/RAW-RAIN-sample",
121
+ split="test", # or use "imagefolder" if repo is raw files
122
+ repo_type="dataset"
123
+ )
124
 
125
  def get_image(scene, frame, kind):
126
 
 
132
  )
133
  return Image.open(image)
134
 
135
+ # def get_image(scene, frame, kind):
136
+ # # Construct relative path exactly like in your hf_hub_download version
137
+ # rel_path = f"{kind}_test/{scene}/rgb_output/output_{frame}.png"
138
+
139
+ # # Find the entry with that filename
140
+ # record = next((item for item in ds if item["image"].filename.endswith(rel_path)), None)
141
+
142
+ # if record is None:
143
+ # raise FileNotFoundError(f"{rel_path} not found in dataset")
144
+
145
+ # return record["image"] # already a PIL.Image
146
+
147
 
148
  def validate_inputs(full_name):
149
  """Validate user inputs"""
150
  if not full_name.strip():
151
  return False, "Please enter your full name."
152
 
 
 
 
 
 
 
 
153
  return True, ""
154
 
155
  def random_crop(gt_array, a_array, b_array, crop_size=512):
 
273
  st.error(error_message)
274
 
275
  else:
276
+ ds = load_my_dataset()
277
  # Main image comparison interface
278
  st.header(f"Welcome, {st.session_state.full_name}")
279
  st.write(f"User ID: {st.session_state.user_id}")
 
325
  a_crop = st.session_state.current_crops['a']
326
  b_crop = st.session_state.current_crops['b']
327
 
328
+ # st.write(f"### Image set {st.session_state.index + 1}/{st.session_state.target_responses}")
 
 
329
 
330
  # Display images
331
  cols = st.columns(3)
 
341
  cols[1].image(gt_crop, caption="Ground Truth", use_container_width=True)
342
  cols[2].image(b_crop, caption="Image B", use_container_width=True)
343
 
 
 
 
 
344
 
345
  col1, col2 = st.columns([1, 1])
346
 
 
482
  st.success("Choice and metrics recorded!")
483
  st.rerun()
484
 
485
+
 
 
 
 
 
 
 
 
486
  st.write(f"Completed comparisons: {st.session_state.index}")
487
+
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
  except Exception as e:
490
+ # st.success("πŸŽ‰ Study ended early!")
491
+ # st.session_state.study_completed = True
492
+ # st.rerun()
493
+ st.error(f"Error loading images: {str(e)}")
494
+ st.write("This might be due to network issues, the Hugging Face repository being unavailable, or missing dependencies.")
495
+ st.write("Make sure you have the following packages installed:")
496
+ st.code("pip install torch torchvision torchmetrics scikit-image opencv-python")