Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Voice-enabled WebSocket server that combines the full voice backend with our document search | |
| """ | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| import json | |
| import logging | |
| import lancedb | |
| import pandas as pd | |
| import asyncio | |
| import os | |
| from dotenv import load_dotenv | |
| from dataclasses import asdict, is_dataclass | |
| # Try to import voice services, fallback if not available | |
| try: | |
| from hybrid_llm_service import HybridLLMService | |
| from voice_service import VoiceService | |
| from settings_api import router as settings_router | |
| from policy_simulator_api import router as policy_simulator_router | |
| VOICE_AVAILABLE = True | |
| except ImportError: | |
| VOICE_AVAILABLE = False | |
| logging.warning("Voice services not available, text-only mode") | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Simple response cache for common queries | |
| response_cache = {} | |
| MAX_CACHE_SIZE = 100 | |
| app = FastAPI() | |
| # Include API routers | |
| if VOICE_AVAILABLE: | |
| app.include_router(settings_router) | |
| app.include_router(policy_simulator_router) | |
| # Enable CORS - Include both local development and production origins | |
| allowed_origins = [ | |
| "http://localhost:5176", "http://localhost:5177", | |
| "http://127.0.0.1:5176", "http://127.0.0.1:5177", | |
| "http://localhost:3000", "http://localhost:5173", | |
| "https://*.vercel.app", "https://*.hf.space" | |
| ] | |
| # Add any custom origins from environment | |
| if os.getenv("ALLOWED_ORIGINS"): | |
| try: | |
| custom_origins = eval(os.getenv("ALLOWED_ORIGINS")) | |
| if isinstance(custom_origins, list): | |
| allowed_origins.extend(custom_origins) | |
| except: | |
| pass | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"] if "*" in str(allowed_origins) else allowed_origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize services if available | |
| if VOICE_AVAILABLE: | |
| try: | |
| hybrid_llm_service = HybridLLMService() | |
| voice_service = VoiceService() | |
| logger.info("✅ Voice services initialized") | |
| except Exception as e: | |
| logger.warning(f"⚠️ Voice services failed to initialize: {e}") | |
| VOICE_AVAILABLE = False | |
| def serialize_for_json(obj): | |
| """Custom JSON serializer for policy simulation objects""" | |
| if is_dataclass(obj): | |
| return asdict(obj) | |
| elif hasattr(obj, '__dict__'): | |
| return obj.__dict__ | |
| elif isinstance(obj, (list, tuple)): | |
| return [serialize_for_json(item) for item in obj] | |
| elif isinstance(obj, dict): | |
| return {key: serialize_for_json(value) for key, value in obj.items()} | |
| else: | |
| return obj | |
| def search_documents_simple(query: str): | |
| """Simple document search without embeddings""" | |
| try: | |
| db = lancedb.connect('./lancedb_data') | |
| # Check for Rajasthan documents first | |
| if 'rajasthan_documents' in db.table_names(): | |
| tbl = db.open_table('rajasthan_documents') | |
| df = tbl.to_pandas() | |
| # Enhanced search for Rajasthan/pension queries | |
| query_lower = query.lower() | |
| is_pension_query = any(keyword in query_lower for keyword in [ | |
| 'pension', 'पेंशन', 'वृद्धावस्था', 'सामाजिक', 'भत्ता', 'allowance', | |
| 'old age', 'social security', 'retirement', 'सेवानिवृत्ति' | |
| ]) | |
| if is_pension_query or 'rajasthan' in query_lower: | |
| # Enhanced pension search with more keywords | |
| pension_filter = df['content'].str.contains( | |
| 'pension|Pension|पेंशन|वृद्धावस्था|सामाजिक|भत्ता|allowance|old.age|social.security|retirement|सेवानिवृत्ति|scheme|योजना', | |
| case=False, na=False, regex=True | |
| ) | |
| relevant_docs = df[pension_filter] | |
| if not relevant_docs.empty: | |
| # Sort by relevance | |
| def score_relevance(content): | |
| keywords = ['pension', 'पेंशन', 'वृद्धावस्था', 'सामाजिक', 'भत्ता', 'allowance', 'old age'] | |
| return sum(1 for keyword in keywords if keyword in content.lower()) | |
| relevant_docs = relevant_docs.copy() | |
| relevant_docs['relevance_score'] = relevant_docs['content'].apply(score_relevance) | |
| relevant_docs = relevant_docs.sort_values('relevance_score', ascending=False) | |
| results = [] | |
| for _, row in relevant_docs.head(5).iterrows(): | |
| results.append({ | |
| "content": row['content'][:800], | |
| "filename": row['filename'] | |
| }) | |
| return results, "rajasthan_pension_documents" | |
| return [], "none" | |
| except Exception as e: | |
| logger.error(f"Search error: {e}") | |
| return [], "error" | |
| async def get_llm_response(query: str, search_results: list): | |
| """Get response using available LLM service with caching""" | |
| # Create cache key based on query and search results | |
| cache_key = f"{query}_{len(search_results) if search_results else 0}" | |
| # Check cache first | |
| if cache_key in response_cache: | |
| logger.info(f"📦 Cache hit for query: {query[:50]}...") | |
| return response_cache[cache_key] | |
| try: | |
| if VOICE_AVAILABLE and hybrid_llm_service: | |
| # Use the hybrid LLM service | |
| if search_results: | |
| context = "\\n\\n".join([f"Document: {doc['filename']}\\nContent: {doc['content']}" for doc in search_results]) | |
| enhanced_query = f"Based on these Rajasthan government documents, please answer: {query}\\n\\nDocuments:\\n{context}" | |
| else: | |
| enhanced_query = query | |
| response = await hybrid_llm_service.get_response(enhanced_query) | |
| # Cache the response | |
| if len(response_cache) >= MAX_CACHE_SIZE: | |
| # Remove oldest entry | |
| response_cache.pop(next(iter(response_cache))) | |
| response_cache[cache_key] = response | |
| return response | |
| else: | |
| # Fallback to simple response | |
| if search_results: | |
| response = f"Based on the Rajasthan government documents, I found information about {query}. However, voice processing is currently limited. Please use text chat for detailed responses." | |
| else: | |
| response = f"I received your query about '{query}' but couldn't find specific documents. Please try using text chat for better results." | |
| # Cache fallback response too | |
| response_cache[cache_key] = response | |
| return response | |
| except Exception as e: | |
| logger.error(f"LLM error: {e}") | |
| return "I'm having trouble processing your request. Please try using the text chat." | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| logger.info("🔌 WebSocket client connected") | |
| # Store user session info | |
| user_language = "english" # Default language | |
| try: | |
| # Send initial greeting | |
| await websocket.send_json({ | |
| "type": "connection_successful", | |
| "message": "Hello! I'm your Rajasthan government document assistant. I can help with text and voice queries about pension schemes and government policies." | |
| }) | |
| while True: | |
| try: | |
| # Receive message with better error handling | |
| message = await websocket.receive() | |
| # Handle different message types | |
| if message["type"] == "websocket.receive": | |
| if "text" in message: | |
| # Parse JSON text message | |
| try: | |
| data = json.loads(message["text"]) | |
| except json.JSONDecodeError: | |
| logger.warning(f"⚠️ Invalid JSON received: {message['text']}") | |
| continue | |
| # Process text message | |
| if isinstance(data, dict) and data.get("type") == "text_message": | |
| user_message = data.get("message", "") | |
| if not user_message.strip(): | |
| continue | |
| logger.info(f"💬 Text received: {user_message}") | |
| # Check for interactive scenario form triggers | |
| form_triggers = ["start scenario analysis", "scenario form", "interactive analysis", "step by step analysis", "guided analysis", "form analysis", "scenario chat form", "interactive scenario"] | |
| is_form_request = any(trigger in user_message.lower() for trigger in form_triggers) | |
| # Check if this is a policy simulation query (robust regex patterns) | |
| import re | |
| POLICY_PATTERNS = [ | |
| r"policy.*simulation|simulation.*policy", | |
| r"policy.*scenario|scenario.*policy", | |
| r"policy.*analysis|analysis.*policy", | |
| r"pension.*simulation|simulation.*pension", | |
| r"pension.*analysis|analysis.*pension", | |
| r"pension.*scenario|scenario.*pension", | |
| r"dearness.*relief|dr.*increase|dr.*adjustment", | |
| r"dearness.*allowance|da.*increase|da.*adjustment", | |
| r"minimum.*pension.*increase|increase.*minimum.*pension", | |
| r"calculate.*pension|pension.*calculation", | |
| r"impact.*dr|dr.*impact|impact.*da|da.*impact", | |
| r"show.*impact.*da|show.*impact.*dr", | |
| r"impact.*\d+.*da|impact.*\d+.*dr", | |
| r"\d+.*da.*increase|da.*\d+.*increase", | |
| r"\d+.*dr.*increase|dr.*\d+.*increase", | |
| r"inflation.*adjustment|adjustment.*inflation", | |
| r"scenario.*analysis|analysis.*scenario", | |
| r"what.*if.*dr|what.*if.*pension|what.*if.*da", | |
| r"compare.*scenario|scenario.*comparison", | |
| r"show.*chart|chart.*show", | |
| r"explain.*chart|chart.*explain", | |
| r"using.*chart|chart.*using", | |
| r"dr.*\d+.*increase|increase.*dr.*\d+", | |
| r"da.*\d+.*increase|increase.*da.*\d+", | |
| r"analyze.*minimum.*pension", | |
| r"pension.*change", | |
| r"make.*chart|chart.*make", | |
| r"pension.*value|value.*pension", | |
| r"basic.*pension.*\d+|pension.*\d+", | |
| r"simulate.*dr|simulate.*pension|simulate.*da" | |
| ] | |
| def is_policy_simulation_query(message: str) -> bool: | |
| """Check if the message is a policy simulation query""" | |
| message_lower = message.lower() | |
| logger.info(f"🔍 Checking policy patterns for: '{message_lower}'") | |
| for i, pattern in enumerate(POLICY_PATTERNS): | |
| if re.search(pattern, message_lower, re.IGNORECASE): | |
| logger.info(f"✅ Pattern {i+1} matched: {pattern}") | |
| return True | |
| logger.info("❌ No policy patterns matched") | |
| return False | |
| is_policy_query = is_policy_simulation_query(user_message) | |
| # Handle interactive scenario form request | |
| if is_form_request: | |
| logger.info("📋 Interactive scenario form requested") | |
| try: | |
| from scenario_chat_form import start_scenario_analysis_form | |
| form_response = start_scenario_analysis_form(data.get("user_id", "default")) | |
| # Format form response for chat | |
| form_message = f"""🎯 **{form_response.get('title', 'Interactive Scenario Analysis')}** | |
| {form_response.get('message', '')} | |
| **{form_response.get('step_title', 'Step 1')}** ({form_response.get('current_step', 1)}/{form_response.get('total_steps', 4)}) | |
| {form_response['form_data']['question']} | |
| **Available Options:**""" | |
| # Add form options | |
| if form_response['form_data']['input_type'] == 'select': | |
| for i, option in enumerate(form_response['form_data']['options'], 1): | |
| form_message += f"\n{i}. {option['label']}" | |
| form_message += "\n\n**Quick Actions:**" | |
| for action in form_response.get('quick_actions', []): | |
| form_message += f"\n• {action['text']}" | |
| form_message += "\n\n💡 **Next:** Choose an option above or type your selection!" | |
| await websocket.send_json({ | |
| "type": "interactive_form", | |
| "message": form_message, | |
| "form_data": form_response | |
| }) | |
| continue | |
| except Exception as e: | |
| logger.error(f"Form initialization failed: {str(e)}") | |
| await websocket.send_json({ | |
| "type": "error_message", | |
| "message": f"Sorry, I couldn't start the interactive scenario analysis. Error: {str(e)}" | |
| }) | |
| continue | |
| # Handle policy queries | |
| elif is_policy_query: | |
| logger.info("🎯 Detected policy simulation query") | |
| try: | |
| # Import policy chat interface | |
| from policy_chat_interface import PolicySimulatorChatInterface | |
| # Send acknowledgment for policy simulation | |
| await websocket.send_json({ | |
| "type": "message_received", | |
| "message": "🎯 Analyzing Rajasthan policy impact..." | |
| }) | |
| # Initialize and process policy simulation | |
| policy_simulator = PolicySimulatorChatInterface() | |
| policy_result = policy_simulator.process_policy_query(user_message) | |
| # Format policy response - use same format as working simple backend | |
| if policy_result.get("type") == "policy_simulation": | |
| # Serialize the response for JSON | |
| serialized_response = serialize_for_json(policy_result) | |
| # Send policy simulation response | |
| await websocket.send_json({ | |
| "type": "policy_simulation", | |
| "data": serialized_response | |
| }) | |
| logger.info("📤 Policy simulation response sent to client") | |
| else: | |
| # Handle other policy responses (errors, help, etc.) | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": policy_result.get('message', 'Policy analysis completed') | |
| }) | |
| continue | |
| except Exception as e: | |
| logger.error(f"Policy simulation failed: {str(e)}") | |
| await websocket.send_json({ | |
| "type": "error_message", | |
| "message": f"Sorry, policy analysis failed. Using document search instead." | |
| }) | |
| # Fall through to regular document search | |
| # Regular document search (fallback) | |
| # Send acknowledgment | |
| await websocket.send_json({ | |
| "type": "message_received", | |
| "message": "🔍 Searching Rajasthan government documents..." | |
| }) | |
| # Search for relevant documents | |
| search_results, source = search_documents_simple(user_message) | |
| logger.info(f"🔍 Found {len(search_results)} documents from {source}") | |
| # Get LLM response | |
| llm_response = await get_llm_response(user_message, search_results) | |
| # Send response | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": llm_response | |
| }) | |
| elif isinstance(data, dict) and data.get("type") == "user_info": | |
| user_name = data.get("user_name", "Unknown") | |
| logger.info(f"👤 User connected: {user_name}") | |
| elif isinstance(data, dict) and data.get("lang"): | |
| new_language = data.get("lang", "english") | |
| if new_language != user_language: | |
| user_language = new_language | |
| logger.info(f"🌍 Language preference updated: {user_language}") | |
| # Avoid logging if language hasn't changed | |
| elif "bytes" in message: | |
| # Handle binary message (audio data) | |
| audio_data = message["bytes"] | |
| logger.info(f"🎤 Received audio data: {len(audio_data)} bytes") | |
| if VOICE_AVAILABLE and voice_service: | |
| try: | |
| # Save audio data to temporary file for processing | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
| temp_file.write(audio_data) | |
| temp_file_path = temp_file.name | |
| # Process audio with voice service using user's language preference | |
| text = await voice_service.speech_to_text(temp_file_path, user_language) | |
| # Clean up temp file | |
| os.unlink(temp_file_path) | |
| if text and text.strip(): | |
| logger.info(f"🎤 Transcribed: {text}") | |
| # Search documents | |
| search_results, source = search_documents_simple(text) | |
| logger.info(f"🔍 Found {len(search_results)} documents from {source}") | |
| # Get LLM response | |
| llm_response = await get_llm_response(text, search_results) | |
| # Send text response | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": llm_response | |
| }) | |
| # Try to send voice response | |
| try: | |
| audio_response = await voice_service.text_to_speech(llm_response) | |
| if audio_response: | |
| await websocket.send_bytes(audio_response) | |
| except Exception as tts_error: | |
| logger.warning(f"TTS failed: {tts_error}") | |
| else: | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": "I couldn't understand what you said. Please try speaking more clearly or use text chat." | |
| }) | |
| except Exception as voice_error: | |
| logger.error(f"Voice processing error: {voice_error}") | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": "Sorry, I couldn't process your voice input. Please try speaking again or use text chat." | |
| }) | |
| else: | |
| # Voice services not available | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": "Voice processing is currently unavailable. Please use the text chat to ask about Rajasthan pension schemes and government policies." | |
| }) | |
| elif message["type"] == "websocket.disconnect": | |
| break | |
| except json.JSONDecodeError as e: | |
| logger.warning(f"⚠️ JSON decode error: {e}") | |
| continue | |
| except KeyError as e: | |
| logger.warning(f"⚠️ Missing key in message: {e}") | |
| continue | |
| except WebSocketDisconnect: | |
| logger.info("🔌 WebSocket client disconnected") | |
| except Exception as e: | |
| logger.error(f"❌ WebSocket error: {e}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| db = lancedb.connect('./lancedb_data') | |
| tables = db.table_names() | |
| return { | |
| "status": "healthy", | |
| "tables": tables, | |
| "voice_available": VOICE_AVAILABLE | |
| } | |
| except Exception as e: | |
| return {"status": "error", "error": str(e)} | |
| if __name__ == "__main__": | |
| print("🚀 Starting voice-enabled WebSocket server on port 8000...") | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |