Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -65,8 +65,8 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
| 65 |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
| 66 |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
| 67 |
|
| 68 |
-
# Calculate how many tokens to mask
|
| 69 |
-
num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
|
| 70 |
# Randomly select indices to mask
|
| 71 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
| 72 |
|
|
@@ -87,6 +87,11 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
| 87 |
# Convert back to text with masks
|
| 88 |
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
return masked_text, indices_to_mask, original_tokens
|
| 91 |
|
| 92 |
def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
@@ -150,18 +155,33 @@ def check_mlm_answer(user_answers):
|
|
| 150 |
"""Check user MLM answers against the masked tokens."""
|
| 151 |
global user_stats
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
#
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# Ensure we have the same number of answers as masks
|
| 163 |
if len(user_tokens) != len(masked_tokens):
|
| 164 |
-
return f"Please provide {len(masked_tokens)} answers. You provided {len(user_tokens)}.\nFormat: word1, word2, word3"
|
| 165 |
|
| 166 |
# Compare each answer
|
| 167 |
correct = 0
|
|
@@ -338,6 +358,9 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
| 338 |
info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
|
| 339 |
)
|
| 340 |
|
|
|
|
|
|
|
|
|
|
| 341 |
sample_text = gr.Textbox(
|
| 342 |
label="Text Sample",
|
| 343 |
placeholder="Click 'New Sample' to get started",
|
|
@@ -351,12 +374,20 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
| 351 |
reset_button = gr.Button("Reset Stats")
|
| 352 |
|
| 353 |
with gr.Group() as mlm_group:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
mlm_answer = gr.Textbox(
|
| 355 |
-
label="Your
|
| 356 |
-
placeholder="word1, word2, word3
|
| 357 |
lines=1
|
| 358 |
)
|
| 359 |
-
gr.Markdown("**Example input format:** finding, its, phishing, in, links, 49, and, it")
|
| 360 |
|
| 361 |
with gr.Group(visible=False) as ntp_group:
|
| 362 |
ntp_answer = gr.Textbox(
|
|
@@ -372,7 +403,27 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
| 372 |
|
| 373 |
# Set up event handlers
|
| 374 |
task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
reset_button.click(reset_stats, inputs=None, outputs=[result])
|
| 377 |
|
| 378 |
check_button.click(
|
|
|
|
| 65 |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
| 66 |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
| 67 |
|
| 68 |
+
# Calculate how many tokens to mask, but ensure at least 1 and at most 8
|
| 69 |
+
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
|
| 70 |
# Randomly select indices to mask
|
| 71 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
| 72 |
|
|
|
|
| 87 |
# Convert back to text with masks
|
| 88 |
masked_text = tokenizer.convert_tokens_to_string(masked_tokens_list)
|
| 89 |
|
| 90 |
+
# Print debugging info
|
| 91 |
+
print(f"Original tokens: {original_tokens}")
|
| 92 |
+
print(f"Masked indices: {indices_to_mask}")
|
| 93 |
+
print(f"Number of masks: {len(original_tokens)}")
|
| 94 |
+
|
| 95 |
return masked_text, indices_to_mask, original_tokens
|
| 96 |
|
| 97 |
def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
|
|
| 155 |
"""Check user MLM answers against the masked tokens."""
|
| 156 |
global user_stats
|
| 157 |
|
| 158 |
+
# Print for debugging
|
| 159 |
+
print(f"Original user input: '{user_answers}'")
|
| 160 |
+
|
| 161 |
+
# Handle the case where input is empty
|
| 162 |
+
if not user_answers or user_answers.isspace():
|
| 163 |
+
return "Please provide your answers. No input was detected."
|
| 164 |
+
|
| 165 |
+
# Basic cleanup - trim and lowercase
|
| 166 |
+
user_answers = user_answers.strip().lower()
|
| 167 |
+
print(f"After basic cleanup: '{user_answers}'")
|
| 168 |
+
|
| 169 |
+
# Explicit comma-based splitting with protection for empty entries
|
| 170 |
+
if ',' in user_answers:
|
| 171 |
+
# Split by commas and strip each item
|
| 172 |
+
user_tokens = [token.strip() for token in user_answers.split(',')]
|
| 173 |
+
# Filter out empty tokens
|
| 174 |
+
user_tokens = [token for token in user_tokens if token]
|
| 175 |
+
else:
|
| 176 |
+
# If no commas, split by whitespace
|
| 177 |
+
user_tokens = [token for token in user_answers.split() if token]
|
| 178 |
+
|
| 179 |
+
print(f"Parsed tokens: {user_tokens}, count: {len(user_tokens)}")
|
| 180 |
+
print(f"Expected tokens: {masked_tokens}, count: {len(masked_tokens)}")
|
| 181 |
|
| 182 |
# Ensure we have the same number of answers as masks
|
| 183 |
if len(user_tokens) != len(masked_tokens):
|
| 184 |
+
return f"Please provide exactly {len(masked_tokens)} answers (one for each [MASK]). You provided {len(user_tokens)}.\n\nFormat example: word1, word2, word3"
|
| 185 |
|
| 186 |
# Compare each answer
|
| 187 |
correct = 0
|
|
|
|
| 358 |
info="Percentage of tokens to mask (MLM) or text to hide (NTP)"
|
| 359 |
)
|
| 360 |
|
| 361 |
+
# Count the visible [MASK] tokens for user reference
|
| 362 |
+
mask_count = gr.Markdown("**Number of [MASK] tokens to guess: 0**")
|
| 363 |
+
|
| 364 |
sample_text = gr.Textbox(
|
| 365 |
label="Text Sample",
|
| 366 |
placeholder="Click 'New Sample' to get started",
|
|
|
|
| 374 |
reset_button = gr.Button("Reset Stats")
|
| 375 |
|
| 376 |
with gr.Group() as mlm_group:
|
| 377 |
+
mlm_instructions = gr.Markdown("""
|
| 378 |
+
### MLM Instructions
|
| 379 |
+
1. For each [MASK] token, provide your guess for the original word.
|
| 380 |
+
2. Separate your answers with commas.
|
| 381 |
+
3. Make sure you provide exactly the same number of answers as [MASK] tokens.
|
| 382 |
+
|
| 383 |
+
**Example format:** `word1, word2, word3` or `word1,word2,word3`
|
| 384 |
+
""")
|
| 385 |
+
|
| 386 |
mlm_answer = gr.Textbox(
|
| 387 |
+
label="Your answers (comma-separated)",
|
| 388 |
+
placeholder="word1, word2, word3",
|
| 389 |
lines=1
|
| 390 |
)
|
|
|
|
| 391 |
|
| 392 |
with gr.Group(visible=False) as ntp_group:
|
| 393 |
ntp_answer = gr.Textbox(
|
|
|
|
| 403 |
|
| 404 |
# Set up event handlers
|
| 405 |
task_radio.change(switch_task, inputs=[task_radio], outputs=[mlm_group, ntp_group])
|
| 406 |
+
|
| 407 |
+
# Update the sample text and also update the mask count
|
| 408 |
+
def new_sample_with_count(mask_ratio_pct, task):
|
| 409 |
+
ratio = float(mask_ratio_pct) / 100.0
|
| 410 |
+
sample = get_new_sample(task, ratio)
|
| 411 |
+
mask_count_text = ""
|
| 412 |
+
|
| 413 |
+
if task == "mlm":
|
| 414 |
+
count = len(masked_tokens)
|
| 415 |
+
mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
|
| 416 |
+
else:
|
| 417 |
+
mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
|
| 418 |
+
|
| 419 |
+
return sample, mask_count_text, ""
|
| 420 |
+
|
| 421 |
+
new_button.click(
|
| 422 |
+
new_sample_with_count,
|
| 423 |
+
inputs=[mask_ratio, task_radio],
|
| 424 |
+
outputs=[sample_text, mask_count, result]
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
reset_button.click(reset_stats, inputs=None, outputs=[result])
|
| 428 |
|
| 429 |
check_button.click(
|