PensionBot / enhanced_websocket_handler.py
ChAbhishek28's picture
Deploy clean Voice Bot backend to HF Spaces
cf02b2b
raw
history blame
14.9 kB
"""
Enhanced WebSocket handler with hybrid LLM and optional voice features
"""
from fastapi import WebSocket, WebSocketDisconnect
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
import logging
import json
import asyncio
import uuid
import tempfile
import base64
from pathlib import Path
from llm_service import create_graph, create_basic_graph
from lancedb_service import lancedb_service
from hybrid_llm_service import HybridLLMService
from voice_service import voice_service
from rag_service import search_government_docs
# Initialize hybrid LLM service
hybrid_llm_service = HybridLLMService()
logger = logging.getLogger("voicebot")
async def handle_enhanced_websocket_connection(websocket: WebSocket):
"""Enhanced WebSocket handler with hybrid LLM and voice features"""
await websocket.accept()
logger.info("πŸ”Œ Enhanced WebSocket client connected.")
# Initialize session data
session_data = {
"messages": [],
"user_preferences": {
"voice_enabled": False,
"preferred_voice": "en-US-AriaNeural",
"response_mode": "text" # text, voice, both
},
"context": ""
}
try:
# Get initial connection data
initial_data = await websocket.receive_json()
# Extract user preferences
if "preferences" in initial_data:
session_data["user_preferences"].update(initial_data["preferences"])
# Setup user session
flag = "user_id" in initial_data
graph = None # Initialize graph variable
if flag:
thread_id = initial_data.get("user_id")
knowledge_base = initial_data.get("knowledge_base", "government_docs")
# Use hybrid LLM or traditional graph based on configuration
if hybrid_llm_service.use_hybrid:
logger.info("πŸ€– Using Hybrid LLM Service")
use_hybrid = True
else:
graph = await create_graph(kb_tool=True, mcp_config=None)
use_hybrid = False
config = {
"configurable": {
"thread_id": thread_id,
"knowledge_base": knowledge_base,
}
}
else:
# Basic setup for unauthenticated users
thread_id = str(uuid.uuid4())
knowledge_base = "government_docs"
use_hybrid = hybrid_llm_service.use_hybrid
if not use_hybrid:
graph = create_basic_graph()
config = {"configurable": {"thread_id": thread_id}}
# Send initial greeting with voice/hybrid capabilities
await send_enhanced_greeting(websocket, session_data)
# Main message handling loop
while True:
try:
data = await websocket.receive_json()
if data["type"] == "text_message":
await handle_text_message(
websocket, data, session_data,
use_hybrid, config, knowledge_base, graph
)
elif data["type"] == "voice_message":
await handle_voice_message(
websocket, data, session_data,
use_hybrid, config, knowledge_base, graph
)
elif data["type"] == "preferences_update":
await handle_preferences_update(websocket, data, session_data)
elif data["type"] == "get_voice_status":
await websocket.send_json({
"type": "voice_status",
"data": voice_service.get_voice_status()
})
elif data["type"] == "get_llm_status":
await websocket.send_json({
"type": "llm_status",
"data": hybrid_llm_service.get_provider_info()
})
except WebSocketDisconnect:
logger.info("πŸ”Œ WebSocket client disconnected.")
break
except Exception as e:
logger.error(f"❌ Error handling message: {e}")
await websocket.send_json({
"type": "error",
"message": f"An error occurred: {str(e)}"
})
except WebSocketDisconnect:
logger.info("πŸ”Œ WebSocket client disconnected during setup.")
except Exception as e:
logger.error(f"❌ WebSocket error: {e}")
try:
await websocket.send_json({
"type": "error",
"message": f"Connection error: {str(e)}"
})
except:
pass
async def send_enhanced_greeting(websocket: WebSocket, session_data: dict):
"""Send enhanced greeting with system capabilities"""
# Get system status
llm_info = hybrid_llm_service.get_provider_info()
voice_status = voice_service.get_voice_status()
greeting_text = f"""πŸ€– Welcome to the Government Document Assistant!
I'm powered by a hybrid AI system that can help you with:
β€’ Government policies and procedures
β€’ Document search and analysis
β€’ Scenario analysis with visualizations
β€’ Quick answers and detailed explanations
Current capabilities:
β€’ LLM: {'Hybrid (' + llm_info['fast_provider'] + '/' + llm_info['complex_provider'] + ')' if llm_info['hybrid_enabled'] else 'Single provider'}
β€’ Voice features: {'Enabled' if voice_status['voice_enabled'] else 'Disabled'}
How can I assist you today? You can ask me about any government policies, procedures, or documents!"""
# Send text greeting
await websocket.send_json({
"type": "message_response",
"message": greeting_text,
"provider_used": "system",
"capabilities": {
"hybrid_llm": llm_info['hybrid_enabled'],
"voice_features": voice_status['voice_enabled'],
"scenario_analysis": True
}
})
# Send voice greeting if enabled
if session_data["user_preferences"]["voice_enabled"] and voice_status['voice_enabled']:
voice_greeting = "Welcome to the Government Document Assistant! I can help you with policies, procedures, and document analysis. How can I assist you today?"
audio_data = await voice_service.text_to_speech(voice_greeting)
if audio_data:
await websocket.send_json({
"type": "audio_response",
"audio_data": base64.b64encode(audio_data).decode(),
"format": "mp3"
})
async def handle_text_message(websocket: WebSocket, data: dict, session_data: dict,
use_hybrid: bool, config: dict, knowledge_base: str, graph=None):
"""Handle text message with hybrid LLM"""
user_message = data["message"]
logger.info(f"πŸ’¬ Received text message: {user_message}")
# Send acknowledgment
await websocket.send_json({
"type": "message_received",
"message": "Processing your message..."
})
try:
if use_hybrid:
# Use hybrid LLM service
response_text, provider_used = await get_hybrid_response(
user_message, session_data["context"], config, knowledge_base
)
else:
# Use traditional graph approach
session_data["messages"].append(HumanMessage(content=user_message))
result = await graph.ainvoke({"messages": session_data["messages"]}, config)
response_text = result["messages"][-1].content
provider_used = "traditional"
# Handle scenario analysis images
if "SCENARIO_ANALYSIS_IMAGE:" in response_text:
await handle_scenario_response(websocket, response_text, provider_used)
else:
await send_text_response(websocket, response_text, provider_used, session_data)
except Exception as e:
logger.error(f"❌ Error processing text message: {e}")
await websocket.send_json({
"type": "error",
"message": f"Error processing your message: {str(e)}"
})
async def handle_voice_message(websocket: WebSocket, data: dict, session_data: dict,
use_hybrid: bool, config: dict, knowledge_base: str, graph=None):
"""Handle voice message with ASR and TTS"""
if not voice_service.is_voice_enabled():
await websocket.send_json({
"type": "error",
"message": "Voice features are not enabled"
})
return
try:
# Decode audio data
audio_data = base64.b64decode(data["audio_data"])
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_file.write(audio_data)
temp_file_path = temp_file.name
# Convert speech to text
transcribed_text = await voice_service.speech_to_text(temp_file_path)
# Clean up temp file
Path(temp_file_path).unlink()
if not transcribed_text:
await websocket.send_json({
"type": "error",
"message": "Could not transcribe audio"
})
return
logger.info(f"🎀 Transcribed: {transcribed_text}")
# Send transcription
await websocket.send_json({
"type": "transcription",
"text": transcribed_text
})
# Process as text message
if use_hybrid:
response_text, provider_used = await get_hybrid_response(
transcribed_text, session_data["context"], config, knowledge_base
)
else:
session_data["messages"].append(HumanMessage(content=transcribed_text))
result = await graph.ainvoke({"messages": session_data["messages"]}, config)
response_text = result["messages"][-1].content
provider_used = "traditional"
# Send text response
await send_text_response(websocket, response_text, provider_used, session_data)
# Send voice response if enabled
if session_data["user_preferences"]["response_mode"] in ["voice", "both"]:
voice_text = voice_service.create_voice_response_with_guidance(
response_text,
suggested_resources=["Government portal", "Local offices"],
redirect_info="contact your local government office for personalized assistance"
)
audio_response = await voice_service.text_to_speech(
voice_text,
session_data["user_preferences"]["preferred_voice"]
)
if audio_response:
await websocket.send_json({
"type": "audio_response",
"audio_data": base64.b64encode(audio_response).decode(),
"format": "mp3"
})
except Exception as e:
logger.error(f"❌ Error processing voice message: {e}")
await websocket.send_json({
"type": "error",
"message": f"Error processing voice message: {str(e)}"
})
async def get_hybrid_response(user_message: str, context: str, config: dict, knowledge_base: str):
"""Get response using hybrid LLM with document search"""
# Search for relevant documents
try:
search_results = await search_government_docs.ainvoke(
{"query": user_message},
config=config
)
context = search_results if search_results else context
except:
logger.warning("Document search failed, using existing context")
# Get hybrid LLM response
response_text = await hybrid_llm_service.get_response(
user_message,
context=context,
system_prompt="""You are a helpful government document assistant. Provide accurate, helpful responses based on the context provided. When appropriate, suggest additional resources or redirect users to relevant departments for more assistance."""
)
# Determine which provider was used
complexity = hybrid_llm_service.determine_task_complexity(user_message, context)
provider_used = hybrid_llm_service.choose_llm_provider(complexity)
return response_text, provider_used
async def send_text_response(websocket: WebSocket, response_text: str, provider_used: str, session_data: dict):
"""Send text response to client"""
await websocket.send_json({
"type": "message_response",
"message": response_text,
"provider_used": provider_used,
"timestamp": asyncio.get_event_loop().time()
})
# Update session context
session_data["context"] = response_text[-1000:] # Keep last 1000 chars as context
async def handle_scenario_response(websocket: WebSocket, response_text: str, provider_used: str):
"""Handle scenario analysis response with images"""
parts = response_text.split("SCENARIO_ANALYSIS_IMAGE:")
text_part = parts[0].strip()
# Send text part
if text_part:
await websocket.send_json({
"type": "message_response",
"message": text_part,
"provider_used": provider_used
})
# Send image parts
for i, part in enumerate(parts[1:], 1):
try:
image_data = part.strip()
await websocket.send_json({
"type": "scenario_image",
"image_data": image_data,
"image_index": i,
"chart_type": "analysis"
})
except Exception as e:
logger.error(f"Error sending scenario image {i}: {e}")
async def handle_preferences_update(websocket: WebSocket, data: dict, session_data: dict):
"""Handle user preferences update"""
try:
session_data["user_preferences"].update(data["preferences"])
await websocket.send_json({
"type": "preferences_updated",
"preferences": session_data["user_preferences"]
})
logger.info(f"πŸ”§ Updated user preferences: {session_data['user_preferences']}")
except Exception as e:
logger.error(f"❌ Error updating preferences: {e}")
await websocket.send_json({
"type": "error",
"message": f"Error updating preferences: {str(e)}"
})
# Keep the original function for backward compatibility
async def handle_websocket_connection(websocket: WebSocket):
"""Original websocket handler for backward compatibility"""
await handle_enhanced_websocket_connection(websocket)