import gc import re import torch from itertools import cycle from ncodec.codec import TTSCodec from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig from mira.utils import clear_cache, split_text class MiraTTS: def __init__( self, model_dir="YatharthS/MiraTTS", tp=1, enable_prefix_caching=True, cache_max_entry_count=0.2, default_chunk_size=50, ): backend_config = TurbomindEngineConfig( cache_max_entry_count=cache_max_entry_count, tp=tp, dtype="bfloat16", enable_prefix_caching=enable_prefix_caching, ) self.pipe = pipeline(model_dir, backend_config=backend_config) self.gen_config = GenerationConfig( top_p=0.95, top_k=50, temperature=0.8, max_new_tokens=1024, repetition_penalty=1.2, do_sample=True, min_p=0.05, ) self.codec = TTSCodec() self.default_chunk_size = default_chunk_size # Warm up decoder to reduce TTFA self._decoder_warmed = False def set_params( self, top_p=0.95, top_k=50, temperature=0.8, max_new_tokens=1024, repetition_penalty=1.2, min_p=0.05, ): """sets sampling parameters for the llm""" self.gen_config = GenerationConfig( top_p=top_p, top_k=top_k, temperature=temperature, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, min_p=min_p, do_sample=True, ) def c_cache(self): clear_cache() def split_text(self, text): return split_text(text) def encode_audio(self, audio_file): """encodes audio into context tokens""" context_tokens = self.codec.encode(audio_file) return context_tokens def warmup_decoder(self, context_tokens=None): """Warm up the decoder to reduce TTFA on first streaming chunk.""" if self._decoder_warmed: return if context_tokens: dummy_tokens = "<|speech_token_0|><|speech_token_1|>" _ = self.codec.decode_chunk(dummy_tokens, context_tokens) else: dummy_context = "".join([f"<|context_token_{i}|>" for i in range(10)]) dummy_tokens = "<|speech_token_0|><|speech_token_1|>" _ = self.codec.decode_chunk(dummy_tokens, dummy_context) self._decoder_warmed = True def generate(self, text, context_tokens): """generates speech from input text""" formatted_prompt = self.codec.format_prompt(text, context_tokens, None) response = self.pipe( [formatted_prompt], gen_config=self.gen_config, do_preprocess=False ) audio = self.codec.decode(response[0].text, context_tokens) return audio def generate_stream(self, text, context_tokens, chunk_size=None): """ Generates speech from input text with streaming output. Args: text: Input text to synthesize context_tokens: Reference audio context tokens chunk_size: Number of tokens to decode before yielding audio (default from __init__ or 50 = ~1 sec at 20ms/token) Yields: Audio chunks as torch tensors (48kHz) """ if chunk_size is None: chunk_size = self.default_chunk_size self.warmup_decoder(context_tokens) formatted_prompt = self.codec.format_prompt(text, context_tokens, None) responses = self.pipe.stream_infer( [formatted_prompt], gen_config=self.gen_config, do_preprocess=False, stream_response=True, ) accumulated_tokens = [] for response in responses: new_tokens = re.findall(r"speech_token_(\d+)", response.text) accumulated_tokens.extend([int(t) for t in new_tokens]) if len(accumulated_tokens) >= chunk_size: num_chunks = len(accumulated_tokens) // chunk_size for i in range(num_chunks): start_idx = i * chunk_size end_idx = start_idx + chunk_size chunk_tokens = accumulated_tokens[start_idx:end_idx] token_str = "".join([f"<|speech_token_{t}|>" for t in chunk_tokens]) audio_chunk = self.codec.decode_chunk(token_str, context_tokens) yield audio_chunk accumulated_tokens = accumulated_tokens[end_idx:] if response.finish_reason: break if accumulated_tokens: token_str = "".join([f"<|speech_token_{t}|>" for t in accumulated_tokens]) audio_chunk = self.codec.decode_chunk(token_str, context_tokens) yield audio_chunk def batch_generate(self, prompts, context_tokens): """ Generates speech from text, for larger batch size Args: prompt (list): Input for tts model, list of prompts voice (list): Description of voice, list of voices respective to prompt """ formatted_prompts = [] for prompt, context_token in zip(prompts, cycle(context_tokens)): formatted_prompt = self.codec.format_prompt(prompt, context_token, None) formatted_prompts.append(formatted_prompt) responses = self.pipe( formatted_prompts, gen_config=self.gen_config, do_preprocess=False ) generated_tokens = [response.text for response in responses] audios = [] for generated_token, context_token in zip( generated_tokens, cycle(context_tokens) ): audio = self.codec.decode(generated_token, context_token) audios.append(audio) audios = torch.cat(audios, dim=0) return audios def generate_and_play( self, text, context_tokens, chunk_size=None, samplerate=48000 ): """ Generates and plays audio in real-time using streaming. Requires sounddevice: pip install sounddevice Args: text: Input text to synthesize context_tokens: Reference audio context tokens chunk_size: Number of tokens per chunk (default from __init__ or 50 = ~1 sec) samplerate: Audio sample rate (default 48000) """ try: import sounddevice as sd except ImportError: raise ImportError( "sounddevice required for playback. Install with: pip install sounddevice" ) for audio_chunk in self.generate_stream( text, context_tokens, chunk_size=chunk_size ): sd.play(audio_chunk.cpu().numpy().flatten(), samplerate=samplerate) sd.wait()