mmtts / app_simple.py
aungkomyat's picture
Create app_simple.py
8d7439b verified
raw
history blame
4.83 kB
import os
import sys
import gradio as gr
import numpy as np
import subprocess
import scipy.io.wavfile
from pathlib import Path
# Ensure the repository is cloned
REPO_URL = "https://github.com/hpbyte/myanmar-tts.git"
REPO_DIR = "myanmar-tts"
def setup():
"""Set up the environment by cloning the repository if needed."""
if not os.path.exists(REPO_DIR):
print(f"Cloning {REPO_URL}...")
subprocess.run(["git", "clone", REPO_URL], check=True)
# Add the repository to Python path
repo_path = os.path.abspath(REPO_DIR)
if repo_path not in sys.path:
sys.path.append(repo_path)
# Create model directory if it doesn't exist
if not os.path.exists("trained_model"):
os.makedirs("trained_model")
def text_to_speech(text):
"""Convert text to speech using Myanmar TTS."""
if not text.strip():
return None, "Please enter some text."
try:
# Try to import the necessary modules
try:
import torch
from text import text_to_sequence
from utils.hparams import create_hparams
from train import load_model
from synthesis import generate_speech
except ImportError:
# If direct import fails, try to import from the local module
import torch
from myanmar_tts import synthesize
# Use the simplified wrapper function
waveform, sample_rate = synthesize(text)
output_path = "output.wav"
scipy.io.wavfile.write(output_path, sample_rate, waveform)
return output_path, "Speech generated successfully!"
# If direct import worked, continue with standard approach
checkpoint_path = os.path.join("trained_model", "checkpoint_latest.pth.tar")
config_path = os.path.join("trained_model", "hparams.yml")
if not os.path.exists(checkpoint_path) or not os.path.exists(config_path):
return None, f"""Model files not found. Please upload:
1. The checkpoint file to: {checkpoint_path}
2. The hparams.yml file to: {config_path}
You can obtain these files from the original repository."""
# Load model and hyperparameters
hparams = create_hparams(config_path)
model = load_model(hparams)
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))['state_dict'])
model.eval()
# Process text input
sequence = np.array(text_to_sequence(text, ['burmese_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cpu().long()
# Generate mel spectrograms
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
# Generate waveform
with torch.no_grad():
waveform = generate_speech(mel_outputs_postnet, hparams)
# Save and return the audio
output_path = "output.wav"
scipy.io.wavfile.write(output_path, hparams.sampling_rate, waveform)
return output_path, "Speech generated successfully!"
except Exception as e:
error_msg = str(e)
detailed_msg = f"""Error: {error_msg}
Make sure you have:
1. Uploaded the model files to the 'trained_model' directory
2. The files are correctly named 'checkpoint_latest.pth.tar' and 'hparams.yml'
If you're still seeing this error, please check the repository for any specific setup instructions."""
return None, detailed_msg
# Set up the environment
setup()
# Create Gradio interface
demo = gr.Interface(
fn=text_to_speech,
inputs=[
gr.Textbox(
lines=3,
placeholder="Enter Burmese text here...",
label="Text"
)
],
outputs=[
gr.Audio(label="Generated Speech"),
gr.Textbox(label="Status", max_lines=10)
],
title="Myanmar (Burmese) Text-to-Speech",
description="""
This is a demo of the Myanmar Text-to-Speech system developed by hpbyte.
Enter Burmese text in the box below and click 'Submit' to generate speech.
**Important**: You need to upload the model files to the 'trained_model' directory:
- checkpoint_latest.pth.tar (the model checkpoint)
- hparams.yml (hyperparameters configuration)
Source: [GitHub Repository](https://github.com/hpbyte/myanmar-tts)
""",
examples=[
["မင်္ဂလာပါ"],
["မြန်မာစကားပြောစနစ်ကို ကြိုဆိုပါတယ်"],
["ဒီစနစ်ဟာ မြန်မာစာကို အသံအဖြစ် ပြောင်းပေးနိုင်ပါတယ်"],
]
)
if __name__ == "__main__":
demo.launch()