File size: 25,480 Bytes
cf02b2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
#!/usr/bin/env python3
"""
Voice-enabled WebSocket server that combines the full voice backend with our document search
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json
import logging
import lancedb
import pandas as pd
import asyncio
import os
from dotenv import load_dotenv
from dataclasses import asdict, is_dataclass

# Try to import voice services, fallback if not available
try:
    from hybrid_llm_service import HybridLLMService
    from voice_service import VoiceService
    from settings_api import router as settings_router
    from policy_simulator_api import router as policy_simulator_router
    VOICE_AVAILABLE = True
except ImportError:
    VOICE_AVAILABLE = False
    logging.warning("Voice services not available, text-only mode")

load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Simple response cache for common queries
response_cache = {}
MAX_CACHE_SIZE = 100

app = FastAPI()

# Include API routers
if VOICE_AVAILABLE:
    app.include_router(settings_router)
    app.include_router(policy_simulator_router)

# Enable CORS - Include both local development and production origins
allowed_origins = [
    "http://localhost:5176", "http://localhost:5177", 
    "http://127.0.0.1:5176", "http://127.0.0.1:5177",
    "http://localhost:3000", "http://localhost:5173",
    "https://*.vercel.app", "https://*.hf.space"
]

# Add any custom origins from environment
if os.getenv("ALLOWED_ORIGINS"):
    try:
        custom_origins = eval(os.getenv("ALLOWED_ORIGINS"))
        if isinstance(custom_origins, list):
            allowed_origins.extend(custom_origins)
    except:
        pass

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"] if "*" in str(allowed_origins) else allowed_origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize services if available
if VOICE_AVAILABLE:
    try:
        hybrid_llm_service = HybridLLMService()
        voice_service = VoiceService()
        logger.info("✅ Voice services initialized")
    except Exception as e:
        logger.warning(f"⚠️ Voice services failed to initialize: {e}")
        VOICE_AVAILABLE = False

def serialize_for_json(obj):
    """Custom JSON serializer for policy simulation objects"""
    if is_dataclass(obj):
        return asdict(obj)
    elif hasattr(obj, '__dict__'):
        return obj.__dict__
    elif isinstance(obj, (list, tuple)):
        return [serialize_for_json(item) for item in obj]
    elif isinstance(obj, dict):
        return {key: serialize_for_json(value) for key, value in obj.items()}
    else:
        return obj

def search_documents_simple(query: str):
    """Simple document search without embeddings"""
    try:
        db = lancedb.connect('./lancedb_data')
        
        # Check for Rajasthan documents first
        if 'rajasthan_documents' in db.table_names():
            tbl = db.open_table('rajasthan_documents')
            df = tbl.to_pandas()
            
            # Enhanced search for Rajasthan/pension queries
            query_lower = query.lower()
            is_pension_query = any(keyword in query_lower for keyword in [
                'pension', 'पेंशन', 'वृद्धावस्था', 'सामाजिक', 'भत्ता', 'allowance', 
                'old age', 'social security', 'retirement', 'सेवानिवृत्ति'
            ])
            
            if is_pension_query or 'rajasthan' in query_lower:
                # Enhanced pension search with more keywords
                pension_filter = df['content'].str.contains(
                    'pension|Pension|पेंशन|वृद्धावस्था|सामाजिक|भत्ता|allowance|old.age|social.security|retirement|सेवानिवृत्ति|scheme|योजना', 
                    case=False, na=False, regex=True
                )
                
                relevant_docs = df[pension_filter]
                
                if not relevant_docs.empty:
                    # Sort by relevance
                    def score_relevance(content):
                        keywords = ['pension', 'पेंशन', 'वृद्धावस्था', 'सामाजिक', 'भत्ता', 'allowance', 'old age']
                        return sum(1 for keyword in keywords if keyword in content.lower())
                    
                    relevant_docs = relevant_docs.copy()
                    relevant_docs['relevance_score'] = relevant_docs['content'].apply(score_relevance)
                    relevant_docs = relevant_docs.sort_values('relevance_score', ascending=False)
                    
                    results = []
                    for _, row in relevant_docs.head(5).iterrows():
                        results.append({
                            "content": row['content'][:800],
                            "filename": row['filename']
                        })
                    return results, "rajasthan_pension_documents"
        
        return [], "none"
        
    except Exception as e:
        logger.error(f"Search error: {e}")
        return [], "error"

async def get_llm_response(query: str, search_results: list):
    """Get response using available LLM service with caching"""
    # Create cache key based on query and search results
    cache_key = f"{query}_{len(search_results) if search_results else 0}"
    
    # Check cache first
    if cache_key in response_cache:
        logger.info(f"📦 Cache hit for query: {query[:50]}...")
        return response_cache[cache_key]
    
    try:
        if VOICE_AVAILABLE and hybrid_llm_service:
            # Use the hybrid LLM service
            if search_results:
                context = "\\n\\n".join([f"Document: {doc['filename']}\\nContent: {doc['content']}" for doc in search_results])
                enhanced_query = f"Based on these Rajasthan government documents, please answer: {query}\\n\\nDocuments:\\n{context}"
            else:
                enhanced_query = query
                
            response = await hybrid_llm_service.get_response(enhanced_query)
            
            # Cache the response
            if len(response_cache) >= MAX_CACHE_SIZE:
                # Remove oldest entry
                response_cache.pop(next(iter(response_cache)))
            response_cache[cache_key] = response
            
            return response
        else:
            # Fallback to simple response
            if search_results:
                response = f"Based on the Rajasthan government documents, I found information about {query}. However, voice processing is currently limited. Please use text chat for detailed responses."
            else:
                response = f"I received your query about '{query}' but couldn't find specific documents. Please try using text chat for better results."
            
            # Cache fallback response too
            response_cache[cache_key] = response
            return response
            
    except Exception as e:
        logger.error(f"LLM error: {e}")
        return "I'm having trouble processing your request. Please try using the text chat."

@app.websocket("/ws/stream")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    logger.info("🔌 WebSocket client connected")
    
    # Store user session info
    user_language = "english"  # Default language
    
    try:
        # Send initial greeting
        await websocket.send_json({
            "type": "connection_successful",
            "message": "Hello! I'm your Rajasthan government document assistant. I can help with text and voice queries about pension schemes and government policies."
        })
        
        while True:
            try:
                # Receive message with better error handling
                message = await websocket.receive()
                
                # Handle different message types
                if message["type"] == "websocket.receive":
                    if "text" in message:
                        # Parse JSON text message
                        try:
                            data = json.loads(message["text"])
                        except json.JSONDecodeError:
                            logger.warning(f"⚠️ Invalid JSON received: {message['text']}")
                            continue
                            
                        # Process text message
                        if isinstance(data, dict) and data.get("type") == "text_message":
                            user_message = data.get("message", "")
                            if not user_message.strip():
                                continue
                                
                            logger.info(f"💬 Text received: {user_message}")
                            
                            # Check for interactive scenario form triggers
                            form_triggers = ["start scenario analysis", "scenario form", "interactive analysis", "step by step analysis", "guided analysis", "form analysis", "scenario chat form", "interactive scenario"]
                            is_form_request = any(trigger in user_message.lower() for trigger in form_triggers)
                            
                            # Check if this is a policy simulation query (robust regex patterns)
                            import re
                            POLICY_PATTERNS = [
                                r"policy.*simulation|simulation.*policy",
                                r"policy.*scenario|scenario.*policy", 
                                r"policy.*analysis|analysis.*policy",
                                r"pension.*simulation|simulation.*pension",
                                r"pension.*analysis|analysis.*pension",
                                r"pension.*scenario|scenario.*pension",
                                r"dearness.*relief|dr.*increase|dr.*adjustment",
                                r"dearness.*allowance|da.*increase|da.*adjustment",
                                r"minimum.*pension.*increase|increase.*minimum.*pension",
                                r"calculate.*pension|pension.*calculation",
                                r"impact.*dr|dr.*impact|impact.*da|da.*impact",
                                r"show.*impact.*da|show.*impact.*dr",
                                r"impact.*\d+.*da|impact.*\d+.*dr",
                                r"\d+.*da.*increase|da.*\d+.*increase",
                                r"\d+.*dr.*increase|dr.*\d+.*increase",
                                r"inflation.*adjustment|adjustment.*inflation",
                                r"scenario.*analysis|analysis.*scenario",
                                r"what.*if.*dr|what.*if.*pension|what.*if.*da",
                                r"compare.*scenario|scenario.*comparison",
                                r"show.*chart|chart.*show",
                                r"explain.*chart|chart.*explain", 
                                r"using.*chart|chart.*using",
                                r"dr.*\d+.*increase|increase.*dr.*\d+",
                                r"da.*\d+.*increase|increase.*da.*\d+",
                                r"analyze.*minimum.*pension",
                                r"pension.*change",
                                r"make.*chart|chart.*make",
                                r"pension.*value|value.*pension",
                                r"basic.*pension.*\d+|pension.*\d+",
                                r"simulate.*dr|simulate.*pension|simulate.*da"
                            ]
                            
                            def is_policy_simulation_query(message: str) -> bool:
                                """Check if the message is a policy simulation query"""
                                message_lower = message.lower()
                                logger.info(f"🔍 Checking policy patterns for: '{message_lower}'")
                                
                                for i, pattern in enumerate(POLICY_PATTERNS):
                                    if re.search(pattern, message_lower, re.IGNORECASE):
                                        logger.info(f"✅ Pattern {i+1} matched: {pattern}")
                                        return True
                                
                                logger.info("❌ No policy patterns matched")
                                return False
                            
                            is_policy_query = is_policy_simulation_query(user_message)
                            
                            # Handle interactive scenario form request
                            if is_form_request:
                                logger.info("📋 Interactive scenario form requested")
                                
                                try:
                                    from scenario_chat_form import start_scenario_analysis_form
                                    form_response = start_scenario_analysis_form(data.get("user_id", "default"))
                                    
                                    # Format form response for chat
                                    form_message = f"""🎯 **{form_response.get('title', 'Interactive Scenario Analysis')}**

