Spaces:
Running
Running
| from pathlib import Path | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from gradio_leaderboard import Leaderboard | |
| from datetime import datetime | |
| import os | |
| from about import ( | |
| PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo, | |
| COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS, | |
| METRIC_GROUP_COLORS, COLUMN_TO_GROUP | |
| ) | |
| def get_leaderboard(): | |
| ds = load_dataset(results_repo, split='train', download_mode="force_redownload") | |
| full_df = pd.DataFrame(ds) | |
| print(full_df.columns) | |
| if len(full_df) == 0: | |
| return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]}) | |
| return full_df | |
| def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True): | |
| """Format the dataframe with proper column names and optional percentages.""" | |
| if len(df) == 0: | |
| return df | |
| # Build column list based on view mode | |
| selected_cols = ['model_name'] | |
| if compact_view: | |
| # Use predefined compact columns | |
| from about import COMPACT_VIEW_COLUMNS | |
| selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns] | |
| else: | |
| # Build from selected groups | |
| if 'n_structures' in df.columns: | |
| selected_cols.append('n_structures') | |
| # If no groups selected, show all | |
| if not selected_groups: | |
| selected_groups = list(METRIC_GROUPS.keys()) | |
| # Add columns from selected groups | |
| for group in selected_groups: | |
| if group in METRIC_GROUPS: | |
| for col in METRIC_GROUPS[group]: | |
| if col in df.columns and col not in selected_cols: | |
| selected_cols.append(col) | |
| # Create a copy with selected columns | |
| display_df = df[selected_cols].copy() | |
| # Add relaxed symbol to model name if relaxed column is True | |
| if 'relaxed' in df.columns and 'model_name' in display_df.columns: | |
| display_df['model_name'] = df.apply( | |
| lambda row: f"{row['model_name']} β‘" if row.get('relaxed', False) else row['model_name'], | |
| axis=1 | |
| ) | |
| # Convert count-based metrics to percentages if requested | |
| if show_percentage and 'n_structures' in df.columns: | |
| n_structures = df['n_structures'] | |
| for col in COUNT_BASED_METRICS: | |
| if col in display_df.columns: | |
| # Calculate percentage and format as string with % | |
| display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%' | |
| # Round numeric columns for cleaner display | |
| for col in display_df.columns: | |
| if display_df[col].dtype in ['float64', 'float32']: | |
| display_df[col] = display_df[col].round(4) | |
| # Rename columns for display | |
| display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES) | |
| # Apply color coding based on metric groups | |
| styled_df = apply_color_styling(display_df, selected_cols) | |
| return styled_df | |
| def apply_color_styling(display_df, original_cols): | |
| """Apply background colors to dataframe based on metric groups using pandas Styler.""" | |
| def style_by_group(x): | |
| # Create a DataFrame with the same shape filled with empty strings | |
| styles = pd.DataFrame('', index=x.index, columns=x.columns) | |
| # Map display column names back to original column names | |
| for i, display_col in enumerate(x.columns): | |
| if i < len(original_cols): | |
| original_col = original_cols[i] | |
| # Check if this column belongs to a metric group | |
| if original_col in COLUMN_TO_GROUP: | |
| group = COLUMN_TO_GROUP[original_col] | |
| color = METRIC_GROUP_COLORS.get(group, '') | |
| if color: | |
| styles[display_col] = f'background-color: {color}' | |
| return styles | |
| # Apply the styling function | |
| return display_df.style.apply(style_by_group, axis=None) | |
| def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction): | |
| """Update the leaderboard based on user selections. | |
| Uses cached dataframe to avoid re-downloading data on every change. | |
| """ | |
| # Use cached dataframe instead of re-downloading | |
| df_to_format = cached_df.copy() | |
| # Convert display name back to raw column name for sorting | |
| if sort_by and sort_by != "None": | |
| # Create reverse mapping from display names to raw column names | |
| display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()} | |
| raw_column_name = display_to_raw.get(sort_by, sort_by) | |
| if raw_column_name in df_to_format.columns: | |
| ascending = (sort_direction == "Ascending") | |
| df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending) | |
| formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view) | |
| return formatted_df | |
| def show_output_box(message): | |
| return gr.update(value=message, visible=True) | |
| def submit_cif_files(problem_type, cif_files, relaxed, profile: gr.OAuthProfile | None): | |
| # TODO: Implement submission logic that includes the relaxed flag | |
| return | |
| def generate_metric_legend_html(): | |
| """Generate HTML table with color-coded metric group legend.""" | |
| metric_details = { | |
| 'Validity β': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', 'β Higher is better'), | |
| 'Uniqueness & Novelty β': ('Unique, Novel', 'β Higher is better'), | |
| 'Energy Metrics β': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', 'β Lower is better'), | |
| 'Stability β': ('Stable, Unique in Stable, SUN', 'β Higher is better'), | |
| 'Metastability β': ('Metastable, Unique in Metastable, MSUN', 'β Higher is better'), | |
| 'Distribution β': ('JS Distance, MMD, FID', 'β Lower is better'), | |
| 'Diversity β': ('Element, Space Group, Atomic Site, Crystal Size', 'β Higher is better'), | |
| 'HHI β': ('HHI Production, HHI Reserve', 'β Lower is better'), | |
| } | |
| html = '<table style="width: 100%; border-collapse: collapse;">' | |
| html += '<thead><tr>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>' | |
| html += '</tr></thead><tbody>' | |
| for group, color in METRIC_GROUP_COLORS.items(): | |
| metrics, direction = metric_details.get(group, ('', '')) | |
| group_name = group.replace('β', '').replace('β', '').strip() | |
| html += '<tr>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>' | |
| html += '</tr>' | |
| html += '</tbody></table>' | |
| return html | |
| def gradio_interface() -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Welcome to the LeMaterial Generative Benchmark Leaderboard!") | |
| with gr.Tabs(elem_classes="tab-buttons"): | |
| with gr.TabItem("π Leaderboard", elem_id="boundary-benchmark-tab-table"): | |
| gr.Markdown("# LeMat-GenBench") | |
| # Display options | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| compact_view = gr.Checkbox( | |
| value=True, | |
| label="Compact View", | |
| info="Show only key metrics" | |
| ) | |
| show_percentage = gr.Checkbox( | |
| value=True, | |
| label="Show as Percentages", | |
| info="Display count-based metrics as percentages of total structures" | |
| ) | |
| with gr.Column(scale=1): | |
| # Create choices with display names, but values are the raw column names | |
| sort_choices = ["None"] + [COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys()] | |
| sort_by = gr.Dropdown( | |
| choices=sort_choices, | |
| value="None", | |
| label="Sort By", | |
| info="Select column to sort by" | |
| ) | |
| sort_direction = gr.Radio( | |
| choices=["Ascending", "Descending"], | |
| value="Descending", | |
| label="Sort Direction" | |
| ) | |
| with gr.Column(scale=2): | |
| selected_groups = gr.CheckboxGroup( | |
| choices=list(METRIC_GROUPS.keys()), | |
| value=list(METRIC_GROUPS.keys()), | |
| label="Metric Families (only active when Compact View is off)", | |
| info="Select which metric groups to display" | |
| ) | |
| # Metric legend with color coding | |
| with gr.Accordion("Metric Groups Legend", open=False): | |
| gr.HTML(generate_metric_legend_html()) | |
| try: | |
| # Initial dataframe - load once and cache | |
| initial_df = get_leaderboard() | |
| cached_df_state = gr.State(initial_df) | |
| formatted_df = format_dataframe(initial_df, show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True) | |
| leaderboard_table = gr.Dataframe( | |
| label="GenBench Leaderboard", | |
| value=formatted_df, | |
| interactive=False, | |
| wrap=True, | |
| column_widths=["180px"] + ["160px"] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None, | |
| show_fullscreen_button=True | |
| ) | |
| # Update dataframe when options change (using cached data) | |
| show_percentage.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| selected_groups.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| compact_view.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| sort_by.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| sort_direction.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], | |
| outputs=leaderboard_table | |
| ) | |
| except Exception as e: | |
| gr.Markdown(f"Leaderboard is empty or error loading: {str(e)}") | |
| gr.Markdown("Verified submissions mean the results came from a model submission rather than a CIF submission.") | |
| with gr.TabItem("βAbout", elem_id="boundary-benchmark-tab-table"): | |
| gr.Markdown( | |
| """ | |
| ## About LeMat-Gen-Bench | |
| **Welcome to the LeMat-Bench Leaderboard**, There are unconditional generation and conditional generation components of this leaderboard. | |
| """) | |
| with gr.TabItem("βοΈ Submit", elem_id="boundary-benchmark-tab-table"): | |
| gr.Markdown( | |
| """ | |
| # Materials Submission | |
| Upload a CSV, pkl, or a ZIP of CIFs with your structures. | |
| """ | |
| ) | |
| filename = gr.State(value=None) | |
| gr.LoginButton() | |
| with gr.Row(): | |
| with gr.Column(): | |
| problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type") | |
| with gr.Column(): | |
| cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.") | |
| relaxed = gr.Checkbox( | |
| value=False, | |
| label="Structures are already relaxed", | |
| info="Check this box if your submitted structures have already been relaxed" | |
| ) | |
| submit_btn = gr.Button("Submission") | |
| message = gr.Textbox(label="Status", lines=1, visible=False) | |
| # help message | |
| gr.Markdown("If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space.") | |
| submit_btn.click( | |
| submit_cif_files, | |
| inputs=[problem_type, cif_file, relaxed], | |
| outputs=[message, filename], | |
| ).then( | |
| fn=show_output_box, | |
| inputs=[message], | |
| outputs=[message], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| gradio_interface().launch() | |