munem420 commited on
Commit
4ef68bd
·
verified ·
1 Parent(s): e892b2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -53
app.py CHANGED
@@ -8,14 +8,10 @@ from huggingface_hub import hf_hub_download
8
 
9
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
10
 
11
-
12
-
13
  MODEL_REPO = "munem420/stock-forecaster-lstm"
14
  MODEL_FILENAME = "model_lstm.h5"
15
  SCALER_FILENAME = "scalers.joblib"
16
 
17
-
18
-
19
  print("--- Downloading model and scalers ---")
20
  try:
21
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
@@ -28,7 +24,6 @@ except Exception as e:
28
  loaded_model_lstm = None
29
  loaded_scalers = None
30
 
31
-
32
  if model_path and os.path.exists(model_path):
33
  try:
34
  loaded_model_lstm = tf.keras.models.load_model(
@@ -48,73 +43,53 @@ if scalers_path and os.path.exists(scalers_path):
48
 
49
  ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
50
 
51
-
52
-
53
-
54
-
55
-
56
-
57
-
58
-
59
 
60
  def get_ticker_from_input(input_name):
61
- return input_name.upper()
62
-
63
-
64
-
65
 
 
 
 
 
66
 
67
-
68
- def forecast_stock(input_name, model, scalers_dict, input_width=60):
69
  if not model or not scalers_dict:
70
- return "Error: Model or scalers not loaded. The backend may have failed to start."
71
- ticker = get_ticker_from_input(input_name)
72
- if not ticker:
73
- return "Error: Invalid stock ticker."
74
- print(f"\n--- Generating forecast for {ticker} ---")
75
-
76
-
77
 
78
-
79
  if len(data_df) < input_width:
80
  return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
 
81
  recent_data = data_df.tail(input_width)
82
- close_prices = recent_data['Close'].values.reshape(-input, 1)
83
  scaler = scalers_dict.get(ticker)
 
84
  if not scaler:
85
- print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
86
  scaler = scalers_dict.get('ZURVY')
87
  if not scaler:
88
- return "Error: Default scaler 'ZURVY' not found."
 
89
  scaled_data = scaler.transform(close_prices)
90
  X_pred = scaled_data.reshape(1, input_width, 1)
91
  prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
92
  prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
93
  last_close = recent_data['Close'].iloc[-1]
94
- result = (
95
- f"Last known close for {ticker}: ${last_close:.2f}\n"
96
- f"Predicted next day's close price: ${prediction_actual:.2f}"
97
- )
98
- print(result)
99
- return result
100
-
101
- def predict_api(ticker_symbol):
102
- return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
103
-
104
- with gr.Blocks() as app:
105
- gr.Markdown("This is the backend for the React Stock Forecaster App.")
106
- ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
107
- output_text = gr.Textbox(label="Forecast", visible=False)
108
- ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
109
-
110
- app = gr.mount_static_directory(app, "build")
111
-
112
-
113
-
114
-
115
-
116
 
 
117
 
 
 
 
 
 
 
 
 
118
 
119
  if __name__ == "__main__":
120
- app.launch()
 
8
 
9
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
10
 
 
 
11
  MODEL_REPO = "munem420/stock-forecaster-lstm"
12
  MODEL_FILENAME = "model_lstm.h5"
13
  SCALER_FILENAME = "scalers.joblib"
14
 
 
 
15
  print("--- Downloading model and scalers ---")
16
  try:
17
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
 
24
  loaded_model_lstm = None
25
  loaded_scalers = None
26
 
 
27
  if model_path and os.path.exists(model_path):
28
  try:
29
  loaded_model_lstm = tf.keras.models.load_model(
 
43
 
44
  ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
45
 
46
+ # Example placeholder DataFrame (replace with your actual data)
47
+ data_df = pd.DataFrame({
48
+ "Date": pd.date_range(start="2024-01-01", periods=100),
49
+ "Close": np.linspace(100, 200, 100)
50
+ })
 
 
 
51
 
52
  def get_ticker_from_input(input_name):
53
+ return input_name.upper().strip()
 
 
 
54
 
55
+ def forecast_stock(input_name):
56
+ model = loaded_model_lstm
57
+ scalers_dict = loaded_scalers
58
+ input_width = 60
59
 
 
 
60
  if not model or not scalers_dict:
61
+ return "Error: Model or scalers not loaded."
 
 
 
 
 
 
62
 
63
+ ticker = get_ticker_from_input(input_name)
64
  if len(data_df) < input_width:
65
  return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
66
+
67
  recent_data = data_df.tail(input_width)
68
+ close_prices = recent_data['Close'].values.reshape(-1, 1)
69
  scaler = scalers_dict.get(ticker)
70
+
71
  if not scaler:
72
+ print(f"⚠️ No specific scaler for {ticker}. Using fallback.")
73
  scaler = scalers_dict.get('ZURVY')
74
  if not scaler:
75
+ return "Error: No default scaler found."
76
+
77
  scaled_data = scaler.transform(close_prices)
78
  X_pred = scaled_data.reshape(1, input_width, 1)
79
  prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
80
  prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
81
  last_close = recent_data['Close'].iloc[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ return f"Last close for {ticker}: ${last_close:.2f}\nPredicted next day close: ${prediction_actual:.2f}"
84
 
85
+ # ✅ Simple Gradio interface
86
+ iface = gr.Interface(
87
+ fn=forecast_stock,
88
+ inputs=gr.Textbox(label="Enter Ticker or Company Name"),
89
+ outputs=gr.Textbox(label="Predicted Next Day Close"),
90
+ title="Stock Price Forecaster (LSTM)",
91
+ description="Enter a stock ticker or company name to predict the next day's close price."
92
+ )
93
 
94
  if __name__ == "__main__":
95
+ iface.launch()