Spaces:
Paused
Paused
| import torch | |
| import torchaudio | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor | |
| import numpy as np | |
| from typing import Optional, Union | |
| import librosa | |
| import soundfile as sf | |
| import os | |
| class KyutaiSTTProcessor: | |
| """Processor for Kyutai Speech-to-Text model""" | |
| def __init__(self, device: str = "cuda"): | |
| self.device = device if torch.cuda.is_available() else "cpu" | |
| self.model = None | |
| self.processor = None | |
| self.model_id = "kyutai/stt-2.6b-en" # English-only model for better accuracy | |
| # Audio processing parameters | |
| self.sample_rate = 16000 | |
| self.chunk_length_s = 30 # Process in 30-second chunks | |
| self.max_duration = 120 # Maximum 2 minutes of audio | |
| def load_model(self): | |
| """Lazy load the STT model""" | |
| if self.model is None: | |
| try: | |
| # Load processor and model | |
| self.processor = AutoProcessor.from_pretrained(self.model_id) | |
| # Model configuration for low VRAM usage | |
| torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| self.model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| self.model_id, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True | |
| ) | |
| self.model.to(self.device) | |
| # Enable better generation settings | |
| self.model.generation_config.language = "english" | |
| self.model.generation_config.task = "transcribe" | |
| self.model.generation_config.forced_decoder_ids = None | |
| except Exception as e: | |
| print(f"Failed to load STT model: {e}") | |
| raise | |
| def preprocess_audio(self, audio_path: str) -> np.ndarray: | |
| """Preprocess audio file for transcription""" | |
| try: | |
| # Load audio file | |
| audio, sr = librosa.load(audio_path, sr=None, mono=True) | |
| # Resample if necessary | |
| if sr != self.sample_rate: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate) | |
| # Limit duration | |
| max_samples = self.max_duration * self.sample_rate | |
| if len(audio) > max_samples: | |
| audio = audio[:max_samples] | |
| # Normalize audio | |
| audio = audio / np.max(np.abs(audio) + 1e-7) | |
| return audio | |
| except Exception as e: | |
| print(f"Error preprocessing audio: {e}") | |
| raise | |
| def transcribe(self, audio_input: Union[str, np.ndarray]) -> str: | |
| """Transcribe audio to text""" | |
| try: | |
| # Load model if not already loaded | |
| self.load_model() | |
| # Process audio input | |
| if isinstance(audio_input, str): | |
| audio = self.preprocess_audio(audio_input) | |
| else: | |
| audio = audio_input | |
| # Process with model | |
| inputs = self.processor( | |
| audio, | |
| sampling_rate=self.sample_rate, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Generate transcription | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate( | |
| inputs["input_features"], | |
| max_new_tokens=128, | |
| do_sample=False, | |
| num_beams=1 # Greedy decoding for speed | |
| ) | |
| # Decode transcription | |
| transcription = self.processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| )[0] | |
| # Clean up transcription | |
| transcription = self._clean_transcription(transcription) | |
| return transcription | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| # Return a default description on error | |
| return "Create a unique digital monster companion" | |
| def _clean_transcription(self, text: str) -> str: | |
| """Clean up transcription output""" | |
| # Remove extra whitespace | |
| text = " ".join(text.split()) | |
| # Ensure proper capitalization | |
| if text and text[0].islower(): | |
| text = text[0].upper() + text[1:] | |
| # Add period if missing | |
| if text and not text[-1] in '.!?': | |
| text += '.' | |
| return text | |
| def transcribe_streaming(self, audio_stream): | |
| """Streaming transcription (for future implementation)""" | |
| # This would handle real-time audio streams | |
| # For now, return placeholder | |
| raise NotImplementedError("Streaming transcription not yet implemented") | |
| def to(self, device: str): | |
| """Move model to specified device""" | |
| self.device = device | |
| if self.model: | |
| self.model.to(device) | |
| def __del__(self): | |
| """Cleanup when object is destroyed""" | |
| if self.model: | |
| del self.model | |
| if self.processor: | |
| del self.processor | |
| torch.cuda.empty_cache() |