PensionBot / groq_websocket_handler.py
ChAbhishek28's picture
Add 8999999999999999999999999999
4e6d880
raw
history blame
17.7 kB
"""
Enhanced WebSocket Handler with Groq ASR integration
Based on friend's superior implementation with /ws/stream endpoint
Provides real-time voice processing with superior transcription accuracy
"""
import logging
import json
import asyncio
import tempfile
import os
import time
from typing import Dict, Any, Optional
from pathlib import Path
import uuid
from fastapi import WebSocket, WebSocketDisconnect
from groq_voice_service import groq_voice_service
from rag_service import hybrid_rag_service
logger = logging.getLogger("voicebot")
class GroqWebSocketHandler:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, Dict] = {}
async def connect(self, websocket: WebSocket, session_id: str = None):
"""Accept WebSocket connection and initialize session"""
await websocket.accept()
if not session_id:
session_id = str(uuid.uuid4())
self.active_connections[session_id] = websocket
self.user_sessions[session_id] = {
"connected_at": time.time(),
"message_count": 0,
"last_activity": time.time(),
"conversation_history": []
}
logger.info(f"🔗 WebSocket connected - Session: {session_id}")
# Send initial connection confirmation
await self.send_message(session_id, {
"type": "connection_established",
"session_id": session_id,
"voice_status": groq_voice_service.get_voice_status(),
"timestamp": time.time()
})
return session_id
async def disconnect(self, session_id: str):
"""Handle WebSocket disconnection"""
if session_id in self.active_connections:
del self.active_connections[session_id]
if session_id in self.user_sessions:
session_duration = time.time() - self.user_sessions[session_id]["connected_at"]
message_count = self.user_sessions[session_id]["message_count"]
logger.info(f"🔌 Session {session_id} ended - Duration: {session_duration:.1f}s, Messages: {message_count}")
del self.user_sessions[session_id]
async def send_message(self, session_id: str, message: Dict[str, Any]):
"""Send message to specific WebSocket connection"""
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_text(json.dumps(message))
return True
except Exception as e:
logger.error(f"❌ Failed to send message to {session_id}: {e}")
return False
return False
async def handle_stream_message(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]):
"""
Handle streaming messages from /ws/stream endpoint
Processes audio data with Groq ASR for superior transcription
"""
try:
message_type = message.get("type", "unknown")
if message_type == "audio_data":
await self._process_audio_stream(websocket, session_id, message)
elif message_type == "text_query":
await self._process_text_query(websocket, session_id, message)
elif message_type == "conversation_state":
await self._handle_conversation_state(websocket, session_id, message)
elif message_type == "voice_settings":
await self._handle_voice_settings(websocket, session_id, message)
else:
logger.warning(f"⚠️ Unknown message type: {message_type}")
await self.send_message(session_id, {
"type": "error",
"message": f"Unknown message type: {message_type}",
"timestamp": time.time()
})
except Exception as e:
logger.error(f"❌ Error handling stream message: {e}")
await self.send_message(session_id, {
"type": "error",
"message": f"Internal error: {str(e)}",
"timestamp": time.time()
})
async def _process_audio_stream(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]):
"""
Process streaming audio data with Groq ASR
Provides superior transcription accuracy compared to Whisper
"""
try:
# Send processing acknowledgment
await self.send_message(session_id, {
"type": "audio_processing_started",
"timestamp": time.time()
})
# Extract audio data
audio_data = message.get("audio_data")
user_language = message.get("language", "en")
if not audio_data:
await self.send_message(session_id, {
"type": "error",
"message": "No audio data provided",
"timestamp": time.time()
})
return
# Decode base64 audio data
import base64
try:
audio_bytes = base64.b64decode(audio_data)
except Exception as decode_error:
logger.error(f"❌ Audio decode error: {decode_error}")
await self.send_message(session_id, {
"type": "error",
"message": "Invalid audio data format",
"timestamp": time.time()
})
return
# Use Groq ASR for superior transcription
logger.info(f"🎤 Processing audio with Groq ASR - Language: {user_language}")
transcription_start = time.time()
transcribed_text = await groq_voice_service.groq_asr_bytes(audio_bytes, user_language)
transcription_time = time.time() - transcription_start
logger.info(f"🎤 Groq ASR completed in {transcription_time:.2f}s")
if not transcribed_text:
await self.send_message(session_id, {
"type": "transcription_failed",
"message": "Could not transcribe audio",
"timestamp": time.time()
})
return
# Send transcription result
await self.send_message(session_id, {
"type": "transcription_complete",
"transcribed_text": transcribed_text,
"processing_time": transcription_time,
"language": user_language,
"timestamp": time.time()
})
# Process the transcribed query
await self._process_transcribed_query(websocket, session_id, transcribed_text, user_language)
except Exception as e:
logger.error(f"❌ Audio processing error: {e}")
await self.send_message(session_id, {
"type": "error",
"message": f"Audio processing failed: {str(e)}",
"timestamp": time.time()
})
async def _process_transcribed_query(self, websocket: WebSocket, session_id: str, query: str, language: str = "en"):
"""Process transcribed query and generate response"""
try:
# Update session activity
if session_id in self.user_sessions:
self.user_sessions[session_id]["last_activity"] = time.time()
self.user_sessions[session_id]["message_count"] += 1
self.user_sessions[session_id]["conversation_history"].append({
"type": "user_voice",
"content": query,
"timestamp": time.time(),
"language": language
})
# Send query processing started
await self.send_message(session_id, {
"type": "query_processing_started",
"query": query,
"timestamp": time.time()
})
# Analyze query context for better response routing
query_context = await self._analyze_query_context(query)
# Send context analysis
await self.send_message(session_id, {
"type": "query_analysis",
"context": query_context,
"timestamp": time.time()
})
# Process with RAG service
processing_start = time.time()
if query_context["requires_documents"]:
logger.info(f"📄 Document search required for: {query}")
response_data = await hybrid_rag_service.search_and_generate_response(
query=query,
user_language=language,
conversation_history=self.user_sessions[session_id]["conversation_history"][-5:] # Last 5 messages
)
else:
logger.info(f"💬 General query: {query}")
response_data = await hybrid_rag_service.generate_simple_response(
query=query,
user_language=language
)
processing_time = time.time() - processing_start
# Send response
await self.send_message(session_id, {
"type": "response_complete",
"response": response_data.get("response", "I couldn't generate a response."),
"sources": response_data.get("sources", []),
"processing_time": processing_time,
"query_context": query_context,
"timestamp": time.time()
})
# Update conversation history
if session_id in self.user_sessions:
self.user_sessions[session_id]["conversation_history"].append({
"type": "assistant",
"content": response_data.get("response", ""),
"sources": response_data.get("sources", []),
"timestamp": time.time()
})
# Generate TTS if requested (can be enabled later)
# if generate_audio_requested:
# await self._generate_audio_response(websocket, session_id, response_data.get("response", ""))
except Exception as e:
logger.error(f"❌ Query processing error: {e}")
await self.send_message(session_id, {
"type": "error",
"message": f"Query processing failed: {str(e)}",
"timestamp": time.time()
})
async def _process_text_query(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]):
"""Process text-based query"""
query = message.get("query", "").strip()
language = message.get("language", "en")
if not query:
await self.send_message(session_id, {
"type": "error",
"message": "Empty query provided",
"timestamp": time.time()
})
return
await self._process_transcribed_query(websocket, session_id, query, language)
async def _analyze_query_context(self, query: str) -> Dict[str, Any]:
"""
Analyze query to determine context and routing
Enhanced logic to prioritize document search over generic responses
"""
query_lower = query.lower().strip()
# Government/pension related keywords that should trigger document search
govt_keywords = [
"pension", "retirement", "pf", "provident fund", "gratuity", "benefits",
"government", "policy", "rules", "regulation", "scheme", "allowance",
"service", "employee", "officer", "department", "ministry", "board",
"application", "form", "procedure", "process", "eligibility", "criteria",
"amount", "calculation", "rate", "percentage", "salary", "pay",
"medical", "health", "insurance", "coverage", "reimbursement",
"leave", "vacation", "sick", "maternity", "paternity",
"transfer", "posting", "promotion", "increment", "grade",
"tax", "income", "deduction", "exemption", "investment",
"documents", "certificate", "verification", "approval"
]
# Simple greetings and casual queries
casual_queries = [
"hello", "hi", "hey", "good morning", "good afternoon", "good evening",
"how are you", "what's up", "thanks", "thank you", "bye", "goodbye",
"what is your name", "who are you", "what can you do"
]
# Check for casual queries first
if any(casual in query_lower for casual in casual_queries):
return {
"requires_documents": False,
"query_type": "casual",
"confidence": 0.9,
"reason": "Casual greeting or simple query"
}
# Check for government/pension keywords
matched_keywords = [kw for kw in govt_keywords if kw in query_lower]
if matched_keywords:
return {
"requires_documents": True,
"query_type": "government_policy",
"confidence": 0.8,
"matched_keywords": matched_keywords,
"reason": f"Contains government/policy keywords: {', '.join(matched_keywords)}"
}
# Default: treat as document search unless clearly casual
if len(query.split()) > 2: # Multi-word queries likely need document search
return {
"requires_documents": True,
"query_type": "information_request",
"confidence": 0.6,
"reason": "Multi-word query likely needs document search"
}
return {
"requires_documents": False,
"query_type": "general",
"confidence": 0.5,
"reason": "Simple query, may not need documents"
}
async def _generate_audio_response(self, websocket: WebSocket, session_id: str, text: str):
"""Generate TTS audio for response"""
try:
await self.send_message(session_id, {
"type": "audio_generation_started",
"timestamp": time.time()
})
audio_data = await groq_voice_service.text_to_speech(text)
if audio_data:
import base64
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
await self.send_message(session_id, {
"type": "audio_response",
"audio_data": audio_base64,
"text": text,
"timestamp": time.time()
})
else:
await self.send_message(session_id, {
"type": "audio_generation_failed",
"message": "Could not generate audio",
"timestamp": time.time()
})
except Exception as e:
logger.error(f"❌ Audio generation error: {e}")
await self.send_message(session_id, {
"type": "error",
"message": f"Audio generation failed: {str(e)}",
"timestamp": time.time()
})
async def _handle_conversation_state(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]):
"""Handle conversation state updates"""
action = message.get("action", "")
if action == "get_history":
history = self.user_sessions.get(session_id, {}).get("conversation_history", [])
await self.send_message(session_id, {
"type": "conversation_history",
"history": history,
"timestamp": time.time()
})
elif action == "clear_history":
if session_id in self.user_sessions:
self.user_sessions[session_id]["conversation_history"] = []
await self.send_message(session_id, {
"type": "history_cleared",
"timestamp": time.time()
})
async def _handle_voice_settings(self, websocket: WebSocket, session_id: str, message: Dict[str, Any]):
"""Handle voice settings updates"""
settings = message.get("settings", {})
# Update session-specific settings if needed
if session_id in self.user_sessions:
self.user_sessions[session_id]["voice_settings"] = settings
await self.send_message(session_id, {
"type": "voice_settings_updated",
"settings": settings,
"timestamp": time.time()
})
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get session information"""
if session_id in self.user_sessions:
session = self.user_sessions[session_id].copy()
session["session_id"] = session_id
session["is_active"] = session_id in self.active_connections
return session
return None
def get_active_sessions_count(self) -> int:
"""Get number of active sessions"""
return len(self.active_connections)
# Global instance
groq_websocket_handler = GroqWebSocketHandler()