Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced WebSocket handler with hybrid LLM and optional voice features | |
| """ | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage | |
| import logging | |
| import json | |
| import asyncio | |
| import uuid | |
| import tempfile | |
| import base64 | |
| from pathlib import Path | |
| from llm_service import create_graph, create_basic_graph | |
| from lancedb_service import lancedb_service | |
| from hybrid_llm_service import HybridLLMService | |
| from voice_service import voice_service | |
| from rag_service import search_government_docs | |
| # Initialize hybrid LLM service | |
| hybrid_llm_service = HybridLLMService() | |
| logger = logging.getLogger("voicebot") | |
| async def handle_enhanced_websocket_connection(websocket: WebSocket): | |
| """Enhanced WebSocket handler with hybrid LLM and voice features""" | |
| await websocket.accept() | |
| logger.info("π Enhanced WebSocket client connected.") | |
| # Initialize session data | |
| session_data = { | |
| "messages": [], | |
| "user_preferences": { | |
| "voice_enabled": False, | |
| "preferred_voice": "en-US-AriaNeural", | |
| "response_mode": "text" # text, voice, both | |
| }, | |
| "context": "" | |
| } | |
| try: | |
| # Get initial connection data | |
| initial_data = await websocket.receive_json() | |
| # Extract user preferences | |
| if "preferences" in initial_data: | |
| session_data["user_preferences"].update(initial_data["preferences"]) | |
| # Setup user session | |
| flag = "user_id" in initial_data | |
| graph = None # Initialize graph variable | |
| if flag: | |
| thread_id = initial_data.get("user_id") | |
| knowledge_base = initial_data.get("knowledge_base", "government_docs") | |
| # Use hybrid LLM or traditional graph based on configuration | |
| if hybrid_llm_service.use_hybrid: | |
| logger.info("π€ Using Hybrid LLM Service") | |
| use_hybrid = True | |
| else: | |
| graph = await create_graph(kb_tool=True, mcp_config=None) | |
| use_hybrid = False | |
| config = { | |
| "configurable": { | |
| "thread_id": thread_id, | |
| "knowledge_base": knowledge_base, | |
| } | |
| } | |
| else: | |
| # Basic setup for unauthenticated users | |
| thread_id = str(uuid.uuid4()) | |
| knowledge_base = "government_docs" | |
| use_hybrid = hybrid_llm_service.use_hybrid | |
| if not use_hybrid: | |
| graph = create_basic_graph() | |
| config = {"configurable": {"thread_id": thread_id}} | |
| # Send initial greeting with voice/hybrid capabilities | |
| await send_enhanced_greeting(websocket, session_data) | |
| # Main message handling loop | |
| while True: | |
| try: | |
| data = await websocket.receive_json() | |
| if data["type"] == "text_message": | |
| await handle_text_message( | |
| websocket, data, session_data, | |
| use_hybrid, config, knowledge_base, graph | |
| ) | |
| elif data["type"] == "voice_message": | |
| await handle_voice_message( | |
| websocket, data, session_data, | |
| use_hybrid, config, knowledge_base, graph | |
| ) | |
| elif data["type"] == "preferences_update": | |
| await handle_preferences_update(websocket, data, session_data) | |
| elif data["type"] == "get_voice_status": | |
| await websocket.send_json({ | |
| "type": "voice_status", | |
| "data": voice_service.get_voice_status() | |
| }) | |
| elif data["type"] == "get_llm_status": | |
| await websocket.send_json({ | |
| "type": "llm_status", | |
| "data": hybrid_llm_service.get_provider_info() | |
| }) | |
| except WebSocketDisconnect: | |
| logger.info("π WebSocket client disconnected.") | |
| break | |
| except Exception as e: | |
| logger.error(f"β Error handling message: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"An error occurred: {str(e)}" | |
| }) | |
| except WebSocketDisconnect: | |
| logger.info("π WebSocket client disconnected during setup.") | |
| except Exception as e: | |
| logger.error(f"β WebSocket error: {e}") | |
| try: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Connection error: {str(e)}" | |
| }) | |
| except: | |
| pass | |
| async def send_enhanced_greeting(websocket: WebSocket, session_data: dict): | |
| """Send enhanced greeting with system capabilities""" | |
| # Get system status | |
| llm_info = hybrid_llm_service.get_provider_info() | |
| voice_status = voice_service.get_voice_status() | |
| greeting_text = f"""π€ Welcome to the Government Document Assistant! | |
| I'm powered by a hybrid AI system that can help you with: | |
| β’ Government policies and procedures | |
| β’ Document search and analysis | |
| β’ Scenario analysis with visualizations | |
| β’ Quick answers and detailed explanations | |
| Current capabilities: | |
| β’ LLM: {'Hybrid (' + llm_info['fast_provider'] + '/' + llm_info['complex_provider'] + ')' if llm_info['hybrid_enabled'] else 'Single provider'} | |
| β’ Voice features: {'Enabled' if voice_status['voice_enabled'] else 'Disabled'} | |
| How can I assist you today? You can ask me about any government policies, procedures, or documents!""" | |
| # Send text greeting | |
| await websocket.send_json({ | |
| "type": "message_response", | |
| "message": greeting_text, | |
| "provider_used": "system", | |
| "capabilities": { | |
| "hybrid_llm": llm_info['hybrid_enabled'], | |
| "voice_features": voice_status['voice_enabled'], | |
| "scenario_analysis": True | |
| } | |
| }) | |
| # Send voice greeting if enabled | |
| if session_data["user_preferences"]["voice_enabled"] and voice_status['voice_enabled']: | |
| voice_greeting = "Welcome to the Government Document Assistant! I can help you with policies, procedures, and document analysis. How can I assist you today?" | |
| audio_data = await voice_service.text_to_speech(voice_greeting) | |
| if audio_data: | |
| await websocket.send_json({ | |
| "type": "audio_response", | |
| "audio_data": base64.b64encode(audio_data).decode(), | |
| "format": "mp3" | |
| }) | |
| async def handle_text_message(websocket: WebSocket, data: dict, session_data: dict, | |
| use_hybrid: bool, config: dict, knowledge_base: str, graph=None): | |
| """Handle text message with hybrid LLM""" | |
| user_message = data["message"] | |
| logger.info(f"π¬ Received text message: {user_message}") | |
| # Send acknowledgment | |
| await websocket.send_json({ | |
| "type": "message_received", | |
| "message": "Processing your message..." | |
| }) | |
| try: | |
| if use_hybrid: | |
| # Use hybrid LLM service | |
| response_text, provider_used = await get_hybrid_response( | |
| user_message, session_data["context"], config, knowledge_base | |
| ) | |
| else: | |
| # Use traditional graph approach | |
| session_data["messages"].append(HumanMessage(content=user_message)) | |
| result = await graph.ainvoke({"messages": session_data["messages"]}, config) | |
| response_text = result["messages"][-1].content | |
| provider_used = "traditional" | |
| # Handle scenario analysis images | |
| if "SCENARIO_ANALYSIS_IMAGE:" in response_text: | |
| await handle_scenario_response(websocket, response_text, provider_used) | |
| else: | |
| await send_text_response(websocket, response_text, provider_used, session_data) | |
| except Exception as e: | |
| logger.error(f"β Error processing text message: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Error processing your message: {str(e)}" | |
| }) | |
| async def handle_voice_message(websocket: WebSocket, data: dict, session_data: dict, | |
| use_hybrid: bool, config: dict, knowledge_base: str, graph=None): | |
| """Handle voice message with ASR and TTS""" | |
| if not voice_service.is_voice_enabled(): | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Voice features are not enabled" | |
| }) | |
| return | |
| try: | |
| # Decode audio data | |
| audio_data = base64.b64decode(data["audio_data"]) | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
| temp_file.write(audio_data) | |
| temp_file_path = temp_file.name | |
| # Convert speech to text | |
| transcribed_text = await voice_service.speech_to_text(temp_file_path) | |
| # Clean up temp file | |
| Path(temp_file_path).unlink() | |
| if not transcribed_text: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Could not transcribe audio" | |
| }) | |
| return | |
| logger.info(f"π€ Transcribed: {transcribed_text}") | |
| # Send transcription | |
| await websocket.send_json({ | |
| "type": "transcription", | |
| "text": transcribed_text | |
| }) | |
| # Process as text message | |
| if use_hybrid: | |
| response_text, provider_used = await get_hybrid_response( | |
| transcribed_text, session_data["context"], config, knowledge_base | |
| ) | |
| else: | |
| session_data["messages"].append(HumanMessage(content=transcribed_text)) | |
| result = await graph.ainvoke({"messages": session_data["messages"]}, config) | |
| response_text = result["messages"][-1].content | |
| provider_used = "traditional" | |
| # Send text response | |
| await send_text_response(websocket, response_text, provider_used, session_data) | |
| # Send voice response if enabled | |
| if session_data["user_preferences"]["response_mode"] in ["voice", "both"]: | |
| voice_text = voice_service.create_voice_response_with_guidance( | |
| response_text, | |
| suggested_resources=["Government portal", "Local offices"], | |
| redirect_info="contact your local government office for personalized assistance" | |
| ) | |
| audio_response = await voice_service.text_to_speech( | |
| voice_text, | |
| session_data["user_preferences"]["preferred_voice"] | |
| ) | |
| if audio_response: | |
| await websocket.send_json({ | |
| "type": "audio_response", | |
| "audio_data": base64.b64encode(audio_response).decode(), | |
| "format": "mp3" | |
| }) | |
| except Exception as e: | |
| logger.error(f"β Error processing voice message: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Error processing voice message: {str(e)}" | |
| }) | |
| async def get_hybrid_response(user_message: str, context: str, config: dict, knowledge_base: str): | |
| """Get response using hybrid LLM with document search""" | |
| # Search for relevant documents | |
| try: | |
| search_results = await search_government_docs.ainvoke( | |
| {"query": user_message}, | |
| config=config | |
| ) | |
| context = search_results if search_results else context | |
| except: | |
| logger.warning("Document search failed, using existing context") | |
| # Get hybrid LLM response | |
| response_text = await hybrid_llm_service.get_response( | |
| user_message, | |
| context=context, | |
| system_prompt="""You are a helpful government document assistant. Provide accurate, helpful responses based on the context provided. When appropriate, suggest additional resources or redirect users to relevant departments for more assistance.""" | |
| ) | |
| # Determine which provider was used | |
| complexity = hybrid_llm_service.determine_task_complexity(user_message, context) | |
| provider_used = hybrid_llm_service.choose_llm_provider(complexity) | |
| return response_text, provider_used | |
| async def send_text_response(websocket: WebSocket, response_text: str, provider_used: str, session_data: dict): | |
| """Send text response to client""" | |
| await websocket.send_json({ | |
| "type": "message_response", | |
| "message": response_text, | |
| "provider_used": provider_used, | |
| "timestamp": asyncio.get_event_loop().time() | |
| }) | |
| # Update session context | |
| session_data["context"] = response_text[-1000:] # Keep last 1000 chars as context | |
| async def handle_scenario_response(websocket: WebSocket, response_text: str, provider_used: str): | |
| """Handle scenario analysis response with images""" | |
| parts = response_text.split("SCENARIO_ANALYSIS_IMAGE:") | |
| text_part = parts[0].strip() | |
| # Send text part | |
| if text_part: | |
| await websocket.send_json({ | |
| "type": "message_response", | |
| "message": text_part, | |
| "provider_used": provider_used | |
| }) | |
| # Send image parts | |
| for i, part in enumerate(parts[1:], 1): | |
| try: | |
| image_data = part.strip() | |
| await websocket.send_json({ | |
| "type": "scenario_image", | |
| "image_data": image_data, | |
| "image_index": i, | |
| "chart_type": "analysis" | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error sending scenario image {i}: {e}") | |
| async def handle_preferences_update(websocket: WebSocket, data: dict, session_data: dict): | |
| """Handle user preferences update""" | |
| try: | |
| session_data["user_preferences"].update(data["preferences"]) | |
| await websocket.send_json({ | |
| "type": "preferences_updated", | |
| "preferences": session_data["user_preferences"] | |
| }) | |
| logger.info(f"π§ Updated user preferences: {session_data['user_preferences']}") | |
| except Exception as e: | |
| logger.error(f"β Error updating preferences: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Error updating preferences: {str(e)}" | |
| }) | |
| # Keep the original function for backward compatibility | |
| async def handle_websocket_connection(websocket: WebSocket): | |
| """Original websocket handler for backward compatibility""" | |
| await handle_enhanced_websocket_connection(websocket) | |