Spaces:
Sleeping
Sleeping
| from fastapi import WebSocket, WebSocketDisconnect | |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage | |
| import logging | |
| import json | |
| import asyncio | |
| import re | |
| from typing import Dict, Any | |
| from hybrid_llm_service import HybridLLMService # Fixed import | |
| from voice_service import VoiceService | |
| from rag_service import search_documents | |
| from llm_service import create_graph, create_basic_graph | |
| from lancedb_service import lancedb_service | |
| from policy_chat_interface import PolicySimulatorChatInterface | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize services | |
| hybrid_llm_service = HybridLLMService() # Create instance | |
| voice_service = VoiceService() | |
| policy_simulator = PolicySimulatorChatInterface() | |
| # Policy simulation detection patterns | |
| POLICY_PATTERNS = [ | |
| r"scenario.*analy", | |
| r"policy.*simulat", | |
| r"pension.*analy", | |
| r"simulate.*dr|dr.*simulat", | |
| r"simulate.*pension|pension.*simulat", | |
| r"impact.*analy", | |
| r"dearness.*relief", | |
| r"basic.*pension", | |
| r"medical.*allowance", | |
| r"chart.*pension|pension.*chart", | |
| r"visual.*analy|analy.*visual", | |
| r"show.*chart|chart.*show", | |
| r"explain.*chart|chart.*explain", | |
| r"using.*chart|chart.*using", | |
| r"dr.*\d+.*increase|increase.*dr.*\d+", | |
| r"analyze.*minimum.*pension", | |
| r"pension.*change" | |
| ] | |
| def is_policy_simulation_query(message: str) -> bool: | |
| """Check if the message is a policy simulation query""" | |
| message_lower = message.lower() | |
| return any(re.search(pattern, message_lower, re.IGNORECASE) for pattern in POLICY_PATTERNS) | |
| async def handle_websocket_connection(websocket: WebSocket): | |
| """Handle WebSocket connection for the voice bot""" | |
| await websocket.accept() | |
| logger.info("🔌 WebSocket client connected.") | |
| import uuid | |
| initial_data = await websocket.receive_json() | |
| messages = [] | |
| # Check if user authentication is provided | |
| flag = "user_id" in initial_data | |
| if flag: | |
| thread_id = initial_data.get("user_id") | |
| knowledge_base = initial_data.get("knowledge_base", "government_docs") | |
| # Create graph with RAG capabilities | |
| graph = await create_graph(kb_tool=True, mcp_config=None) | |
| config = { | |
| "configurable": { | |
| "thread_id": thread_id, | |
| "knowledge_base": knowledge_base, | |
| } | |
| } | |
| # Set system prompt for government document queries | |
| system_message = """You are a helpful assistant that can answer questions about government documents, policies, and procedures. | |
| Keep your responses clear and concise. When referencing specific documents or policies, mention the source. | |
| If you're uncertain about information, clearly state that and suggest where the user might find authoritative information.""" | |
| messages.append(SystemMessage(content=system_message)) | |
| else: | |
| # Basic graph for unauthenticated users | |
| graph = create_basic_graph() | |
| thread_id = str(uuid.uuid4()) | |
| config = {"configurable": {"thread_id": thread_id}} | |
| # Send initial greeting | |
| greeting_message = HumanMessage( | |
| content="Generate a brief greeting for the user, introduce yourself as a government document assistant, and explain how you can help them find information from government policies and documents." | |
| ) | |
| messages.append(greeting_message) | |
| try: | |
| response = await graph.ainvoke({"messages": messages}, config=config) | |
| greeting_response = response["messages"][-1].content | |
| messages.append(AIMessage(content=greeting_response)) | |
| await websocket.send_json({ | |
| "type": "connection_successful", | |
| "message": greeting_response | |
| }) | |
| except Exception as e: | |
| logger.error(f"❌ Error generating greeting: {e}") | |
| await websocket.send_json({ | |
| "type": "connection_successful", | |
| "message": "Hello! I'm your government document assistant. How can I help you today?" | |
| }) | |
| try: | |
| while True: | |
| data = await websocket.receive_json() | |
| if data["type"] == "text_message": | |
| # Handle text message | |
| user_message = data["message"] | |
| logger.info(f"💬 Received text message: {user_message}") | |
| messages.append(HumanMessage(content=user_message)) | |
| # Send acknowledgment | |
| await websocket.send_json({ | |
| "type": "message_received", | |
| "message": "Processing your message..." | |
| }) | |
| # Check if this is a policy simulation query | |
| if is_policy_simulation_query(user_message): | |
| logger.info("🎯 Detected policy simulation query") | |
| try: | |
| # Process with policy simulator | |
| policy_response = policy_simulator.process_policy_query(user_message) | |
| # Send policy simulation response | |
| await websocket.send_json({ | |
| "type": "policy_simulation", | |
| "data": policy_response | |
| }) | |
| messages.append(AIMessage(content=policy_response.get('message', 'Policy simulation completed'))) | |
| continue | |
| except Exception as policy_error: | |
| logger.error(f"❌ Policy simulation failed: {policy_error}") | |
| # Fall back to normal processing | |
| # First try to search for relevant documents | |
| search_results = None | |
| try: | |
| # Search for documents related to the user's query | |
| search_results = search_documents(user_message, limit=5) | |
| logger.info(f"🔍 Found {len(search_results) if search_results else 0} documents for query") | |
| except Exception as search_error: | |
| logger.warning(f"⚠️ Document search failed: {search_error}") | |
| # Get LLM response (with or without search context) | |
| try: | |
| if search_results and len(search_results) > 0: | |
| # Add search context to the message | |
| context_message = f"User query: {user_message}\n\nRelevant documents found:\n" | |
| for i, doc in enumerate(search_results[:3], 1): | |
| context_message += f"\n{i}. Source: {doc.get('filename', 'Unknown')}\nContent: {doc.get('content', '')[:400]}...\n" | |
| context_message += f"\nBased on the above documents, please provide a helpful response to the user's query: {user_message}" | |
| # Replace the user message with the enriched version | |
| messages[-1] = HumanMessage(content=context_message) | |
| result = await graph.ainvoke({"messages": messages}, config=config) | |
| llm_response = result["messages"][-1].content | |
| # Check if response contains scenario analysis images | |
| if "**SCENARIO_IMAGES_START**" in llm_response and "**SCENARIO_IMAGES_END**" in llm_response: | |
| # Extract images and text separately | |
| parts = llm_response.split("**SCENARIO_IMAGES_START**") | |
| text_response = parts[0].strip() | |
| image_part = parts[1].split("**SCENARIO_IMAGES_END**")[0].strip() | |
| try: | |
| import json | |
| images = json.loads(image_part) | |
| # Send text response first | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": text_response | |
| }) | |
| # Send images separately | |
| await websocket.send_json({ | |
| "type": "scenario_images", | |
| "images": images | |
| }) | |
| except json.JSONDecodeError: | |
| # If JSON parsing fails, send as regular text | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": llm_response | |
| }) | |
| else: | |
| # Send regular text response | |
| await websocket.send_json({ | |
| "type": "text_response", | |
| "message": llm_response | |
| }) | |
| # Add AI response to messages | |
| messages.append(AIMessage(content=llm_response)) | |
| logger.info(f"✅ Sent response to user: {thread_id}") | |
| except Exception as e: | |
| logger.error(f"❌ Error processing message: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Sorry, I encountered an error processing your message." | |
| }) | |
| elif data["type"] == "ping": | |
| # Handle ping for connection keep-alive | |
| await websocket.send_json({"type": "pong"}) | |
| elif data["type"] == "get_knowledge_bases": | |
| # Send available knowledge bases | |
| try: | |
| kb_list = await lancedb_service.get_knowledge_bases() | |
| await websocket.send_json({ | |
| "type": "knowledge_bases", | |
| "knowledge_bases": kb_list | |
| }) | |
| except Exception as e: | |
| logger.error(f"❌ Error getting knowledge bases: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Error retrieving knowledge bases" | |
| }) | |
| elif data["type"] == "end_session": | |
| logger.info("📞 Session ended by client") | |
| await websocket.close() | |
| break | |
| except WebSocketDisconnect: | |
| logger.info("🔌 WebSocket client disconnected.") | |
| except Exception as e: | |
| logger.error(f"❌ WebSocket error: {e}") | |
| try: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Connection error occurred" | |
| }) | |
| except: | |
| pass | |
| finally: | |
| # Clean up when session ends | |
| logger.info(f"🔄 Session {thread_id} ended") | |
| async def send_welcome_message(websocket: WebSocket): | |
| """Send welcome message to the client""" | |
| try: | |
| welcome_text = """🇮🇳 Welcome to the Government Services AI Assistant! | |
| I'm here to help you with: | |
| • Government policies and procedures | |
| • Document information and guidance | |
| • Service-specific questions and redirects | |
| • Voice or text interaction (your choice!) | |
| How can I assist you today?""" | |
| await websocket.send_text(json.dumps({ | |
| "type": "bot_message", | |
| "content": welcome_text, | |
| "timestamp": asyncio.get_event_loop().time() | |
| })) | |
| except Exception as e: | |
| logger.error(f"❌ Error sending welcome message: {e}") | |
| async def handle_text_message(websocket: WebSocket, message_data: Dict[str, Any]): | |
| """Handle text-based messages""" | |
| try: | |
| user_message = message_data.get("content", "") | |
| logger.info(f"💬 Processing text message: {user_message}") | |
| # Search for relevant documents | |
| context = "" | |
| try: | |
| search_results = search_documents(user_message, limit=3) | |
| if search_results: | |
| context = "\n".join([doc.get("content", "") for doc in search_results]) | |
| logger.info(f"📚 Found {len(search_results)} relevant documents") | |
| except Exception as e: | |
| logger.warning(f"⚠️ Document search failed: {e}") | |
| # Get response from hybrid LLM | |
| response_text = "" | |
| try: | |
| # Check if this is a streaming request | |
| stream_response = message_data.get("stream", True) | |
| if stream_response: | |
| # Send streaming response | |
| await websocket.send_text(json.dumps({ | |
| "type": "bot_message_start", | |
| "timestamp": asyncio.get_event_loop().time() | |
| })) | |
| async for chunk in hybrid_llm_service.get_streaming_response(user_message, context): | |
| response_text += chunk | |
| await websocket.send_text(json.dumps({ | |
| "type": "bot_message_chunk", | |
| "content": chunk, | |
| "timestamp": asyncio.get_event_loop().time() | |
| })) | |
| await asyncio.sleep(0.01) # Small delay for better streaming | |
| await websocket.send_text(json.dumps({ | |
| "type": "bot_message_end", | |
| "timestamp": asyncio.get_event_loop().time() | |
| })) | |
| else: | |
| # Send complete response | |
| response_text = await hybrid_llm_service.get_response(user_message, context) | |
| await websocket.send_text(json.dumps({ | |
| "type": "bot_message", | |
| "content": response_text, | |
| "timestamp": asyncio.get_event_loop().time() | |
| })) | |
| except Exception as e: | |
| logger.error(f"❌ Error getting LLM response: {e}") | |
| await websocket.send_text(json.dumps({ | |
| "type": "bot_message", | |
| "content": f"I apologize, but I encountered an error processing your request: {str(e)}", | |
| "timestamp": asyncio.get_event_loop().time() | |
| })) | |
| # Add government service redirect suggestions | |
| try: | |
| redirect_suggestions = voice_service.generate_redirect_suggestions(user_message, "text") | |
| if redirect_suggestions: | |
| await websocket.send_text(json.dumps({ | |
| "type": "redirect_suggestions", | |
| "content": redirect_suggestions, | |
| "timestamp": asyncio.get_event_loop().time() | |
| })) | |
| except Exception as e: | |
| logger.warning(f"⚠️ Could not generate redirect suggestions: {e}") | |
| except Exception as e: | |
| logger.error(f"❌ Error handling text message: {e}") | |
| await websocket.send_text(json.dumps({ | |
| "type": "error", | |
| "content": f"Error processing your message: {str(e)}" | |
| })) | |
| async def handle_voice_message(websocket: WebSocket, message_data: Dict[str, Any]): | |
| """Handle voice-based messages""" | |
| try: | |
| # Check if voice features are enabled | |
| if not voice_service.voice_enabled: | |
| await websocket.send_text(json.dumps({ | |
| "type": "error", | |
| "content": "Voice features are currently disabled. Please use text input." | |
| })) | |
| return | |
| audio_data = message_data.get("audio_data", "") | |
| if not audio_data: | |
| await websocket.send_text(json.dumps({ | |
| "type": "error", | |
| "content": "No audio data received" | |
| })) | |
| return | |
| logger.info("🎤 Processing voice message") | |
| # Convert speech to text | |
| try: | |
| transcribed_text = await voice_service.speech_to_text(audio_data) | |
| logger.info(f"📝 Transcribed: {transcribed_text}") | |
| # Send transcription to client | |
| await websocket.send_text(json.dumps({ | |
| "type": "transcription", | |
| "content": transcribed_text, | |
| "timestamp": asyncio.get_event_loop().time() | |
| })) | |
| except Exception as e: | |
| logger.error(f"❌ Speech-to-text failed: {e}") | |
| await websocket.send_text(json.dumps({ | |
| "type": "error", | |
| "content": f"Speech recognition failed: {str(e)}" | |
| })) | |
| except Exception as e: | |
| logger.error(f"❌ Error handling voice message: {e}") | |
| await websocket.send_text(json.dumps({ | |
| "type": "error", | |
| "content": f"Error processing voice message: {str(e)}" | |
| })) | |