mmtts / myanmar_tts.py
aungkomyat's picture
Create myanmar_tts.py
b677e2c verified
"""
This is a simplified wrapper for myanmar-tts to handle import issues.
It's intended to make the HuggingFace Space deployment easier.
"""
import os
import sys
import importlib.util
# Add the repository directory to Python path if needed
REPO_DIR = "myanmar-tts"
if os.path.exists(REPO_DIR) and REPO_DIR not in sys.path:
sys.path.append(os.path.abspath(REPO_DIR))
# Try to import directly, or from the repository
try:
# First attempt: direct imports
from text import text_to_sequence
from utils.hparams import create_hparams
from train import load_model
from synthesis import generate_speech
except ImportError:
try:
# Second attempt: repository imports
from myanmar_tts.text import text_to_sequence
from myanmar_tts.utils.hparams import create_hparams
from myanmar_tts.train import load_model
from myanmar_tts.synthesis import generate_speech
except ImportError:
# If still failing, try to load modules dynamically
def load_module(module_name, file_path):
if not os.path.exists(file_path):
raise ImportError(f"Module file not found: {file_path}")
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
# Try to load critical modules
try:
text_module = load_module("text", os.path.join(REPO_DIR, "text", "__init__.py"))
text_to_sequence = text_module.text_to_sequence
hparams_module = load_module("hparams", os.path.join(REPO_DIR, "utils", "hparams.py"))
create_hparams = hparams_module.create_hparams
train_module = load_module("train", os.path.join(REPO_DIR, "train.py"))
load_model = train_module.load_model
synthesis_module = load_module("synthesis", os.path.join(REPO_DIR, "synthesis.py"))
generate_speech = synthesis_module.generate_speech
except Exception as e:
print(f"Failed to import myanmar-tts modules: {str(e)}")
raise
# Define a simple synthesis function
def synthesize(text, model_dir="trained_model"):
"""
Synthesize speech from the given text using the Myanmar TTS model.
Args:
text (str): The Burmese text to synthesize
model_dir (str): Directory containing the model files
Returns:
tuple: (waveform, sample_rate)
"""
import torch
import numpy as np
checkpoint_path = os.path.join(model_dir, "checkpoint_latest.pth.tar")
config_path = os.path.join(model_dir, "hparams.yml")
if not os.path.exists(checkpoint_path) or not os.path.exists(config_path):
raise FileNotFoundError(f"Model files not found in {model_dir}")
# Load the model
hparams = create_hparams(config_path)
model = load_model(hparams)
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))['state_dict'])
model.eval()
# Process text
sequence = np.array(text_to_sequence(text, ['burmese_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cpu().long()
# Generate mel spectrograms
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
# Generate waveform
with torch.no_grad():
waveform = generate_speech(mel_outputs_postnet, hparams)
return waveform, hparams.sampling_rate