""" Enhanced WebSocket handler with hybrid LLM and optional voice features """ from fastapi import WebSocket, WebSocketDisconnect from langchain_core.messages import HumanMessage, SystemMessage, AIMessage import logging import json import asyncio import uuid import tempfile import base64 from pathlib import Path from llm_service import create_graph, create_basic_graph from lancedb_service import lancedb_service from hybrid_llm_service import HybridLLMService from voice_service import voice_service from rag_service import search_government_docs # Initialize hybrid LLM service hybrid_llm_service = HybridLLMService() logger = logging.getLogger("voicebot") async def handle_enhanced_websocket_connection(websocket: WebSocket): """Enhanced WebSocket handler with hybrid LLM and voice features""" await websocket.accept() logger.info("🔌 Enhanced WebSocket client connected.") # Initialize session data session_data = { "messages": [], "user_preferences": { "voice_enabled": False, "preferred_voice": "en-US-AriaNeural", "response_mode": "text" # text, voice, both }, "context": "" } try: # Get initial connection data initial_data = await websocket.receive_json() # Extract user preferences if "preferences" in initial_data: session_data["user_preferences"].update(initial_data["preferences"]) # Setup user session flag = "user_id" in initial_data graph = None # Initialize graph variable if flag: thread_id = initial_data.get("user_id") knowledge_base = initial_data.get("knowledge_base", "government_docs") # Use hybrid LLM or traditional graph based on configuration if hybrid_llm_service.use_hybrid: logger.info("🤖 Using Hybrid LLM Service") use_hybrid = True else: graph = await create_graph(kb_tool=True, mcp_config=None) use_hybrid = False config = { "configurable": { "thread_id": thread_id, "knowledge_base": knowledge_base, } } else: # Basic setup for unauthenticated users thread_id = str(uuid.uuid4()) knowledge_base = "government_docs" use_hybrid = hybrid_llm_service.use_hybrid if not use_hybrid: graph = create_basic_graph() config = {"configurable": {"thread_id": thread_id}} # Send initial greeting with voice/hybrid capabilities await send_enhanced_greeting(websocket, session_data) # Main message handling loop while True: try: data = await websocket.receive_json() if data["type"] == "text_message": await handle_text_message( websocket, data, session_data, use_hybrid, config, knowledge_base, graph ) elif data["type"] == "voice_message": await handle_voice_message( websocket, data, session_data, use_hybrid, config, knowledge_base, graph ) elif data["type"] == "preferences_update": await handle_preferences_update(websocket, data, session_data) elif data["type"] == "get_voice_status": await websocket.send_json({ "type": "voice_status", "data": voice_service.get_voice_status() }) elif data["type"] == "get_llm_status": await websocket.send_json({ "type": "llm_status", "data": hybrid_llm_service.get_provider_info() }) except WebSocketDisconnect: logger.info("🔌 WebSocket client disconnected.") break except Exception as e: logger.error(f"❌ Error handling message: {e}") await websocket.send_json({ "type": "error", "message": f"An error occurred: {str(e)}" }) except WebSocketDisconnect: logger.info("🔌 WebSocket client disconnected during setup.") except Exception as e: logger.error(f"❌ WebSocket error: {e}") try: await websocket.send_json({ "type": "error", "message": f"Connection error: {str(e)}" }) except: pass async def send_enhanced_greeting(websocket: WebSocket, session_data: dict): """Send enhanced greeting with system capabilities""" # Get system status llm_info = hybrid_llm_service.get_provider_info() voice_status = voice_service.get_voice_status() greeting_text = f"""🤖 Welcome to the Government Document Assistant! I'm powered by a hybrid AI system that can help you with: • Government policies and procedures • Document search and analysis • Scenario analysis with visualizations • Quick answers and detailed explanations Current capabilities: • LLM: {'Hybrid (' + llm_info['fast_provider'] + '/' + llm_info['complex_provider'] + ')' if llm_info['hybrid_enabled'] else 'Single provider'} • Voice features: {'Enabled' if voice_status['voice_enabled'] else 'Disabled'} How can I assist you today? You can ask me about any government policies, procedures, or documents!""" # Send text greeting await websocket.send_json({ "type": "message_response", "message": greeting_text, "provider_used": "system", "capabilities": { "hybrid_llm": llm_info['hybrid_enabled'], "voice_features": voice_status['voice_enabled'], "scenario_analysis": True } }) # Send voice greeting if enabled if session_data["user_preferences"]["voice_enabled"] and voice_status['voice_enabled']: voice_greeting = "Welcome to the Government Document Assistant! I can help you with policies, procedures, and document analysis. How can I assist you today?" audio_data = await voice_service.text_to_speech(voice_greeting) if audio_data: await websocket.send_json({ "type": "audio_response", "audio_data": base64.b64encode(audio_data).decode(), "format": "mp3" }) async def handle_text_message(websocket: WebSocket, data: dict, session_data: dict, use_hybrid: bool, config: dict, knowledge_base: str, graph=None): """Handle text message with hybrid LLM""" user_message = data["message"] logger.info(f"💬 Received text message: {user_message}") # Send acknowledgment await websocket.send_json({ "type": "message_received", "message": "Processing your message..." }) try: if use_hybrid: # Use hybrid LLM service response_text, provider_used = await get_hybrid_response( user_message, session_data["context"], config, knowledge_base ) else: # Use traditional graph approach session_data["messages"].append(HumanMessage(content=user_message)) result = await graph.ainvoke({"messages": session_data["messages"]}, config) response_text = result["messages"][-1].content provider_used = "traditional" # Handle scenario analysis images if "SCENARIO_ANALYSIS_IMAGE:" in response_text: await handle_scenario_response(websocket, response_text, provider_used) else: await send_text_response(websocket, response_text, provider_used, session_data) except Exception as e: logger.error(f"❌ Error processing text message: {e}") await websocket.send_json({ "type": "error", "message": f"Error processing your message: {str(e)}" }) async def handle_voice_message(websocket: WebSocket, data: dict, session_data: dict, use_hybrid: bool, config: dict, knowledge_base: str, graph=None): """Handle voice message with ASR and TTS""" if not voice_service.is_voice_enabled(): await websocket.send_json({ "type": "error", "message": "Voice features are not enabled" }) return try: # Decode audio data audio_data = base64.b64decode(data["audio_data"]) # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_file.write(audio_data) temp_file_path = temp_file.name # Convert speech to text transcribed_text = await voice_service.speech_to_text(temp_file_path) # Clean up temp file Path(temp_file_path).unlink() if not transcribed_text: await websocket.send_json({ "type": "error", "message": "Could not transcribe audio" }) return logger.info(f"🎤 Transcribed: {transcribed_text}") # Send transcription await websocket.send_json({ "type": "transcription", "text": transcribed_text }) # Process as text message if use_hybrid: response_text, provider_used = await get_hybrid_response( transcribed_text, session_data["context"], config, knowledge_base ) else: session_data["messages"].append(HumanMessage(content=transcribed_text)) result = await graph.ainvoke({"messages": session_data["messages"]}, config) response_text = result["messages"][-1].content provider_used = "traditional" # Send text response await send_text_response(websocket, response_text, provider_used, session_data) # Send voice response if enabled if session_data["user_preferences"]["response_mode"] in ["voice", "both"]: voice_text = voice_service.create_voice_response_with_guidance( response_text, suggested_resources=["Government portal", "Local offices"], redirect_info="contact your local government office for personalized assistance" ) audio_response = await voice_service.text_to_speech( voice_text, session_data["user_preferences"]["preferred_voice"] ) if audio_response: await websocket.send_json({ "type": "audio_response", "audio_data": base64.b64encode(audio_response).decode(), "format": "mp3" }) except Exception as e: logger.error(f"❌ Error processing voice message: {e}") await websocket.send_json({ "type": "error", "message": f"Error processing voice message: {str(e)}" }) async def get_hybrid_response(user_message: str, context: str, config: dict, knowledge_base: str): """Get response using hybrid LLM with document search""" # Search for relevant documents try: search_results = await search_government_docs.ainvoke( {"query": user_message}, config=config ) context = search_results if search_results else context except: logger.warning("Document search failed, using existing context") # Get hybrid LLM response response_text = await hybrid_llm_service.get_response( user_message, context=context, system_prompt="""You are a helpful government document assistant. Provide accurate, helpful responses based on the context provided. When appropriate, suggest additional resources or redirect users to relevant departments for more assistance.""" ) # Determine which provider was used provider = hybrid_llm_service.choose_llm_provider(user_message) provider_used = provider.value if provider else "unknown" return response_text, provider_used async def send_text_response(websocket: WebSocket, response_text: str, provider_used: str, session_data: dict): """Send text response to client""" await websocket.send_json({ "type": "message_response", "message": response_text, "provider_used": provider_used, "timestamp": asyncio.get_event_loop().time() }) # Update session context session_data["context"] = response_text[-1000:] # Keep last 1000 chars as context async def handle_scenario_response(websocket: WebSocket, response_text: str, provider_used: str): """Handle scenario analysis response with images""" parts = response_text.split("SCENARIO_ANALYSIS_IMAGE:") text_part = parts[0].strip() # Send text part if text_part: await websocket.send_json({ "type": "message_response", "message": text_part, "provider_used": provider_used }) # Send image parts for i, part in enumerate(parts[1:], 1): try: image_data = part.strip() await websocket.send_json({ "type": "scenario_image", "image_data": image_data, "image_index": i, "chart_type": "analysis" }) except Exception as e: logger.error(f"Error sending scenario image {i}: {e}") async def handle_preferences_update(websocket: WebSocket, data: dict, session_data: dict): """Handle user preferences update""" try: session_data["user_preferences"].update(data["preferences"]) await websocket.send_json({ "type": "preferences_updated", "preferences": session_data["user_preferences"] }) logger.info(f"🔧 Updated user preferences: {session_data['user_preferences']}") except Exception as e: logger.error(f"❌ Error updating preferences: {e}") await websocket.send_json({ "type": "error", "message": f"Error updating preferences: {str(e)}" }) # Keep the original function for backward compatibility async def handle_websocket_connection(websocket: WebSocket): """Original websocket handler for backward compatibility""" await handle_enhanced_websocket_connection(websocket)