munem420's picture
Update app.py
dfb42e6 verified
import os
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
# -------------------------------------------------------
# CONFIG
# -------------------------------------------------------
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.h5"
SCALER_FILENAME = "scalers.joblib"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# -------------------------------------------------------
# LOAD MODEL AND SCALERS
# -------------------------------------------------------
print("๐Ÿ“ฆ Loading model and scalers...")
try:
model_path = tf.keras.utils.get_file(
MODEL_FILENAME,
f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILENAME}"
)
scalers_path = tf.keras.utils.get_file(
SCALER_FILENAME,
f"https://huggingface.co/{MODEL_REPO}/resolve/main/{SCALER_FILENAME}"
)
model = tf.keras.models.load_model(
model_path,
custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
)
scalers = joblib.load(scalers_path)
print("โœ… Model and scalers loaded successfully.")
except Exception as e:
print(f"โŒ Error loading model or scalers: {e}")
model, scalers = None, None
# -------------------------------------------------------
# FORECAST FUNCTION
# -------------------------------------------------------
def forecast_stock(ticker):
if not model or not scalers:
return "โŒ Model or scalers not loaded properly."
ticker = ticker.strip().upper()
if ticker not in scalers:
return f"โš ๏ธ No scaler found for ticker '{ticker}'. Please check spelling."
# Dummy inference example (replace with actual data fetching or preprocessing)
# Here we just simulate 60 normalized close prices for inference
scaler = scalers[ticker]
dummy_data = np.linspace(0.9, 1.1, 60).reshape(-1, 1)
X_pred = dummy_data.reshape(1, 60, 1)
# Predict scaled value
pred_scaled = model.predict(X_pred, verbose=0)[0][0]
# Inverse transform prediction
pred_actual = scaler.inverse_transform(np.array([[pred_scaled]]))[0][0]
return f"๐Ÿ”ฎ Predicted next day close for **{ticker}**: ${pred_actual:.2f}"
# -------------------------------------------------------
# GRADIO INTERFACE
# -------------------------------------------------------
iface = gr.Interface(
fn=forecast_stock,
inputs=gr.Textbox(label="Enter Ticker or Company Name"),
outputs=gr.Markdown(label="Prediction Result"),
title="๐Ÿ“Š Stock Price Forecaster (LSTM)",
description="Enter a stock ticker or company name to predict the next day's closing price using a trained LSTM model."
)
if __name__ == "__main__":
iface.launch()