Spaces:
Sleeping
Sleeping
nate
commited on
Commit
Β·
ae43f7b
1
Parent(s):
a054d09
sync
Browse files- src/streamlit_app.py +30 -44
src/streamlit_app.py
CHANGED
|
@@ -114,7 +114,13 @@ def information_conservation_score(x, x_hat, lambda1=0.5):
|
|
| 114 |
|
| 115 |
return ((lambda1 * ms_ssim_score.mean() + (1 - lambda1) * freq_score.mean())).item()
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
def get_image(scene, frame, kind):
|
| 120 |
|
|
@@ -126,19 +132,24 @@ def get_image(scene, frame, kind):
|
|
| 126 |
)
|
| 127 |
return Image.open(image)
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
def validate_inputs(full_name):
|
| 131 |
"""Validate user inputs"""
|
| 132 |
if not full_name.strip():
|
| 133 |
return False, "Please enter your full name."
|
| 134 |
|
| 135 |
-
# if not user_id.strip():
|
| 136 |
-
# return False, "Please enter your ID number."
|
| 137 |
-
|
| 138 |
-
# Basic validation for ID (you can customize this)
|
| 139 |
-
# if not re.match(r'^[a-zA-Z0-9_-]+$', user_id.strip()):
|
| 140 |
-
# return False, "ID number can only contain letters, numbers, hyphens, and underscores."
|
| 141 |
-
|
| 142 |
return True, ""
|
| 143 |
|
| 144 |
def random_crop(gt_array, a_array, b_array, crop_size=512):
|
|
@@ -262,6 +273,7 @@ if not st.session_state.user_authenticated:
|
|
| 262 |
st.error(error_message)
|
| 263 |
|
| 264 |
else:
|
|
|
|
| 265 |
# Main image comparison interface
|
| 266 |
st.header(f"Welcome, {st.session_state.full_name}")
|
| 267 |
st.write(f"User ID: {st.session_state.user_id}")
|
|
@@ -313,9 +325,7 @@ else:
|
|
| 313 |
a_crop = st.session_state.current_crops['a']
|
| 314 |
b_crop = st.session_state.current_crops['b']
|
| 315 |
|
| 316 |
-
st.write(f"### Image set {st.session_state.index + 1}/{st.session_state.target_responses}")
|
| 317 |
-
# st.write(f"Scene: {st.session_state.current_scene}, Frame: {st.session_state.current_frame}")
|
| 318 |
-
# st.write("Compare Image A and Image B to the Ground Truth. Which one is closer to the Ground Truth?")
|
| 319 |
|
| 320 |
# Display images
|
| 321 |
cols = st.columns(3)
|
|
@@ -331,10 +341,6 @@ else:
|
|
| 331 |
cols[1].image(gt_crop, caption="Ground Truth", use_container_width=True)
|
| 332 |
cols[2].image(b_crop, caption="Image B", use_container_width=True)
|
| 333 |
|
| 334 |
-
# Choice selection
|
| 335 |
-
# choice = st.radio("Which image is closer to the Ground Truth?",
|
| 336 |
-
# ["Image A (Left)", "Image B (Right)"],
|
| 337 |
-
# key=f"choice_{st.session_state.index}")
|
| 338 |
|
| 339 |
col1, col2 = st.columns([1, 1])
|
| 340 |
|
|
@@ -476,35 +482,15 @@ else:
|
|
| 476 |
st.success("Choice and metrics recorded!")
|
| 477 |
st.rerun()
|
| 478 |
|
| 479 |
-
|
| 480 |
-
# with col3:
|
| 481 |
-
# if st.button("Reset Study"):
|
| 482 |
-
# # Reset session state
|
| 483 |
-
# for key in list(st.session_state.keys()):
|
| 484 |
-
# del st.session_state[key]
|
| 485 |
-
# st.rerun()
|
| 486 |
-
|
| 487 |
-
# Progress and data display
|
| 488 |
st.write(f"Completed comparisons: {st.session_state.index}")
|
| 489 |
-
|
| 490 |
-
# Show collected responses
|
| 491 |
-
# if st.session_state.responses_data and st.checkbox("Show collected responses", value=False):
|
| 492 |
-
# import pandas as pd
|
| 493 |
-
# df = pd.DataFrame(st.session_state.responses_data, columns=st.session_state.csv_headers)
|
| 494 |
-
# st.dataframe(df)
|
| 495 |
-
|
| 496 |
-
# Display current metrics (optional, for debugging)
|
| 497 |
-
# if st.checkbox("Show current metrics", value=False):
|
| 498 |
-
# with st.spinner("Computing current metrics..."):
|
| 499 |
-
# current_metrics = compute_metrics(gt_crop, a_crop, b_crop, ics_lambda=0.5)
|
| 500 |
-
# if current_metrics:
|
| 501 |
-
# st.json(current_metrics)
|
| 502 |
|
| 503 |
except Exception as e:
|
| 504 |
-
st.success("π Study ended early!")
|
| 505 |
-
st.session_state.study_completed = True
|
| 506 |
-
st.rerun()
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
|
|
|
| 114 |
|
| 115 |
return ((lambda1 * ms_ssim_score.mean() + (1 - lambda1) * freq_score.mean())).item()
|
| 116 |
|
| 117 |
+
@st.cache_resource
|
| 118 |
+
def load_my_dataset():
|
| 119 |
+
return load_dataset(
|
| 120 |
+
"rain-maker/RAW-RAIN-sample",
|
| 121 |
+
split="test", # or use "imagefolder" if repo is raw files
|
| 122 |
+
repo_type="dataset"
|
| 123 |
+
)
|
| 124 |
|
| 125 |
def get_image(scene, frame, kind):
|
| 126 |
|
|
|
|
| 132 |
)
|
| 133 |
return Image.open(image)
|
| 134 |
|
| 135 |
+
# def get_image(scene, frame, kind):
|
| 136 |
+
# # Construct relative path exactly like in your hf_hub_download version
|
| 137 |
+
# rel_path = f"{kind}_test/{scene}/rgb_output/output_{frame}.png"
|
| 138 |
+
|
| 139 |
+
# # Find the entry with that filename
|
| 140 |
+
# record = next((item for item in ds if item["image"].filename.endswith(rel_path)), None)
|
| 141 |
+
|
| 142 |
+
# if record is None:
|
| 143 |
+
# raise FileNotFoundError(f"{rel_path} not found in dataset")
|
| 144 |
+
|
| 145 |
+
# return record["image"] # already a PIL.Image
|
| 146 |
+
|
| 147 |
|
| 148 |
def validate_inputs(full_name):
|
| 149 |
"""Validate user inputs"""
|
| 150 |
if not full_name.strip():
|
| 151 |
return False, "Please enter your full name."
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return True, ""
|
| 154 |
|
| 155 |
def random_crop(gt_array, a_array, b_array, crop_size=512):
|
|
|
|
| 273 |
st.error(error_message)
|
| 274 |
|
| 275 |
else:
|
| 276 |
+
ds = load_my_dataset()
|
| 277 |
# Main image comparison interface
|
| 278 |
st.header(f"Welcome, {st.session_state.full_name}")
|
| 279 |
st.write(f"User ID: {st.session_state.user_id}")
|
|
|
|
| 325 |
a_crop = st.session_state.current_crops['a']
|
| 326 |
b_crop = st.session_state.current_crops['b']
|
| 327 |
|
| 328 |
+
# st.write(f"### Image set {st.session_state.index + 1}/{st.session_state.target_responses}")
|
|
|
|
|
|
|
| 329 |
|
| 330 |
# Display images
|
| 331 |
cols = st.columns(3)
|
|
|
|
| 341 |
cols[1].image(gt_crop, caption="Ground Truth", use_container_width=True)
|
| 342 |
cols[2].image(b_crop, caption="Image B", use_container_width=True)
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
col1, col2 = st.columns([1, 1])
|
| 346 |
|
|
|
|
| 482 |
st.success("Choice and metrics recorded!")
|
| 483 |
st.rerun()
|
| 484 |
|
| 485 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
st.write(f"Completed comparisons: {st.session_state.index}")
|
| 487 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
|
| 489 |
except Exception as e:
|
| 490 |
+
# st.success("π Study ended early!")
|
| 491 |
+
# st.session_state.study_completed = True
|
| 492 |
+
# st.rerun()
|
| 493 |
+
st.error(f"Error loading images: {str(e)}")
|
| 494 |
+
st.write("This might be due to network issues, the Hugging Face repository being unavailable, or missing dependencies.")
|
| 495 |
+
st.write("Make sure you have the following packages installed:")
|
| 496 |
+
st.code("pip install torch torchvision torchmetrics scikit-image opencv-python")
|