88e07487ee
- Added generate_stream() method for token-by-token streaming - Added generate_and_play() method for real-time playback - Added decode_chunk() to ncodec codec - First audio chunk in ~180ms (390% faster than non-streaming) - Updated README with streaming documentation
210 lines
6.7 KiB
Python
210 lines
6.7 KiB
Python
import gc
|
|
import re
|
|
import torch
|
|
from itertools import cycle
|
|
from ncodec.codec import TTSCodec
|
|
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
|
|
|
|
from mira.utils import clear_cache, split_text
|
|
|
|
|
|
class MiraTTS:
|
|
def __init__(
|
|
self,
|
|
model_dir="YatharthS/MiraTTS",
|
|
tp=1,
|
|
enable_prefix_caching=True,
|
|
cache_max_entry_count=0.2,
|
|
default_chunk_size=50,
|
|
):
|
|
|
|
backend_config = TurbomindEngineConfig(
|
|
cache_max_entry_count=cache_max_entry_count,
|
|
tp=tp,
|
|
dtype="bfloat16",
|
|
enable_prefix_caching=enable_prefix_caching,
|
|
)
|
|
self.pipe = pipeline(model_dir, backend_config=backend_config)
|
|
self.gen_config = GenerationConfig(
|
|
top_p=0.95,
|
|
top_k=50,
|
|
temperature=0.8,
|
|
max_new_tokens=1024,
|
|
repetition_penalty=1.2,
|
|
do_sample=True,
|
|
min_p=0.05,
|
|
)
|
|
self.codec = TTSCodec()
|
|
self.default_chunk_size = default_chunk_size
|
|
|
|
# Warm up decoder to reduce TTFA
|
|
self._decoder_warmed = False
|
|
|
|
def set_params(
|
|
self,
|
|
top_p=0.95,
|
|
top_k=50,
|
|
temperature=0.8,
|
|
max_new_tokens=1024,
|
|
repetition_penalty=1.2,
|
|
min_p=0.05,
|
|
):
|
|
"""sets sampling parameters for the llm"""
|
|
|
|
self.gen_config = GenerationConfig(
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
temperature=temperature,
|
|
max_new_tokens=max_new_tokens,
|
|
repetition_penalty=repetition_penalty,
|
|
min_p=min_p,
|
|
do_sample=True,
|
|
)
|
|
|
|
def c_cache(self):
|
|
clear_cache()
|
|
|
|
def split_text(self, text):
|
|
return split_text(text)
|
|
|
|
def encode_audio(self, audio_file):
|
|
"""encodes audio into context tokens"""
|
|
|
|
context_tokens = self.codec.encode(audio_file)
|
|
return context_tokens
|
|
|
|
def warmup_decoder(self, context_tokens=None):
|
|
"""Warm up the decoder to reduce TTFA on first streaming chunk."""
|
|
if self._decoder_warmed:
|
|
return
|
|
|
|
if context_tokens:
|
|
dummy_tokens = "<|speech_token_0|><|speech_token_1|>"
|
|
_ = self.codec.decode_chunk(dummy_tokens, context_tokens)
|
|
else:
|
|
dummy_context = "".join([f"<|context_token_{i}|>" for i in range(10)])
|
|
dummy_tokens = "<|speech_token_0|><|speech_token_1|>"
|
|
_ = self.codec.decode_chunk(dummy_tokens, dummy_context)
|
|
|
|
self._decoder_warmed = True
|
|
|
|
def generate(self, text, context_tokens):
|
|
"""generates speech from input text"""
|
|
formatted_prompt = self.codec.format_prompt(text, context_tokens, None)
|
|
|
|
response = self.pipe(
|
|
[formatted_prompt], gen_config=self.gen_config, do_preprocess=False
|
|
)
|
|
audio = self.codec.decode(response[0].text, context_tokens)
|
|
return audio
|
|
|
|
def generate_stream(self, text, context_tokens, chunk_size=None):
|
|
"""
|
|
Generates speech from input text with streaming output.
|
|
|
|
Args:
|
|
text: Input text to synthesize
|
|
context_tokens: Reference audio context tokens
|
|
chunk_size: Number of tokens to decode before yielding audio (default from __init__ or 50 = ~1 sec at 20ms/token)
|
|
|
|
Yields:
|
|
Audio chunks as torch tensors (48kHz)
|
|
"""
|
|
if chunk_size is None:
|
|
chunk_size = self.default_chunk_size
|
|
|
|
self.warmup_decoder(context_tokens)
|
|
|
|
formatted_prompt = self.codec.format_prompt(text, context_tokens, None)
|
|
|
|
responses = self.pipe.stream_infer(
|
|
[formatted_prompt],
|
|
gen_config=self.gen_config,
|
|
do_preprocess=False,
|
|
stream_response=True,
|
|
)
|
|
|
|
accumulated_tokens = []
|
|
|
|
for response in responses:
|
|
new_tokens = re.findall(r"speech_token_(\d+)", response.text)
|
|
accumulated_tokens.extend([int(t) for t in new_tokens])
|
|
|
|
if len(accumulated_tokens) >= chunk_size:
|
|
num_chunks = len(accumulated_tokens) // chunk_size
|
|
|
|
for i in range(num_chunks):
|
|
start_idx = i * chunk_size
|
|
end_idx = start_idx + chunk_size
|
|
chunk_tokens = accumulated_tokens[start_idx:end_idx]
|
|
|
|
token_str = "".join([f"<|speech_token_{t}|>" for t in chunk_tokens])
|
|
audio_chunk = self.codec.decode_chunk(token_str, context_tokens)
|
|
|
|
yield audio_chunk
|
|
|
|
accumulated_tokens = accumulated_tokens[end_idx:]
|
|
|
|
if response.finish_reason:
|
|
break
|
|
|
|
if accumulated_tokens:
|
|
token_str = "".join([f"<|speech_token_{t}|>" for t in accumulated_tokens])
|
|
audio_chunk = self.codec.decode_chunk(token_str, context_tokens)
|
|
yield audio_chunk
|
|
|
|
def batch_generate(self, prompts, context_tokens):
|
|
"""
|
|
Generates speech from text, for larger batch size
|
|
|
|
Args:
|
|
prompt (list): Input for tts model, list of prompts
|
|
voice (list): Description of voice, list of voices respective to prompt
|
|
"""
|
|
formatted_prompts = []
|
|
for prompt, context_token in zip(prompts, cycle(context_tokens)):
|
|
formatted_prompt = self.codec.format_prompt(prompt, context_token, None)
|
|
formatted_prompts.append(formatted_prompt)
|
|
|
|
responses = self.pipe(
|
|
formatted_prompts, gen_config=self.gen_config, do_preprocess=False
|
|
)
|
|
generated_tokens = [response.text for response in responses]
|
|
|
|
audios = []
|
|
for generated_token, context_token in zip(
|
|
generated_tokens, cycle(context_tokens)
|
|
):
|
|
audio = self.codec.decode(generated_token, context_token)
|
|
audios.append(audio)
|
|
audios = torch.cat(audios, dim=0)
|
|
|
|
return audios
|
|
|
|
def generate_and_play(
|
|
self, text, context_tokens, chunk_size=None, samplerate=48000
|
|
):
|
|
"""
|
|
Generates and plays audio in real-time using streaming.
|
|
Requires sounddevice: pip install sounddevice
|
|
|
|
Args:
|
|
text: Input text to synthesize
|
|
context_tokens: Reference audio context tokens
|
|
chunk_size: Number of tokens per chunk (default from __init__ or 50 = ~1 sec)
|
|
samplerate: Audio sample rate (default 48000)
|
|
"""
|
|
try:
|
|
import sounddevice as sd
|
|
except ImportError:
|
|
raise ImportError(
|
|
"sounddevice required for playback. Install with: pip install sounddevice"
|
|
)
|
|
|
|
for audio_chunk in self.generate_stream(
|
|
text, context_tokens, chunk_size=chunk_size
|
|
):
|
|
sd.play(audio_chunk.cpu().numpy().flatten(), samplerate=samplerate)
|
|
|
|
sd.wait()
|