{form_response.get('message', '')}

**{form_response.get('step_title', 'Step 1')}** ({form_response.get('current_step', 1)}/{form_response.get('total_steps', 4)})

{form_response['form_data']['question']}

**Available Options:**"""
                                    
                                    # Add form options
                                    if form_response['form_data']['input_type'] == 'select':
                                        for i, option in enumerate(form_response['form_data']['options'], 1):
                                            form_message += f"\n{i}. {option['label']}"
                                    
                                    form_message += "\n\n**Quick Actions:**"
                                    for action in form_response.get('quick_actions', []):
                                        form_message += f"\n• {action['text']}"
                                        
                                    form_message += "\n\n💡 **Next:** Choose an option above or type your selection!"
                                    
                                    await websocket.send_json({
                                        "type": "interactive_form",
                                        "message": form_message,
                                        "form_data": form_response
                                    })
                                    continue
                                    
                                except Exception as e:
                                    logger.error(f"Form initialization failed: {str(e)}")
                                    await websocket.send_json({
                                        "type": "error_message",
                                        "message": f"Sorry, I couldn't start the interactive scenario analysis. Error: {str(e)}"
                                    })
                                    continue
                            
                            # Handle policy queries
                            elif is_policy_query:
                                logger.info("🎯 Detected policy simulation query")
                                
                                try:
                                    # Import policy chat interface
                                    from policy_chat_interface import PolicySimulatorChatInterface
                                    
                                    # Send acknowledgment for policy simulation
                                    await websocket.send_json({
                                        "type": "message_received", 
                                        "message": "🎯 Analyzing Rajasthan policy impact..."
                                    })
                                    
                                    # Initialize and process policy simulation
                                    policy_simulator = PolicySimulatorChatInterface()
                                    policy_result = policy_simulator.process_policy_query(user_message)
                                    
                                    # Format policy response - use same format as working simple backend
                                    if policy_result.get("type") == "policy_simulation":
                                        # Serialize the response for JSON
                                        serialized_response = serialize_for_json(policy_result)
                                        
                                        # Send policy simulation response
                                        await websocket.send_json({
                                            "type": "policy_simulation",
                                            "data": serialized_response
                                        })
                                        logger.info("📤 Policy simulation response sent to client")
                                    else:
                                        # Handle other policy responses (errors, help, etc.)
                                        await websocket.send_json({
                                            "type": "text_response",
                                            "message": policy_result.get('message', 'Policy analysis completed')
                                        })
                                    
                                    continue
                                    
                                except Exception as e:
                                    logger.error(f"Policy simulation failed: {str(e)}")
                                    await websocket.send_json({
                                        "type": "error_message",
                                        "message": f"Sorry, policy analysis failed. Using document search instead."
                                    })
                                    # Fall through to regular document search
                            
                            # Regular document search (fallback)
                            # Send acknowledgment
                            await websocket.send_json({
                                "type": "message_received",
                                "message": "🔍 Searching Rajasthan government documents..."
                            })
                            
                            # Search for relevant documents
                            search_results, source = search_documents_simple(user_message)
                            logger.info(f"🔍 Found {len(search_results)} documents from {source}")
                            
                            # Get LLM response
                            llm_response = await get_llm_response(user_message, search_results)
                            
                            # Send response
                            await websocket.send_json({
                                "type": "text_response",
                                "message": llm_response
                            })
                            
                        elif isinstance(data, dict) and data.get("type") == "user_info":
                            user_name = data.get("user_name", "Unknown")
                            logger.info(f"👤 User connected: {user_name}")
                            
                        elif isinstance(data, dict) and data.get("lang"):
                            new_language = data.get("lang", "english")
                            if new_language != user_language:
                                user_language = new_language
                                logger.info(f"🌍 Language preference updated: {user_language}")
                            # Avoid logging if language hasn't changed
                            
                    elif "bytes" in message:
                        # Handle binary message (audio data)
                        audio_data = message["bytes"]
                        logger.info(f"🎤 Received audio data: {len(audio_data)} bytes")
                        
                        if VOICE_AVAILABLE and voice_service:
                            try:
                                # Save audio data to temporary file for processing
                                import tempfile
                                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
                                    temp_file.write(audio_data)
                                    temp_file_path = temp_file.name
                                
                                # Process audio with voice service using user's language preference
                                text = await voice_service.speech_to_text(temp_file_path, user_language)
                                
                                # Clean up temp file
                                os.unlink(temp_file_path)
                                
                                if text and text.strip():
                                    logger.info(f"🎤 Transcribed: {text}")
                                    
                                    # Search documents
                                    search_results, source = search_documents_simple(text)
                                    logger.info(f"🔍 Found {len(search_results)} documents from {source}")
                                    
                                    # Get LLM response
                                    llm_response = await get_llm_response(text, search_results)
                                    
                                    # Send text response
                                    await websocket.send_json({
                                        "type": "text_response",
                                        "message": llm_response
                                    })
                                    
                                    # Try to send voice response
                                    try:
                                        audio_response = await voice_service.text_to_speech(llm_response)
                                        if audio_response:
                                            await websocket.send_bytes(audio_response)
                                    except Exception as tts_error:
                                        logger.warning(f"TTS failed: {tts_error}")
                                else:
                                    await websocket.send_json({
                                        "type": "text_response",
                                        "message": "I couldn't understand what you said. Please try speaking more clearly or use text chat."
                                    })
                                        
                            except Exception as voice_error:
                                logger.error(f"Voice processing error: {voice_error}")
                                await websocket.send_json({
                                    "type": "text_response",
                                    "message": "Sorry, I couldn't process your voice input. Please try speaking again or use text chat."
                                })
                        else:
                            # Voice services not available
                            await websocket.send_json({
                                "type": "text_response", 
                                "message": "Voice processing is currently unavailable. Please use the text chat to ask about Rajasthan pension schemes and government policies."
                            })
                            
                elif message["type"] == "websocket.disconnect":
                    break
                    
            except json.JSONDecodeError as e:
                logger.warning(f"⚠️ JSON decode error: {e}")
                continue
            except KeyError as e:
                logger.warning(f"⚠️ Missing key in message: {e}")
                continue
                
    except WebSocketDisconnect:
        logger.info("🔌 WebSocket client disconnected")
    except Exception as e:
        logger.error(f"❌ WebSocket error: {e}")

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    try:
        db = lancedb.connect('./lancedb_data')
        tables = db.table_names()
        return {
            "status": "healthy", 
            "tables": tables,
            "voice_available": VOICE_AVAILABLE
        }
    except Exception as e:
        return {"status": "error", "error": str(e)}

if __name__ == "__main__":
    print("🚀 Starting voice-enabled WebSocket server on port 8000...")
    uvicorn.run(app, host="0.0.0.0", port=8000)