amitlals commited on
Commit
974a628
·
1 Parent(s): fc8c40e

Add models directory with RPT model wrapper

Browse files
Files changed (3) hide show
  1. .gitignore +9 -5
  2. models/__init__.py +2 -0
  3. models/rpt_model.py +210 -0
.gitignore CHANGED
@@ -37,12 +37,16 @@ ENV/
37
  .env
38
  .env.local
39
 
40
- # Model cache
41
  .cache/
42
- models/
43
- *.pth
44
- *.pt
45
- *.ckpt
 
 
 
 
46
 
47
  # Data files (optional - uncomment if you don't want to track data)
48
  # data/*.csv
 
37
  .env
38
  .env.local
39
 
40
+ # Model cache and downloaded models
41
  .cache/
42
+ models/*.pth
43
+ models/*.pt
44
+ models/*.ckpt
45
+ models/*.bin
46
+ models/*.safetensors
47
+ # But keep Python source files in models/
48
+ !models/*.py
49
+ !models/__init__.py
50
 
51
  # Data files (optional - uncomment if you don't want to track data)
52
  # data/*.csv
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Models package
2
+
models/rpt_model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAP-RPT-1-OSS Model Wrapper
3
+
4
+ Provides a wrapper for SAP-RPT-OSS-Classifier and Regressor with
5
+ authentication handling and CPU fallback options.
6
+ """
7
+
8
+ import os
9
+ import logging
10
+ from typing import Optional, Union
11
+ import pandas as pd
12
+ import numpy as np
13
+ from huggingface_hub import login as hf_login
14
+ from dotenv import load_dotenv
15
+
16
+ # Try to import SAP-RPT-OSS models
17
+ try:
18
+ from sap_rpt_oss import SAP_RPT_OSS_Classifier, SAP_RPT_OSS_Regressor
19
+ SAP_RPT_AVAILABLE = True
20
+ except ImportError:
21
+ SAP_RPT_AVAILABLE = False
22
+ logging.warning("sap-rpt-oss package not installed. Install with: pip install git+https://github.com/SAP-samples/sap-rpt-1-oss")
23
+
24
+ load_dotenv()
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class RPTModelWrapper:
31
+ """Wrapper for SAP-RPT-1-OSS models with authentication and resource management."""
32
+
33
+ def __init__(self, model_type: str = "classifier", max_context_size: int = 2048, bagging: int = 1):
34
+ """
35
+ Initialize the RPT model wrapper.
36
+
37
+ Args:
38
+ model_type: "classifier" or "regressor"
39
+ max_context_size: Maximum context size (8192 for best performance, 2048 for CPU)
40
+ bagging: Bagging factor (8 for best performance, 1 for lightweight)
41
+ """
42
+ if not SAP_RPT_AVAILABLE:
43
+ raise ImportError("sap-rpt-oss package is not installed. Please install it first.")
44
+
45
+ self.model_type = model_type.lower()
46
+ self.max_context_size = max_context_size
47
+ self.bagging = bagging
48
+ self.model = None
49
+ self.is_fitted = False
50
+
51
+ # Check for Hugging Face token
52
+ self._check_hf_authentication()
53
+
54
+ # Initialize model
55
+ self._initialize_model()
56
+
57
+ def _check_hf_authentication(self):
58
+ """Check and handle Hugging Face authentication."""
59
+ hf_token = os.getenv("HUGGINGFACE_TOKEN")
60
+
61
+ if hf_token:
62
+ try:
63
+ hf_login(token=hf_token)
64
+ logger.info("Hugging Face authentication successful using token from environment.")
65
+ except Exception as e:
66
+ logger.warning(f"Failed to login with token: {e}. Trying interactive login...")
67
+ try:
68
+ hf_login()
69
+ except Exception as e2:
70
+ logger.error(f"Hugging Face authentication failed: {e2}")
71
+ else:
72
+ logger.warning("HUGGINGFACE_TOKEN not found in environment. Attempting interactive login...")
73
+ try:
74
+ hf_login()
75
+ except Exception as e:
76
+ logger.error(f"Hugging Face authentication failed: {e}")
77
+ logger.info("Please set HUGGINGFACE_TOKEN in .env file or run: huggingface-cli login")
78
+
79
+ def _initialize_model(self):
80
+ """Initialize the appropriate model based on type."""
81
+ try:
82
+ if self.model_type == "classifier":
83
+ self.model = SAP_RPT_OSS_Classifier(
84
+ max_context_size=self.max_context_size,
85
+ bagging=self.bagging
86
+ )
87
+ logger.info(f"Initialized SAP-RPT-OSS-Classifier with context_size={self.max_context_size}, bagging={self.bagging}")
88
+ elif self.model_type == "regressor":
89
+ self.model = SAP_RPT_OSS_Regressor(
90
+ max_context_size=self.max_context_size,
91
+ bagging=self.bagging
92
+ )
93
+ logger.info(f"Initialized SAP-RPT-OSS-Regressor with context_size={self.max_context_size}, bagging={self.bagging}")
94
+ else:
95
+ raise ValueError(f"Invalid model_type: {self.model_type}. Must be 'classifier' or 'regressor'")
96
+ except Exception as e:
97
+ logger.error(f"Failed to initialize model: {e}")
98
+ raise
99
+
100
+ def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]):
101
+ """
102
+ Fit the model on training data.
103
+
104
+ Args:
105
+ X: Feature data (DataFrame or array)
106
+ y: Target data (Series or array)
107
+ """
108
+ try:
109
+ if isinstance(X, np.ndarray):
110
+ # Convert to DataFrame if needed
111
+ X = pd.DataFrame(X)
112
+ if isinstance(y, np.ndarray):
113
+ y = pd.Series(y)
114
+
115
+ logger.info(f"Fitting model on {len(X)} samples...")
116
+ self.model.fit(X, y)
117
+ self.is_fitted = True
118
+ logger.info("Model fitting completed successfully.")
119
+ except Exception as e:
120
+ logger.error(f"Error during model fitting: {e}")
121
+ raise
122
+
123
+ def predict(self, X: Union[pd.DataFrame, np.ndarray]):
124
+ """
125
+ Make predictions.
126
+
127
+ Args:
128
+ X: Feature data (DataFrame or array)
129
+
130
+ Returns:
131
+ Predictions (array)
132
+ """
133
+ if not self.is_fitted:
134
+ raise ValueError("Model must be fitted before making predictions. Call fit() first.")
135
+
136
+ try:
137
+ if isinstance(X, np.ndarray):
138
+ X = pd.DataFrame(X)
139
+
140
+ logger.info(f"Making predictions on {len(X)} samples...")
141
+ predictions = self.model.predict(X)
142
+ return predictions
143
+ except Exception as e:
144
+ logger.error(f"Error during prediction: {e}")
145
+ raise
146
+
147
+ def predict_proba(self, X: Union[pd.DataFrame, np.ndarray]):
148
+ """
149
+ Predict class probabilities (classification only).
150
+
151
+ Args:
152
+ X: Feature data (DataFrame or array)
153
+
154
+ Returns:
155
+ Probability predictions (array)
156
+ """
157
+ if self.model_type != "classifier":
158
+ raise ValueError("predict_proba() is only available for classifiers.")
159
+
160
+ if not self.is_fitted:
161
+ raise ValueError("Model must be fitted before making predictions. Call fit() first.")
162
+
163
+ try:
164
+ if isinstance(X, np.ndarray):
165
+ X = pd.DataFrame(X)
166
+
167
+ logger.info(f"Predicting probabilities on {len(X)} samples...")
168
+ probabilities = self.model.predict_proba(X)
169
+ return probabilities
170
+ except Exception as e:
171
+ logger.error(f"Error during probability prediction: {e}")
172
+ raise
173
+
174
+ def get_model_info(self):
175
+ """Get information about the current model configuration."""
176
+ return {
177
+ "model_type": self.model_type,
178
+ "max_context_size": self.max_context_size,
179
+ "bagging": self.bagging,
180
+ "is_fitted": self.is_fitted,
181
+ "sap_rpt_available": SAP_RPT_AVAILABLE
182
+ }
183
+
184
+
185
+ def create_model(model_type: str = "classifier", use_gpu: bool = True):
186
+ """
187
+ Factory function to create a model with appropriate settings.
188
+
189
+ Args:
190
+ model_type: "classifier" or "regressor"
191
+ use_gpu: Whether to use GPU-optimized settings (requires 80GB GPU memory)
192
+
193
+ Returns:
194
+ RPTModelWrapper instance
195
+ """
196
+ if use_gpu:
197
+ # Best performance settings (requires 80GB GPU)
198
+ return RPTModelWrapper(
199
+ model_type=model_type,
200
+ max_context_size=8192,
201
+ bagging=8
202
+ )
203
+ else:
204
+ # CPU-friendly settings
205
+ return RPTModelWrapper(
206
+ model_type=model_type,
207
+ max_context_size=2048,
208
+ bagging=1
209
+ )
210
+