PensionBot / voice_websocket_server.py
ChAbhishek28's picture
Deploy clean Voice Bot backend to HF Spaces
cf02b2b
raw
history blame
25.5 kB
#!/usr/bin/env python3
"""
Voice-enabled WebSocket server that combines the full voice backend with our document search
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json
import logging
import lancedb
import pandas as pd
import asyncio
import os
from dotenv import load_dotenv
from dataclasses import asdict, is_dataclass
# Try to import voice services, fallback if not available
try:
from hybrid_llm_service import HybridLLMService
from voice_service import VoiceService
from settings_api import router as settings_router
from policy_simulator_api import router as policy_simulator_router
VOICE_AVAILABLE = True
except ImportError:
VOICE_AVAILABLE = False
logging.warning("Voice services not available, text-only mode")
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Simple response cache for common queries
response_cache = {}
MAX_CACHE_SIZE = 100
app = FastAPI()
# Include API routers
if VOICE_AVAILABLE:
app.include_router(settings_router)
app.include_router(policy_simulator_router)
# Enable CORS - Include both local development and production origins
allowed_origins = [
"http://localhost:5176", "http://localhost:5177",
"http://127.0.0.1:5176", "http://127.0.0.1:5177",
"http://localhost:3000", "http://localhost:5173",
"https://*.vercel.app", "https://*.hf.space"
]
# Add any custom origins from environment
if os.getenv("ALLOWED_ORIGINS"):
try:
custom_origins = eval(os.getenv("ALLOWED_ORIGINS"))
if isinstance(custom_origins, list):
allowed_origins.extend(custom_origins)
except:
pass
app.add_middleware(
CORSMiddleware,
allow_origins=["*"] if "*" in str(allowed_origins) else allowed_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize services if available
if VOICE_AVAILABLE:
try:
hybrid_llm_service = HybridLLMService()
voice_service = VoiceService()
logger.info("✅ Voice services initialized")
except Exception as e:
logger.warning(f"⚠️ Voice services failed to initialize: {e}")
VOICE_AVAILABLE = False
def serialize_for_json(obj):
"""Custom JSON serializer for policy simulation objects"""
if is_dataclass(obj):
return asdict(obj)
elif hasattr(obj, '__dict__'):
return obj.__dict__
elif isinstance(obj, (list, tuple)):
return [serialize_for_json(item) for item in obj]
elif isinstance(obj, dict):
return {key: serialize_for_json(value) for key, value in obj.items()}
else:
return obj
def search_documents_simple(query: str):
"""Simple document search without embeddings"""
try:
db = lancedb.connect('./lancedb_data')
# Check for Rajasthan documents first
if 'rajasthan_documents' in db.table_names():
tbl = db.open_table('rajasthan_documents')
df = tbl.to_pandas()
# Enhanced search for Rajasthan/pension queries
query_lower = query.lower()
is_pension_query = any(keyword in query_lower for keyword in [
'pension', 'पेंशन', 'वृद्धावस्था', 'सामाजिक', 'भत्ता', 'allowance',
'old age', 'social security', 'retirement', 'सेवानिवृत्ति'
])
if is_pension_query or 'rajasthan' in query_lower:
# Enhanced pension search with more keywords
pension_filter = df['content'].str.contains(
'pension|Pension|पेंशन|वृद्धावस्था|सामाजिक|भत्ता|allowance|old.age|social.security|retirement|सेवानिवृत्ति|scheme|योजना',
case=False, na=False, regex=True
)
relevant_docs = df[pension_filter]
if not relevant_docs.empty:
# Sort by relevance
def score_relevance(content):
keywords = ['pension', 'पेंशन', 'वृद्धावस्था', 'सामाजिक', 'भत्ता', 'allowance', 'old age']
return sum(1 for keyword in keywords if keyword in content.lower())
relevant_docs = relevant_docs.copy()
relevant_docs['relevance_score'] = relevant_docs['content'].apply(score_relevance)
relevant_docs = relevant_docs.sort_values('relevance_score', ascending=False)
results = []
for _, row in relevant_docs.head(5).iterrows():
results.append({
"content": row['content'][:800],
"filename": row['filename']
})
return results, "rajasthan_pension_documents"
return [], "none"
except Exception as e:
logger.error(f"Search error: {e}")
return [], "error"
async def get_llm_response(query: str, search_results: list):
"""Get response using available LLM service with caching"""
# Create cache key based on query and search results
cache_key = f"{query}_{len(search_results) if search_results else 0}"
# Check cache first
if cache_key in response_cache:
logger.info(f"📦 Cache hit for query: {query[:50]}...")
return response_cache[cache_key]
try:
if VOICE_AVAILABLE and hybrid_llm_service:
# Use the hybrid LLM service
if search_results:
context = "\\n\\n".join([f"Document: {doc['filename']}\\nContent: {doc['content']}" for doc in search_results])
enhanced_query = f"Based on these Rajasthan government documents, please answer: {query}\\n\\nDocuments:\\n{context}"
else:
enhanced_query = query
response = await hybrid_llm_service.get_response(enhanced_query)
# Cache the response
if len(response_cache) >= MAX_CACHE_SIZE:
# Remove oldest entry
response_cache.pop(next(iter(response_cache)))
response_cache[cache_key] = response
return response
else:
# Fallback to simple response
if search_results:
response = f"Based on the Rajasthan government documents, I found information about {query}. However, voice processing is currently limited. Please use text chat for detailed responses."
else:
response = f"I received your query about '{query}' but couldn't find specific documents. Please try using text chat for better results."
# Cache fallback response too
response_cache[cache_key] = response
return response
except Exception as e:
logger.error(f"LLM error: {e}")
return "I'm having trouble processing your request. Please try using the text chat."
@app.websocket("/ws/stream")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
logger.info("🔌 WebSocket client connected")
# Store user session info
user_language = "english" # Default language
try:
# Send initial greeting
await websocket.send_json({
"type": "connection_successful",
"message": "Hello! I'm your Rajasthan government document assistant. I can help with text and voice queries about pension schemes and government policies."
})
while True:
try:
# Receive message with better error handling
message = await websocket.receive()
# Handle different message types
if message["type"] == "websocket.receive":
if "text" in message:
# Parse JSON text message
try:
data = json.loads(message["text"])
except json.JSONDecodeError:
logger.warning(f"⚠️ Invalid JSON received: {message['text']}")
continue
# Process text message
if isinstance(data, dict) and data.get("type") == "text_message":
user_message = data.get("message", "")
if not user_message.strip():
continue
logger.info(f"💬 Text received: {user_message}")
# Check for interactive scenario form triggers
form_triggers = ["start scenario analysis", "scenario form", "interactive analysis", "step by step analysis", "guided analysis", "form analysis", "scenario chat form", "interactive scenario"]
is_form_request = any(trigger in user_message.lower() for trigger in form_triggers)
# Check if this is a policy simulation query (robust regex patterns)
import re
POLICY_PATTERNS = [
r"policy.*simulation|simulation.*policy",
r"policy.*scenario|scenario.*policy",
r"policy.*analysis|analysis.*policy",
r"pension.*simulation|simulation.*pension",
r"pension.*analysis|analysis.*pension",
r"pension.*scenario|scenario.*pension",
r"dearness.*relief|dr.*increase|dr.*adjustment",
r"dearness.*allowance|da.*increase|da.*adjustment",
r"minimum.*pension.*increase|increase.*minimum.*pension",
r"calculate.*pension|pension.*calculation",
r"impact.*dr|dr.*impact|impact.*da|da.*impact",
r"show.*impact.*da|show.*impact.*dr",
r"impact.*\d+.*da|impact.*\d+.*dr",
r"\d+.*da.*increase|da.*\d+.*increase",
r"\d+.*dr.*increase|dr.*\d+.*increase",
r"inflation.*adjustment|adjustment.*inflation",
r"scenario.*analysis|analysis.*scenario",
r"what.*if.*dr|what.*if.*pension|what.*if.*da",
r"compare.*scenario|scenario.*comparison",
r"show.*chart|chart.*show",
r"explain.*chart|chart.*explain",
r"using.*chart|chart.*using",
r"dr.*\d+.*increase|increase.*dr.*\d+",
r"da.*\d+.*increase|increase.*da.*\d+",
r"analyze.*minimum.*pension",
r"pension.*change",
r"make.*chart|chart.*make",
r"pension.*value|value.*pension",
r"basic.*pension.*\d+|pension.*\d+",
r"simulate.*dr|simulate.*pension|simulate.*da"
]
def is_policy_simulation_query(message: str) -> bool:
"""Check if the message is a policy simulation query"""
message_lower = message.lower()
logger.info(f"🔍 Checking policy patterns for: '{message_lower}'")
for i, pattern in enumerate(POLICY_PATTERNS):
if re.search(pattern, message_lower, re.IGNORECASE):
logger.info(f"✅ Pattern {i+1} matched: {pattern}")
return True
logger.info("❌ No policy patterns matched")
return False
is_policy_query = is_policy_simulation_query(user_message)
# Handle interactive scenario form request
if is_form_request:
logger.info("📋 Interactive scenario form requested")
try:
from scenario_chat_form import start_scenario_analysis_form
form_response = start_scenario_analysis_form(data.get("user_id", "default"))
# Format form response for chat
form_message = f"""🎯 **{form_response.get('title', 'Interactive Scenario Analysis')}**
{form_response.get('message', '')}
**{form_response.get('step_title', 'Step 1')}** ({form_response.get('current_step', 1)}/{form_response.get('total_steps', 4)})
{form_response['form_data']['question']}
**Available Options:**"""
# Add form options
if form_response['form_data']['input_type'] == 'select':
for i, option in enumerate(form_response['form_data']['options'], 1):
form_message += f"\n{i}. {option['label']}"
form_message += "\n\n**Quick Actions:**"
for action in form_response.get('quick_actions', []):
form_message += f"\n• {action['text']}"
form_message += "\n\n💡 **Next:** Choose an option above or type your selection!"
await websocket.send_json({
"type": "interactive_form",
"message": form_message,
"form_data": form_response
})
continue
except Exception as e:
logger.error(f"Form initialization failed: {str(e)}")
await websocket.send_json({
"type": "error_message",
"message": f"Sorry, I couldn't start the interactive scenario analysis. Error: {str(e)}"
})
continue
# Handle policy queries
elif is_policy_query:
logger.info("🎯 Detected policy simulation query")
try:
# Import policy chat interface
from policy_chat_interface import PolicySimulatorChatInterface
# Send acknowledgment for policy simulation
await websocket.send_json({
"type": "message_received",
"message": "🎯 Analyzing Rajasthan policy impact..."
})
# Initialize and process policy simulation
policy_simulator = PolicySimulatorChatInterface()
policy_result = policy_simulator.process_policy_query(user_message)
# Format policy response - use same format as working simple backend
if policy_result.get("type") == "policy_simulation":
# Serialize the response for JSON
serialized_response = serialize_for_json(policy_result)
# Send policy simulation response
await websocket.send_json({
"type": "policy_simulation",
"data": serialized_response
})
logger.info("📤 Policy simulation response sent to client")
else:
# Handle other policy responses (errors, help, etc.)
await websocket.send_json({
"type": "text_response",
"message": policy_result.get('message', 'Policy analysis completed')
})
continue
except Exception as e:
logger.error(f"Policy simulation failed: {str(e)}")
await websocket.send_json({
"type": "error_message",
"message": f"Sorry, policy analysis failed. Using document search instead."
})
# Fall through to regular document search
# Regular document search (fallback)
# Send acknowledgment
await websocket.send_json({
"type": "message_received",
"message": "🔍 Searching Rajasthan government documents..."
})
# Search for relevant documents
search_results, source = search_documents_simple(user_message)
logger.info(f"🔍 Found {len(search_results)} documents from {source}")
# Get LLM response
llm_response = await get_llm_response(user_message, search_results)
# Send response
await websocket.send_json({
"type": "text_response",
"message": llm_response
})
elif isinstance(data, dict) and data.get("type") == "user_info":
user_name = data.get("user_name", "Unknown")
logger.info(f"👤 User connected: {user_name}")
elif isinstance(data, dict) and data.get("lang"):
new_language = data.get("lang", "english")
if new_language != user_language:
user_language = new_language
logger.info(f"🌍 Language preference updated: {user_language}")
# Avoid logging if language hasn't changed
elif "bytes" in message:
# Handle binary message (audio data)
audio_data = message["bytes"]
logger.info(f"🎤 Received audio data: {len(audio_data)} bytes")
if VOICE_AVAILABLE and voice_service:
try:
# Save audio data to temporary file for processing
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_file.write(audio_data)
temp_file_path = temp_file.name
# Process audio with voice service using user's language preference
text = await voice_service.speech_to_text(temp_file_path, user_language)
# Clean up temp file
os.unlink(temp_file_path)
if text and text.strip():
logger.info(f"🎤 Transcribed: {text}")
# Search documents
search_results, source = search_documents_simple(text)
logger.info(f"🔍 Found {len(search_results)} documents from {source}")
# Get LLM response
llm_response = await get_llm_response(text, search_results)
# Send text response
await websocket.send_json({
"type": "text_response",
"message": llm_response
})
# Try to send voice response
try:
audio_response = await voice_service.text_to_speech(llm_response)
if audio_response:
await websocket.send_bytes(audio_response)
except Exception as tts_error:
logger.warning(f"TTS failed: {tts_error}")
else:
await websocket.send_json({
"type": "text_response",
"message": "I couldn't understand what you said. Please try speaking more clearly or use text chat."
})
except Exception as voice_error:
logger.error(f"Voice processing error: {voice_error}")
await websocket.send_json({
"type": "text_response",
"message": "Sorry, I couldn't process your voice input. Please try speaking again or use text chat."
})
else:
# Voice services not available
await websocket.send_json({
"type": "text_response",
"message": "Voice processing is currently unavailable. Please use the text chat to ask about Rajasthan pension schemes and government policies."
})
elif message["type"] == "websocket.disconnect":
break
except json.JSONDecodeError as e:
logger.warning(f"⚠️ JSON decode error: {e}")
continue
except KeyError as e:
logger.warning(f"⚠️ Missing key in message: {e}")
continue
except WebSocketDisconnect:
logger.info("🔌 WebSocket client disconnected")
except Exception as e:
logger.error(f"❌ WebSocket error: {e}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
db = lancedb.connect('./lancedb_data')
tables = db.table_names()
return {
"status": "healthy",
"tables": tables,
"voice_available": VOICE_AVAILABLE
}
except Exception as e:
return {"status": "error", "error": str(e)}
if __name__ == "__main__":
print("🚀 Starting voice-enabled WebSocket server on port 8000...")
uvicorn.run(app, host="0.0.0.0", port=8000)