Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import random | |
| import numpy as np | |
| import yaml | |
| import logging | |
| import os | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| import tempfile | |
| import traceback | |
| from data_utils import ( | |
| save_uploaded_files, | |
| load_dataset, | |
| ) | |
| from inference_utils import run_inference | |
| from config_utils import load_config | |
| from plot_utils import plot_prithvi_output, plot_aurora_output | |
| from prithvi_utils import ( | |
| prithvi_config_ui, | |
| initialize_prithvi_model, | |
| prepare_prithvi_batch | |
| ) | |
| from aurora_utils import aurora_config_ui, prepare_aurora_batch, initialize_aurora_model | |
| from pangu_utils import ( | |
| pangu_config_data, | |
| inference_1hr, | |
| inference_3hrs, | |
| inference_6hrs, | |
| inference_24hrs, | |
| inference_custom_hrs, | |
| plot_pangu_output, | |
| ) | |
| from fengwu_utils import (fengwu_config_data, inference_6hrs_fengwu, inference_12hrs_fengwu, inference_custom_hrs_fengwu, plot_fengwu_output) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="Weather Data Processor", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| header_col1, header_col2 = st.columns([4, 1]) | |
| with header_col1: | |
| st.title("π¦οΈ Weather & Climate Data Processor and Forecaster") | |
| with header_col2: | |
| st.markdown("### Select a Model") | |
| selected_model = st.selectbox( | |
| "", | |
| options=["Pangu-Weather", "FengWu", "Aurora", "Climax", "Prithvi", "GEOS-Specific-LSTM", "GEOS-Finetuned-Climax"], | |
| index=0, | |
| key="model_selector", | |
| help="Select the model you want to use." | |
| ) | |
| st.write("---") | |
| # --- Layout: Two Columns --- | |
| left_col, right_col = st.columns([1, 2]) | |
| with left_col: | |
| st.header("π§ Configuration") | |
| # Dynamically show configuration UI based on selected model | |
| if selected_model == "Prithvi": | |
| (config, uploaded_surface_files, uploaded_vertical_files, | |
| clim_surf_path, clim_vert_path, config_path, weights_path) = prithvi_config_ui() | |
| elif selected_model == "Climax": | |
| st.info("Climax model is not yet available.") | |
| st.stop() | |
| elif selected_model == "GEOS-Specific-LSTM": | |
| st.info("GEOS-Specific-LSTM model is not yet available.") | |
| st.stop() | |
| elif selected_model == "GEOS-Finetuned-Climax": | |
| st.info("GEOS-Finetuned-Climax model is not yet available.") | |
| st.stop() | |
| elif selected_model == "Aurora": | |
| uploaded_files = aurora_config_ui() | |
| elif selected_model == "Pangu-Weather": | |
| input_surface_file, input_upper_file = pangu_config_data() | |
| elif selected_model == "FengWu": | |
| input_file1_fengwu, input_file2_fengwu = fengwu_config_data() | |
| else: | |
| # Generic data upload for other models | |
| st.subheader(f"{selected_model} Model Data Upload") | |
| st.markdown("### Drag and Drop Your Data Files Here") | |
| uploaded_files = st.file_uploader( | |
| f"Upload Data Files for {selected_model}", | |
| accept_multiple_files=True, | |
| key=f"{selected_model.lower()}_uploader", | |
| type=["nc", "netcdf", "nc4"], | |
| ) | |
| st.write("---") | |
| # --- Forecast Duration Selection --- | |
| st.subheader("Forecast Duration") | |
| forecast_options = ["1 hour", "3 hours", "6 hours", "24 hours", "Custom"] | |
| selected_duration = st.selectbox( | |
| "Select forecast duration", | |
| forecast_options, | |
| index=3, # Default to 24 hours | |
| help="Select how many hours to forecast." | |
| ) | |
| custom_hours = None | |
| if selected_duration == "Custom": | |
| custom_hours = st.number_input( | |
| "Enter custom forecast hours", | |
| min_value=24, | |
| max_value=480, | |
| value=48, | |
| step=24, | |
| help="Enter the number of hours you want to forecast." | |
| ) | |
| st.write("---") | |
| # Run Inference button | |
| if st.button("π Run Inference"): | |
| with right_col: | |
| st.header("π Inference Progress & Visualization") | |
| # Set seeds and device | |
| try: | |
| torch.jit.enable_onednn_fusion(True) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| st.write(f"Using device: **{torch.cuda.get_device_name()}**") | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = True | |
| else: | |
| device = torch.device("cpu") | |
| st.write("Using device: **CPU**") | |
| random.seed(42) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(42) | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| except Exception: | |
| st.error("Error initializing device:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Use a spinner while running inference | |
| with st.spinner("Running inference, please wait..."): | |
| # Initialize and run inference for selected model | |
| if selected_model == "Prithvi": | |
| model, in_mu, in_sig, output_sig, static_mu, static_sig = initialize_prithvi_model( | |
| config, config_path, weights_path, device | |
| ) | |
| batch = prepare_prithvi_batch( | |
| uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device | |
| ) | |
| out = run_inference(selected_model, model, batch, device) | |
| # Store results | |
| st.session_state['prithvi_out'] = out | |
| st.session_state['prithvi_done'] = True | |
| elif selected_model == "Aurora": | |
| if uploaded_files: | |
| save_uploaded_files(uploaded_files) | |
| ds = load_dataset(st.session_state.temp_file_paths) | |
| if ds is not None: | |
| batch = prepare_aurora_batch(ds) | |
| model = initialize_aurora_model(device) | |
| out = run_inference(selected_model, model, batch, device) | |
| st.session_state['aurora_out'] = out | |
| st.session_state['aurora_ds_subset'] = ds | |
| st.session_state['aurora_done'] = True | |
| else: | |
| st.error("Failed to load dataset for Aurora.") | |
| st.stop() | |
| else: | |
| st.error("Please upload data files for Aurora.") | |
| st.stop() | |
| elif selected_model == "FengWu": | |
| if input_file1_fengwu and input_file2_fengwu: | |
| try: | |
| input1 = np.load(input_file1_fengwu) | |
| input2 = np.load(input_file2_fengwu) | |
| if selected_duration == "1 hour": | |
| st.warning("1hr inference is not yet available on this model.") | |
| elif selected_duration == "3 hours": | |
| st.warning("3hrs inference is not yet available on this model.") | |
| elif selected_duration == "6 hours": | |
| output_fengwu = inference_6hrs_fengwu(input1, input2) | |
| elif selected_duration == "12 hours": | |
| output_fengwu = inference_12hrs_fengwu(input1, input2) | |
| else: | |
| output_fengwu = inference_custom_hrs_fengwu(input1, input2, custom_hours) | |
| st.session_state['output_fengwu'] = output_fengwu | |
| st.session_state['fengwu_done'] = True | |
| st.session_state['input_fengwu'] = input_file2_fengwu | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| else: | |
| st.error("Please upload data files for Aurora.") | |
| st.stop() | |
| elif selected_model == "Pangu-Weather": | |
| if input_surface_file and input_upper_file: | |
| try: | |
| surface_data = np.load(input_surface_file) | |
| upper_data = np.load(input_upper_file) | |
| # Decide which inference function to use based on selection | |
| if selected_duration == "1 hour": | |
| out_upper, out_surface = inference_1hr(upper_data, surface_data) | |
| elif selected_duration == "3 hours": | |
| out_upper, out_surface = inference_3hrs(upper_data, surface_data) | |
| elif selected_duration == "6 hours": | |
| out_upper, out_surface = inference_6hrs(upper_data, surface_data) | |
| elif selected_duration == "24 hours": | |
| out_upper, out_surface = inference_24hrs(upper_data, surface_data) | |
| else: | |
| out_upper, out_surface = inference_custom_hrs(upper_data, surface_data, custom_hours) | |
| # Store results in session_state | |
| st.session_state['pangu_upper_data'] = upper_data | |
| st.session_state['pangu_surface_data'] = surface_data | |
| st.session_state['pangu_out_upper'] = out_upper | |
| st.session_state['pangu_out_surface'] = out_surface | |
| st.session_state['pangu_done'] = True | |
| st.write("**Forecast Results:**") | |
| st.write("Upper Data Forecast Shape:", out_upper.shape) | |
| st.write("Surface Data Forecast Shape:", out_surface.shape) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| else: | |
| st.error("Please upload data files for Pangu-Weather.") | |
| st.stop() | |
| else: | |
| st.warning("Inference not implemented for this model.") | |
| st.stop() | |
| # Visualization after inference is done | |
| if selected_model == "Prithvi": | |
| if 'prithvi_done' in st.session_state and st.session_state['prithvi_done']: | |
| plot_prithvi_output(st.session_state['prithvi_out']) | |
| elif selected_model == "Aurora": | |
| if 'aurora_done' in st.session_state and st.session_state['aurora_done']: | |
| plot_aurora_output(st.session_state['aurora_out'], st.session_state['aurora_ds_subset']) | |
| elif selected_model == "FengWu": | |
| if 'fengwu_done' in st.session_state and st.session_state['fengwu_done']: | |
| plot_fengwu_output(st.session_state['input_fengwu'], st.session_state['output_fengwu']) | |
| elif selected_model == "Pangu-Weather": | |
| if 'pangu_done' in st.session_state and st.session_state['pangu_done']: | |
| plot_pangu_output( | |
| st.session_state['pangu_upper_data'], | |
| st.session_state['pangu_surface_data'], | |
| st.session_state['pangu_out_upper'], | |
| st.session_state['pangu_out_surface'] | |
| ) | |
| else: | |
| st.info("No visualization implemented for this model.") | |
| else: | |
| # If not running inference now, but we have previously computed results, show them | |
| with right_col: | |
| st.header("π₯οΈ Visualization & Progress") | |
| # Check which model was selected and if we have done inference before | |
| if selected_model == "Prithvi" and 'prithvi_done' in st.session_state and st.session_state['prithvi_done']: | |
| plot_prithvi_output(st.session_state['prithvi_out']) | |
| elif selected_model == "Aurora" and 'aurora_done' in st.session_state and st.session_state['aurora_done']: | |
| plot_aurora_output(st.session_state['aurora_out'], st.session_state['aurora_ds_subset']) | |
| elif selected_model == "Pangu-Weather" and 'pangu_done' in st.session_state and st.session_state['pangu_done']: | |
| plot_pangu_output( | |
| st.session_state['pangu_upper_data'], | |
| st.session_state['pangu_surface_data'], | |
| st.session_state['pangu_out_upper'], | |
| st.session_state['pangu_out_surface'] | |
| ) | |
| elif selected_model == "FengWu" and 'output_fengwu' in st.session_state and st.session_state['fengwu_done']: | |
| plot_fengwu_output(st.session_state['input_fengwu'], st.session_state['output_fengwu']) | |
| else: | |
| st.info("Awaiting inference to display results.") | |