Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced WebSocket Handler with Groq ASR integration | |
| Based on friend's superior implementation with /ws/stream endpoint | |
| Provides real-time voice processing with superior transcription accuracy | |
| """ | |
| import logging | |
| import json | |
| import asyncio | |
| import tempfile | |
| import os | |
| import time | |
| from typing import Dict, Any, Optional | |
| from pathlib import Path | |
| import uuid | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from groq_voice_service import groq_voice_service | |
| from rag_service import hybrid_rag_service | |
| logger = logging.getLogger("voicebot") | |
| class GroqWebSocketHandler: | |
| def __init__(self): | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| self.user_sessions: Dict[str, Dict] = {} | |
| async def connect(self, websocket: WebSocket, session_id: str = None): | |
| """Accept WebSocket connection and initialize session""" | |
| await websocket.accept() | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| self.active_connections[session_id] = websocket | |
| self.user_sessions[session_id] = { | |
| "connected_at": time.time(), | |
| "message_count": 0, | |
| "last_activity": time.time(), | |
| "conversation_history": [] | |
| } | |
| logger.info(f"🔗 WebSocket connected - Session: {session_id}") | |
| # Send initial connection confirmation | |
| await self.send_message(session_id, { | |
| "type": "connection_established", | |
| "session_id": session_id, | |
| "voice_status": groq_voice_service.get_voice_status(), | |
| "timestamp": time.time() | |
| }) | |
| return session_id | |
| async def disconnect(self, session_id: str): | |
| """Handle WebSocket disconnection""" | |
| if session_id in self.active_connections: | |
| del self.active_connections[session_id] | |
| if session_id in self.user_sessions: | |
| session_duration = time.time() - self.user_sessions[session_id]["connected_at"] | |
| message_count = self.user_sessions[session_id]["message_count"] | |
| logger.info(f"🔌 Session {session_id} ended - Duration: {session_duration:.1f}s, Messages: {message_count}") | |
| del self.user_sessions[session_id] | |
| async def send_message(self, session_id: str, message: Dict[str, Any]): | |
| """Send message to specific WebSocket connection""" | |
| if session_id in self.active_connections: | |
| try: | |
| await self.active_connections[session_id].send_text(json.dumps(message)) | |
| return True | |
| except Exception as e: | |
| logger.error(f"❌ Failed to send message to {session_id}: {e}") | |
| return False | |
| return False | |
| async def handle_stream_message(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]): | |
| """ | |
| Handle streaming messages from /ws/stream endpoint | |
| Processes audio data with Groq ASR for superior transcription | |
| """ | |
| try: | |
| message_type = message.get("type", "unknown") | |
| if message_type == "audio_data": | |
| await self._process_audio_stream(websocket, session_id, message) | |
| elif message_type == "text_query": | |
| await self._process_text_query(websocket, session_id, message) | |
| elif message_type == "conversation_state": | |
| await self._handle_conversation_state(websocket, session_id, message) | |
| elif message_type == "voice_settings": | |
| await self._handle_voice_settings(websocket, session_id, message) | |
| else: | |
| logger.warning(f"⚠️ Unknown message type: {message_type}") | |
| await self.send_message(session_id, { | |
| "type": "error", | |
| "message": f"Unknown message type: {message_type}", | |
| "timestamp": time.time() | |
| }) | |
| except Exception as e: | |
| logger.error(f"❌ Error handling stream message: {e}") | |
| await self.send_message(session_id, { | |
| "type": "error", | |
| "message": f"Internal error: {str(e)}", | |
| "timestamp": time.time() | |
| }) | |
| async def _process_audio_stream(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]): | |
| """ | |
| Process streaming audio data with Groq ASR | |
| Provides superior transcription accuracy compared to Whisper | |
| """ | |
| try: | |
| # Send processing acknowledgment | |
| await self.send_message(session_id, { | |
| "type": "audio_processing_started", | |
| "timestamp": time.time() | |
| }) | |
| # Extract audio data | |
| audio_data = message.get("audio_data") | |
| user_language = message.get("language", "en") | |
| if not audio_data: | |
| await self.send_message(session_id, { | |
| "type": "error", | |
| "message": "No audio data provided", | |
| "timestamp": time.time() | |
| }) | |
| return | |
| # Decode base64 audio data | |
| import base64 | |
| try: | |
| audio_bytes = base64.b64decode(audio_data) | |
| except Exception as decode_error: | |
| logger.error(f"❌ Audio decode error: {decode_error}") | |
| await self.send_message(session_id, { | |
| "type": "error", | |
| "message": "Invalid audio data format", | |
| "timestamp": time.time() | |
| }) | |
| return | |
| # Use Groq ASR for superior transcription | |
| logger.info(f"🎤 Processing audio with Groq ASR - Language: {user_language}") | |
| transcription_start = time.time() | |
| transcribed_text = await groq_voice_service.groq_asr_bytes(audio_bytes, user_language) | |
| transcription_time = time.time() - transcription_start | |
| logger.info(f"🎤 Groq ASR completed in {transcription_time:.2f}s") | |
| if not transcribed_text: | |
| await self.send_message(session_id, { | |
| "type": "transcription_failed", | |
| "message": "Could not transcribe audio", | |
| "timestamp": time.time() | |
| }) | |
| return | |
| # Send transcription result | |
| await self.send_message(session_id, { | |
| "type": "transcription_complete", | |
| "transcribed_text": transcribed_text, | |
| "processing_time": transcription_time, | |
| "language": user_language, | |
| "timestamp": time.time() | |
| }) | |
| # Process the transcribed query | |
| await self._process_transcribed_query(websocket, session_id, transcribed_text, user_language) | |
| except Exception as e: | |
| logger.error(f"❌ Audio processing error: {e}") | |
| await self.send_message(session_id, { | |
| "type": "error", | |
| "message": f"Audio processing failed: {str(e)}", | |
| "timestamp": time.time() | |
| }) | |
| async def _process_transcribed_query(self, websocket: WebSocket, session_id: str, query: str, language: str = "en"): | |
| """Process transcribed query and generate response""" | |
| try: | |
| # Update session activity | |
| if session_id in self.user_sessions: | |
| self.user_sessions[session_id]["last_activity"] = time.time() | |
| self.user_sessions[session_id]["message_count"] += 1 | |
| self.user_sessions[session_id]["conversation_history"].append({ | |
| "type": "user_voice", | |
| "content": query, | |
| "timestamp": time.time(), | |
| "language": language | |
| }) | |
| # Send query processing started | |
| await self.send_message(session_id, { | |
| "type": "query_processing_started", | |
| "query": query, | |
| "timestamp": time.time() | |
| }) | |
| # Analyze query context for better response routing | |
| query_context = await self._analyze_query_context(query) | |
| # Send context analysis | |
| await self.send_message(session_id, { | |
| "type": "query_analysis", | |
| "context": query_context, | |
| "timestamp": time.time() | |
| }) | |
| # Process with RAG service | |
| processing_start = time.time() | |
| if query_context["requires_documents"]: | |
| logger.info(f"📄 Document search required for: {query}") | |
| response_data = await hybrid_rag_service.search_and_generate_response( | |
| query=query, | |
| user_language=language, | |
| conversation_history=self.user_sessions[session_id]["conversation_history"][-5:] # Last 5 messages | |
| ) | |
| else: | |
| logger.info(f"💬 General query: {query}") | |
| response_data = await hybrid_rag_service.generate_simple_response( | |
| query=query, | |
| user_language=language | |
| ) | |
| processing_time = time.time() - processing_start | |
| # Send response | |
| await self.send_message(session_id, { | |
| "type": "response_complete", | |
| "response": response_data.get("response", "I couldn't generate a response."), | |
| "sources": response_data.get("sources", []), | |
| "processing_time": processing_time, | |
| "query_context": query_context, | |
| "timestamp": time.time() | |
| }) | |
| # Update conversation history | |
| if session_id in self.user_sessions: | |
| self.user_sessions[session_id]["conversation_history"].append({ | |
| "type": "assistant", | |
| "content": response_data.get("response", ""), | |
| "sources": response_data.get("sources", []), | |
| "timestamp": time.time() | |
| }) | |
| # Generate TTS if requested (can be enabled later) | |
| # if generate_audio_requested: | |
| # await self._generate_audio_response(websocket, session_id, response_data.get("response", "")) | |
| except Exception as e: | |
| logger.error(f"❌ Query processing error: {e}") | |
| await self.send_message(session_id, { | |
| "type": "error", | |
| "message": f"Query processing failed: {str(e)}", | |
| "timestamp": time.time() | |
| }) | |
| async def _process_text_query(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]): | |
| """Process text-based query""" | |
| query = message.get("query", "").strip() | |
| language = message.get("language", "en") | |
| if not query: | |
| await self.send_message(session_id, { | |
| "type": "error", | |
| "message": "Empty query provided", | |
| "timestamp": time.time() | |
| }) | |
| return | |
| await self._process_transcribed_query(websocket, session_id, query, language) | |
| async def _analyze_query_context(self, query: str) -> Dict[str, Any]: | |
| """ | |
| Analyze query to determine context and routing | |
| Enhanced logic to prioritize document search over generic responses | |
| """ | |
| query_lower = query.lower().strip() | |
| # Government/pension related keywords that should trigger document search | |
| govt_keywords = [ | |
| "pension", "retirement", "pf", "provident fund", "gratuity", "benefits", | |
| "government", "policy", "rules", "regulation", "scheme", "allowance", | |
| "service", "employee", "officer", "department", "ministry", "board", | |
| "application", "form", "procedure", "process", "eligibility", "criteria", | |
| "amount", "calculation", "rate", "percentage", "salary", "pay", | |
| "medical", "health", "insurance", "coverage", "reimbursement", | |
| "leave", "vacation", "sick", "maternity", "paternity", | |
| "transfer", "posting", "promotion", "increment", "grade", | |
| "tax", "income", "deduction", "exemption", "investment", | |
| "documents", "certificate", "verification", "approval" | |
| ] | |
| # Simple greetings and casual queries | |
| casual_queries = [ | |
| "hello", "hi", "hey", "good morning", "good afternoon", "good evening", | |
| "how are you", "what's up", "thanks", "thank you", "bye", "goodbye", | |
| "what is your name", "who are you", "what can you do" | |
| ] | |
| # Check for casual queries first | |
| if any(casual in query_lower for casual in casual_queries): | |
| return { | |
| "requires_documents": False, | |
| "query_type": "casual", | |
| "confidence": 0.9, | |
| "reason": "Casual greeting or simple query" | |
| } | |
| # Check for government/pension keywords | |
| matched_keywords = [kw for kw in govt_keywords if kw in query_lower] | |
| if matched_keywords: | |
| return { | |
| "requires_documents": True, | |
| "query_type": "government_policy", | |
| "confidence": 0.8, | |
| "matched_keywords": matched_keywords, | |
| "reason": f"Contains government/policy keywords: {', '.join(matched_keywords)}" | |
| } | |
| # Default: treat as document search unless clearly casual | |
| if len(query.split()) > 2: # Multi-word queries likely need document search | |
| return { | |
| "requires_documents": True, | |
| "query_type": "information_request", | |
| "confidence": 0.6, | |
| "reason": "Multi-word query likely needs document search" | |
| } | |
| return { | |
| "requires_documents": False, | |
| "query_type": "general", | |
| "confidence": 0.5, | |
| "reason": "Simple query, may not need documents" | |
| } | |
| async def _generate_audio_response(self, websocket: WebSocket, session_id: str, text: str): | |
| """Generate TTS audio for response""" | |
| try: | |
| await self.send_message(session_id, { | |
| "type": "audio_generation_started", | |
| "timestamp": time.time() | |
| }) | |
| audio_data = await groq_voice_service.text_to_speech(text) | |
| if audio_data: | |
| import base64 | |
| audio_base64 = base64.b64encode(audio_data).decode('utf-8') | |
| await self.send_message(session_id, { | |
| "type": "audio_response", | |
| "audio_data": audio_base64, | |
| "text": text, | |
| "timestamp": time.time() | |
| }) | |
| else: | |
| await self.send_message(session_id, { | |
| "type": "audio_generation_failed", | |
| "message": "Could not generate audio", | |
| "timestamp": time.time() | |
| }) | |
| except Exception as e: | |
| logger.error(f"❌ Audio generation error: {e}") | |
| await self.send_message(session_id, { | |
| "type": "error", | |
| "message": f"Audio generation failed: {str(e)}", | |
| "timestamp": time.time() | |
| }) | |
| async def _handle_conversation_state(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]): | |
| """Handle conversation state updates""" | |
| action = message.get("action", "") | |
| if action == "get_history": | |
| history = self.user_sessions.get(session_id, {}).get("conversation_history", []) | |
| await self.send_message(session_id, { | |
| "type": "conversation_history", | |
| "history": history, | |
| "timestamp": time.time() | |
| }) | |
| elif action == "clear_history": | |
| if session_id in self.user_sessions: | |
| self.user_sessions[session_id]["conversation_history"] = [] | |
| await self.send_message(session_id, { | |
| "type": "history_cleared", | |
| "timestamp": time.time() | |
| }) | |
| async def _handle_voice_settings(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]): | |
| """Handle voice settings updates""" | |
| settings = message.get("settings", {}) | |
| # Update session-specific settings if needed | |
| if session_id in self.user_sessions: | |
| self.user_sessions[session_id]["voice_settings"] = settings | |
| await self.send_message(session_id, { | |
| "type": "voice_settings_updated", | |
| "settings": settings, | |
| "timestamp": time.time() | |
| }) | |
| def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """Get session information""" | |
| if session_id in self.user_sessions: | |
| session = self.user_sessions[session_id].copy() | |
| session["session_id"] = session_id | |
| session["is_active"] = session_id in self.active_connections | |
| return session | |
| return None | |
| def get_active_sessions_count(self) -> int: | |
| """Get number of active sessions""" | |
| return len(self.active_connections) | |
| # Global instance | |
| groq_websocket_handler = GroqWebSocketHandler() |