First commit
This commit is contained in:
0
vibevoice/__init__.py
Normal file
0
vibevoice/__init__.py
Normal file
112
vibevoice/configs/qwen2.5_1.5b_64k.json
Normal file
112
vibevoice/configs/qwen2.5_1.5b_64k.json
Normal file
@ -0,0 +1,112 @@
|
||||
{
|
||||
"_attn_implementation_autoset": true,
|
||||
"acoustic_vae_dim": 64,
|
||||
"acoustic_tokenizer_config": {
|
||||
"causal": true,
|
||||
"channels": 1,
|
||||
"conv_bias": true,
|
||||
"conv_norm": "none",
|
||||
"corpus_normalize": 0.0,
|
||||
"decoder_depths": null,
|
||||
"decoder_n_filters": 32,
|
||||
"decoder_ratios": [
|
||||
8,
|
||||
5,
|
||||
5,
|
||||
4,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"disable_last_norm": true,
|
||||
"encoder_depths": "3-3-3-3-3-3-8",
|
||||
"encoder_n_filters": 32,
|
||||
"encoder_ratios": [
|
||||
8,
|
||||
5,
|
||||
5,
|
||||
4,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"fix_std": 0.5,
|
||||
"layer_scale_init_value": 1e-06,
|
||||
"layernorm": "RMSNorm",
|
||||
"layernorm_elementwise_affine": true,
|
||||
"layernorm_eps": 1e-05,
|
||||
"mixer_layer": "depthwise_conv",
|
||||
"model_type": "vibepod_acoustic_tokenizer",
|
||||
"pad_mode": "constant",
|
||||
"std_dist_type": "gaussian",
|
||||
"vae_dim": 64,
|
||||
"weight_init_value": 0.01
|
||||
},
|
||||
"decoder_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1536,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 8960,
|
||||
"max_position_embeddings": 65536,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 2,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": null,
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": null,
|
||||
"tie_word_embeddings": true,
|
||||
"torch_dtype": "bfloat16",
|
||||
"use_cache": true,
|
||||
"use_sliding_window": false,
|
||||
"vocab_size": 151936
|
||||
},
|
||||
"diffusion_head_config": {
|
||||
"ddpm_batch_mul": 4,
|
||||
"ddpm_beta_schedule": "cosine",
|
||||
"ddpm_num_inference_steps": 20,
|
||||
"ddpm_num_steps": 1000,
|
||||
"diffusion_type": "ddpm",
|
||||
"head_ffn_ratio": 3.0,
|
||||
"head_layers": 4,
|
||||
"hidden_size": 1536,
|
||||
"latent_size": 64,
|
||||
"model_type": "vibepod_diffusion_head",
|
||||
"prediction_type": "v_prediction",
|
||||
"rms_norm_eps": 1e-05,
|
||||
"speech_vae_dim": 64
|
||||
},
|
||||
"model_type": "vibepod",
|
||||
"semantic_tokenizer_config": {
|
||||
"causal": true,
|
||||
"channels": 1,
|
||||
"conv_bias": true,
|
||||
"conv_norm": "none",
|
||||
"corpus_normalize": 0.0,
|
||||
"disable_last_norm": true,
|
||||
"encoder_depths": "3-3-3-3-3-3-8",
|
||||
"encoder_n_filters": 32,
|
||||
"encoder_ratios": [
|
||||
8,
|
||||
5,
|
||||
5,
|
||||
4,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"fix_std": 0,
|
||||
"layer_scale_init_value": 1e-06,
|
||||
"layernorm": "RMSNorm",
|
||||
"layernorm_elementwise_affine": true,
|
||||
"layernorm_eps": 1e-05,
|
||||
"mixer_layer": "depthwise_conv",
|
||||
"model_type": "vibepod_semantic_tokenizer",
|
||||
"pad_mode": "constant",
|
||||
"std_dist_type": "none",
|
||||
"vae_dim": 128,
|
||||
"weight_init_value": 0.01
|
||||
},
|
||||
"semantic_vae_dim": 128,
|
||||
"torch_dtype": "bfloat16"
|
||||
}
|
||||
113
vibevoice/configs/qwen2.5_7b_32k.json
Normal file
113
vibevoice/configs/qwen2.5_7b_32k.json
Normal file
@ -0,0 +1,113 @@
|
||||
{
|
||||
"_attn_implementation_autoset": true,
|
||||
"acoustic_vae_dim": 64,
|
||||
"acoustic_tokenizer_config": {
|
||||
"causal": true,
|
||||
"channels": 1,
|
||||
"conv_bias": true,
|
||||
"conv_norm": "none",
|
||||
"corpus_normalize": 0.0,
|
||||
"decoder_depths": null,
|
||||
"decoder_n_filters": 32,
|
||||
"decoder_ratios": [
|
||||
8,
|
||||
5,
|
||||
5,
|
||||
4,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"disable_last_norm": true,
|
||||
"encoder_depths": "3-3-3-3-3-3-8",
|
||||
"encoder_n_filters": 32,
|
||||
"encoder_ratios": [
|
||||
8,
|
||||
5,
|
||||
5,
|
||||
4,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"fix_std": 0.5,
|
||||
"layer_scale_init_value": 1e-06,
|
||||
"layernorm": "RMSNorm",
|
||||
"layernorm_elementwise_affine": true,
|
||||
"layernorm_eps": 1e-05,
|
||||
"mixer_layer": "depthwise_conv",
|
||||
"model_type": "vibepod_acoustic_tokenizer",
|
||||
"pad_mode": "constant",
|
||||
"std_dist_type": "gaussian",
|
||||
"vae_dim": 64,
|
||||
"weight_init_value": 0.01
|
||||
},
|
||||
"decoder_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 32768,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": null,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.40.1",
|
||||
"use_cache": true,
|
||||
"use_mrope": false,
|
||||
"use_sliding_window": false,
|
||||
"vocab_size": 152064
|
||||
},
|
||||
"diffusion_head_config": {
|
||||
"ddpm_batch_mul": 4,
|
||||
"ddpm_beta_schedule": "cosine",
|
||||
"ddpm_num_inference_steps": 20,
|
||||
"ddpm_num_steps": 1000,
|
||||
"diffusion_type": "ddpm",
|
||||
"head_ffn_ratio": 3.0,
|
||||
"head_layers": 4,
|
||||
"hidden_size": 3584,
|
||||
"latent_size": 64,
|
||||
"model_type": "vibepod_diffusion_head",
|
||||
"prediction_type": "v_prediction",
|
||||
"rms_norm_eps": 1e-05,
|
||||
"speech_vae_dim": 64
|
||||
},
|
||||
"model_type": "vibepod",
|
||||
"semantic_tokenizer_config": {
|
||||
"causal": true,
|
||||
"channels": 1,
|
||||
"conv_bias": true,
|
||||
"conv_norm": "none",
|
||||
"corpus_normalize": 0.0,
|
||||
"disable_last_norm": true,
|
||||
"encoder_depths": "3-3-3-3-3-3-8",
|
||||
"encoder_n_filters": 32,
|
||||
"encoder_ratios": [
|
||||
8,
|
||||
5,
|
||||
5,
|
||||
4,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"fix_std": 0,
|
||||
"layer_scale_init_value": 1e-06,
|
||||
"layernorm": "RMSNorm",
|
||||
"layernorm_elementwise_affine": true,
|
||||
"layernorm_eps": 1e-05,
|
||||
"mixer_layer": "depthwise_conv",
|
||||
"model_type": "vibepod_semantic_tokenizer",
|
||||
"pad_mode": "constant",
|
||||
"std_dist_type": "none",
|
||||
"vae_dim": 128,
|
||||
"weight_init_value": 0.01
|
||||
},
|
||||
"semantic_vae_dim": 128,
|
||||
"torch_dtype": "bfloat16"
|
||||
}
|
||||
0
vibevoice/modular/__init__.py
Normal file
0
vibevoice/modular/__init__.py
Normal file
248
vibevoice/modular/configuration_vibevoice.py
Normal file
248
vibevoice/modular/configuration_vibevoice.py
Normal file
@ -0,0 +1,248 @@
|
||||
""" VibeVoice_AcousticTokenizer model configuration"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
|
||||
model_type = "vibevoice_acoustic_tokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 1,
|
||||
corpus_normalize: float = 0.0,
|
||||
causal: bool = True,
|
||||
vae_dim: int = 64,
|
||||
fix_std: float = 0.5,
|
||||
std_dist_type: str = 'gaussian',
|
||||
# common
|
||||
mixer_layer: str = 'depthwise_conv',
|
||||
conv_norm: str = 'none',
|
||||
pad_mode: str = 'constant',
|
||||
disable_last_norm: bool = True,
|
||||
layernorm: str = 'RMSNorm',
|
||||
layernorm_eps: float = 1e-5,
|
||||
layernorm_elementwise_affine: bool = True,
|
||||
conv_bias: bool = True,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
weight_init_value: float = 1e-2,
|
||||
# encoder specific
|
||||
encoder_n_filters: int = 32,
|
||||
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
|
||||
encoder_depths: str = "3-3-3-3-3-3-8",
|
||||
# decoder specific
|
||||
decoder_n_filters: int = 32,
|
||||
decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
|
||||
decoder_depths: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.channels = channels
|
||||
self.corpus_normalize = corpus_normalize
|
||||
self.causal = causal
|
||||
self.vae_dim = vae_dim
|
||||
self.fix_std = fix_std
|
||||
self.std_dist_type = std_dist_type
|
||||
|
||||
# common parameters
|
||||
self.conv_norm = conv_norm
|
||||
self.pad_mode = pad_mode
|
||||
self.layernorm_eps = layernorm_eps
|
||||
self.disable_last_norm = disable_last_norm
|
||||
self.layernorm = layernorm
|
||||
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
||||
self.conv_bias = conv_bias
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.weight_init_value = weight_init_value
|
||||
self.mixer_layer = mixer_layer
|
||||
|
||||
# encoder specific parameters
|
||||
self.encoder_n_filters = encoder_n_filters
|
||||
self.encoder_ratios = encoder_ratios
|
||||
self.encoder_depths = encoder_depths
|
||||
|
||||
# decoder specific parameters
|
||||
self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
|
||||
self.decoder_n_filters = decoder_n_filters
|
||||
self.decoder_depths = decoder_depths
|
||||
|
||||
|
||||
class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
|
||||
model_type = "vibevoice_semantic_tokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 1,
|
||||
corpus_normalize: float = 0.0,
|
||||
causal: bool = True,
|
||||
vae_dim: int = 64,
|
||||
fix_std: float = 0,
|
||||
std_dist_type: str = 'none',
|
||||
# common
|
||||
mixer_layer: str = 'depthwise_conv',
|
||||
conv_norm: str = 'none',
|
||||
pad_mode: str = 'constant',
|
||||
disable_last_norm: bool = True,
|
||||
layernorm: str = 'RMSNorm',
|
||||
layernorm_eps: float = 1e-5,
|
||||
layernorm_elementwise_affine: bool = True,
|
||||
conv_bias: bool = True,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
weight_init_value: float = 1e-2,
|
||||
# encoder specific
|
||||
encoder_n_filters: int = 32,
|
||||
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
|
||||
encoder_depths: str = "3-3-3-3-3-3-8",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.channels = channels
|
||||
self.corpus_normalize = corpus_normalize
|
||||
self.causal = causal
|
||||
self.vae_dim = vae_dim
|
||||
self.fix_std = fix_std
|
||||
self.std_dist_type = std_dist_type
|
||||
|
||||
# common parameters
|
||||
self.conv_norm = conv_norm
|
||||
self.pad_mode = pad_mode
|
||||
self.layernorm_eps = layernorm_eps
|
||||
self.disable_last_norm = disable_last_norm
|
||||
self.layernorm = layernorm
|
||||
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
||||
self.conv_bias = conv_bias
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.weight_init_value = weight_init_value
|
||||
self.mixer_layer = mixer_layer
|
||||
|
||||
# encoder specific parameters
|
||||
self.encoder_n_filters = encoder_n_filters
|
||||
self.encoder_ratios = encoder_ratios
|
||||
self.encoder_depths = encoder_depths
|
||||
|
||||
|
||||
class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
|
||||
model_type = "vibevoice_diffusion_head"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
head_layers=4,
|
||||
head_ffn_ratio=3.0,
|
||||
rms_norm_eps=1e-5,
|
||||
latent_size=64,
|
||||
speech_vae_dim=None,
|
||||
prediction_type="v_prediction",
|
||||
diffusion_type="ddpm",
|
||||
ddpm_num_steps=1000,
|
||||
ddpm_num_inference_steps=20,
|
||||
ddpm_beta_schedule="cosine",
|
||||
ddpm_batch_mul=4,
|
||||
**kwargs
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_layers = head_layers
|
||||
self.head_ffn_ratio = head_ffn_ratio
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.latent_size = latent_size
|
||||
self.speech_vae_dim = speech_vae_dim
|
||||
self.prediction_type = prediction_type
|
||||
self.diffusion_type = diffusion_type
|
||||
self.ddpm_num_steps = ddpm_num_steps
|
||||
self.ddpm_num_inference_steps = ddpm_num_inference_steps
|
||||
self.ddpm_beta_schedule = ddpm_beta_schedule
|
||||
self.ddpm_batch_mul = ddpm_batch_mul
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
class VibeVoiceConfig(PretrainedConfig):
|
||||
model_type = "vibevoice"
|
||||
is_composition = True
|
||||
sub_configs = {
|
||||
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
|
||||
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
|
||||
"decoder_config": Qwen2Config,
|
||||
"diffusion_head_config": VibeVoiceDiffusionHeadConfig,
|
||||
}
|
||||
# keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
acoustic_tokenizer_config=None,
|
||||
semantic_tokenizer_config=None,
|
||||
decoder_config=None,
|
||||
diffusion_head_config=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# kwargs["_attn_implementation"] = "flash_attention_2"
|
||||
kwargs["_attn_implementation_autoset"] = False
|
||||
|
||||
if acoustic_tokenizer_config is None:
|
||||
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
|
||||
elif isinstance(acoustic_tokenizer_config, dict):
|
||||
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
|
||||
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
|
||||
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
|
||||
# If an instance of the config class is provided
|
||||
self.acoustic_tokenizer_config = acoustic_tokenizer_config
|
||||
|
||||
if semantic_tokenizer_config is None:
|
||||
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
|
||||
elif isinstance(semantic_tokenizer_config, dict):
|
||||
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
|
||||
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
|
||||
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
|
||||
# If an instance of the config class is provided
|
||||
self.semantic_tokenizer_config = semantic_tokenizer_config
|
||||
|
||||
if decoder_config is None:
|
||||
self.decoder_config = self.sub_configs["decoder_config"]()
|
||||
elif isinstance(decoder_config, dict):
|
||||
# If a dictionary is provided, instantiate the config class with it
|
||||
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
|
||||
if decoder_config.get("model_type", '') == "qwen2":
|
||||
self.decoder_config = Qwen2Config(**decoder_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
|
||||
elif isinstance(decoder_config, (Qwen2Config,)):
|
||||
# If an instance of the config class is provided
|
||||
self.decoder_config = decoder_config
|
||||
|
||||
if diffusion_head_config is None:
|
||||
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
|
||||
elif isinstance(diffusion_head_config, dict):
|
||||
diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
|
||||
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
|
||||
elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
|
||||
# If an instance of the config class is provided
|
||||
self.diffusion_head_config = diffusion_head_config
|
||||
|
||||
# other parameters
|
||||
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
|
||||
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceAcousticTokenizerConfig",
|
||||
"VibeVoiceSemanticTokenizerConfig",
|
||||
"VibeVoiceDiffusionHeadConfig",
|
||||
"VibeVoiceConfig"
|
||||
]
|
||||
488
vibevoice/modular/modeling_vibevoice.py
Normal file
488
vibevoice/modular/modeling_vibevoice.py
Normal file
@ -0,0 +1,488 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union, Callable
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from transformers import modeling_utils
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
||||
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
||||
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
||||
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
diffusion_loss: Optional[torch.FloatTensor] = None
|
||||
speech_token_num: Optional[int] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceGenerationOutput(ModelOutput):
|
||||
"""
|
||||
Output type for VibeVoice generation.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences.
|
||||
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
||||
List of generated speech waveforms or latents for each speech segment.
|
||||
"""
|
||||
sequences: torch.LongTensor = None
|
||||
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class SpeechConnector(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_dim, output_dim)
|
||||
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
|
||||
self.fc2 = nn.Linear(output_dim, output_dim)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.fc1(features)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
# @auto_docstring
|
||||
class VibeVoicePreTrainedModel(PreTrainedModel):
|
||||
config_class = VibeVoiceConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, VibeVoiceDiffusionHead):
|
||||
module.initialize_weights()
|
||||
return
|
||||
|
||||
# Use the language model's initializer_range if available
|
||||
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
|
||||
std = self.config.language_model_config.initializer_range
|
||||
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
|
||||
std = self.config.decoder_config.initializer_range
|
||||
else:
|
||||
std = 0.02 # Default value
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# @auto_docstring
|
||||
class VibeVoiceModel(VibeVoicePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
||||
if isinstance(config.torch_dtype, str):
|
||||
dtype = getattr(torch, config.torch_dtype)
|
||||
else:
|
||||
dtype = config.torch_dtype
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# Initialize Qwen2 model for language modeling
|
||||
lm_config = config.decoder_config
|
||||
self.language_model = AutoModel.from_config(lm_config)
|
||||
|
||||
# Initialize speech components if needed
|
||||
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
|
||||
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
|
||||
|
||||
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||
|
||||
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
|
||||
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
|
||||
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
|
||||
|
||||
# Initialize prediction head for speech generation
|
||||
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
|
||||
|
||||
# Initialize noise scheduler
|
||||
self.noise_scheduler = DPMSolverMultistepScheduler(
|
||||
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
|
||||
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
|
||||
prediction_type=config.diffusion_head_config.prediction_type
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
if hasattr(self.language_model, 'embed_tokens'):
|
||||
# If the language model has an embed_tokens attribute, return it
|
||||
return self.language_model.embed_tokens
|
||||
|
||||
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
|
||||
if attr.orig_name == 'embed_tokens.weight':
|
||||
return getattr(self.language_model, name)
|
||||
assert False, 'should not arrive here'
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.embed_tokens = value
|
||||
|
||||
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
||||
"""Set the speech tokenizers used for encoding and decoding speech."""
|
||||
self.acoustic_tokenizer = acoustic_tokenizer
|
||||
self.semantic_tokenizer = semantic_tokenizer
|
||||
|
||||
# Reset the encoder to evaluation mode
|
||||
if self.acoustic_tokenizer is not None:
|
||||
self.acoustic_tokenizer.eval()
|
||||
|
||||
if self.semantic_tokenizer is not None:
|
||||
self.semantic_tokenizer.eval()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Forward through language model
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return outputs
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = VibeVoiceModel(config)
|
||||
self.vocab_size = config.decoder_config.vocab_size
|
||||
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.language_model
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
"""
|
||||
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
|
||||
# The standard PreTrainedModel method will handle the tying.
|
||||
# It typically does a simple parameter object assignment, which is
|
||||
# CORRECT to do BEFORE FSDP wraps the model.
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
if hasattr(input_embeddings, 'weight'):
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
else:
|
||||
# maybe returned input_embeddings a tensor directly
|
||||
output_embeddings.weight = input_embeddings
|
||||
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
output_embeddings.bias.data = nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
print("✅ Tied input and output embeddings using standard assignment.")
|
||||
else:
|
||||
print("ℹ️ tie_word_embeddings is False, not tying weights.")
|
||||
|
||||
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
|
||||
# The key is to avoid calling it after accelerator.prepare().
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
# Your current implementation using data.copy_ is good practice,
|
||||
# but the best way is to not call this after prepare().
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def forward_speech_features(
|
||||
self,
|
||||
speech_tensors=None,
|
||||
speech_masks=None,
|
||||
speech_type="audio",
|
||||
return_unmask=False
|
||||
):
|
||||
if speech_tensors is None:
|
||||
# Use config to get vae_dim instead of non-existent self.args
|
||||
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
||||
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
|
||||
connect_features = self.model.acoustic_connector(audio_features)
|
||||
return audio_features, connect_features
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if speech_type == "audio":
|
||||
with torch.no_grad():
|
||||
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
|
||||
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||
|
||||
elif speech_type == "vae":
|
||||
# Use config to get vae_dim instead of non-existent self.args
|
||||
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
||||
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
|
||||
|
||||
# gaussian sample from the speech_mode
|
||||
batch_size = speech_mode.size(0)
|
||||
value = self.model.acoustic_tokenizer.fix_std / 0.8
|
||||
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
|
||||
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
|
||||
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
|
||||
else:
|
||||
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
||||
|
||||
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
|
||||
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
|
||||
bias_factor = -audio_tokens[speech_masks].flatten().mean()
|
||||
|
||||
# Only use distributed operations if the process group is initialized
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
|
||||
world_size = dist.get_world_size()
|
||||
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
|
||||
self.model.speech_bias_factor.copy_(bias_factor / world_size)
|
||||
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
||||
else:
|
||||
# Single process case
|
||||
self.model.speech_scaling_factor.copy_(scaling_factor)
|
||||
self.model.speech_bias_factor.copy_(bias_factor)
|
||||
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
||||
|
||||
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
|
||||
|
||||
connect_features = self.model.acoustic_connector(audio_features)
|
||||
if return_unmask:
|
||||
return audio_features, connect_features
|
||||
return audio_features[speech_masks], connect_features[speech_masks]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
# New arguments for speech processing and loss calculation
|
||||
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speeches_loss_input: Optional[torch.FloatTensor] = None,
|
||||
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
||||
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
||||
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
|
||||
ddpm_batch_mul: int = 1,
|
||||
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
|
||||
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
x = self.get_input_embeddings()(input_ids)
|
||||
|
||||
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
|
||||
if speeches_loss_input is not None:
|
||||
# only part audio need diffuse
|
||||
speech_all_features, speech_all_connect_features = self.forward_speech_features(
|
||||
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
||||
speech_masks=speech_masks,
|
||||
speech_type=kwargs.get("speech_type", "audio"),
|
||||
return_unmask=True
|
||||
)
|
||||
if speech_tensors is not None:
|
||||
if semantic_speech_all_connect_features is not None:
|
||||
x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
|
||||
else:
|
||||
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
|
||||
speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
|
||||
speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
|
||||
else:
|
||||
speech_features, speech_connect_features = self.forward_speech_features(
|
||||
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
||||
speech_masks=speech_masks,
|
||||
speech_type=kwargs.get("speech_type", "audio"),
|
||||
)
|
||||
if speech_tensors is not None:
|
||||
x[acoustic_input_mask] = speech_connect_features
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=x,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=False,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
logits = self.lm_head(hidden_states)
|
||||
# logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# The custom CE loss with masking is calculated in the training script.
|
||||
# We leave the standard loss calculation here as None.
|
||||
pass
|
||||
|
||||
# --- Diffusion Loss Calculation ---
|
||||
diffusion_loss = None
|
||||
# This block is executed only if we are in a context that involves speech.
|
||||
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
|
||||
condition_features = hidden_states[acoustic_loss_mask]
|
||||
|
||||
speech_len, latent_size = speech_features.shape
|
||||
|
||||
noise = torch.randn(
|
||||
(speech_len * ddpm_batch_mul, latent_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
timesteps = torch.multinomial(
|
||||
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
|
||||
speech_len * ddpm_batch_mul,
|
||||
replacement=True,
|
||||
).to(hidden_states.device)
|
||||
|
||||
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
||||
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
||||
|
||||
noisy_speech_features = self.model.noise_scheduler.add_noise(
|
||||
speech_features_repeated, noise, timesteps
|
||||
)
|
||||
|
||||
model_output = self.model.prediction_head(
|
||||
noisy_speech_features,
|
||||
timesteps.type_as(x),
|
||||
condition_features_repeated
|
||||
)
|
||||
|
||||
prediction_type = self.config.diffusion_head_config.prediction_type
|
||||
if prediction_type == "epsilon":
|
||||
target_for_loss = noise
|
||||
elif prediction_type == "v_prediction":
|
||||
target_for_loss = self.model.noise_scheduler.get_velocity(
|
||||
speech_features_repeated, noise, timesteps
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
|
||||
|
||||
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
|
||||
if latent_size > 0 and ddpm_batch_mul > 0:
|
||||
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
|
||||
else:
|
||||
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
|
||||
|
||||
else:
|
||||
# Dummy loss for DDP to work when there are no speech samples in a batch,
|
||||
# but we are in a speech context.
|
||||
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
|
||||
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
|
||||
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
|
||||
# --- End Diffusion Loss Calculation ---
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, speech_len) + outputs.to_tuple()[1:]
|
||||
return (loss, diffusion_loss) + output
|
||||
|
||||
return VibeVoiceCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
diffusion_loss=diffusion_loss,
|
||||
speech_token_num=speech_len if speech_tensors is not None else 0,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
|
||||
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceModel",
|
||||
"VibeVoicePreTrainedModel",
|
||||
"VibeVoiceForConditionalGeneration",
|
||||
"VibeVoiceCausalLMOutputWithPast",
|
||||
"VibeVoiceGenerationOutput",
|
||||
]
|
||||
731
vibevoice/modular/modeling_vibevoice_inference.py
Normal file
731
vibevoice/modular/modeling_vibevoice_inference.py
Normal file
@ -0,0 +1,731 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union, Callable
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from transformers import modeling_utils
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
# from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
||||
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
|
||||
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
||||
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceConfig
|
||||
|
||||
from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
|
||||
|
||||
from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
|
||||
from .streamer import AudioStreamer, AsyncAudioStreamer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
||||
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceGenerationOutput(ModelOutput):
|
||||
"""
|
||||
Output type for VibeVoice generation.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences.
|
||||
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
||||
List of generated speech waveforms or latents for each speech segment.
|
||||
"""
|
||||
sequences: torch.LongTensor = None
|
||||
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
||||
reach_max_step_sample: Optional[torch.BoolTensor] = None
|
||||
|
||||
class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
|
||||
"""Constrains token generation to only valid tokens during speech generation."""
|
||||
|
||||
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
|
||||
self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Create a mask for valid tokens
|
||||
mask = torch.full_like(scores, float('-inf'))
|
||||
mask[:, self.valid_token_ids] = 0
|
||||
|
||||
# Apply mask to scores
|
||||
scores = scores + mask
|
||||
return scores
|
||||
|
||||
class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
# Initialize the base model
|
||||
self.model = VibeVoiceModel(config)
|
||||
|
||||
# LM head for text generation
|
||||
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
|
||||
|
||||
# inference configuration
|
||||
self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def noise_scheduler(self):
|
||||
return self.model.noise_scheduler
|
||||
|
||||
@property
|
||||
def prediction_head(self):
|
||||
return self.model.prediction_head
|
||||
|
||||
@property
|
||||
def speech_scaling_factor(self):
|
||||
return self.model.speech_scaling_factor
|
||||
|
||||
@property
|
||||
def speech_bias_factor(self):
|
||||
return self.model.speech_bias_factor
|
||||
|
||||
@property
|
||||
def acoustic_tokenizer(self):
|
||||
return self.model.acoustic_tokenizer
|
||||
|
||||
@property
|
||||
def semantic_tokenizer(self):
|
||||
return self.model.semantic_tokenizer
|
||||
|
||||
@property
|
||||
def acoustic_connector(self):
|
||||
return self.model.acoustic_connector
|
||||
|
||||
@property
|
||||
def semantic_connector(self):
|
||||
return self.model.semantic_connector
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
"""
|
||||
# Tie lm_head.weight to language_model.embed_tokens.weight
|
||||
if not getattr(self.config, 'tie_word_embeddings', False):
|
||||
return
|
||||
|
||||
if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
|
||||
self.lm_head.weight = self.model.language_model.embed_tokens.weight
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
||||
"""Set the speech tokenizers used for encoding and decoding speech."""
|
||||
self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
|
||||
|
||||
def set_ddpm_inference_steps(self, num_steps=None):
|
||||
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
|
||||
|
||||
def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
|
||||
"""Process speech inputs through tokenizers and connectors."""
|
||||
with torch.no_grad():
|
||||
if speech_type == "audio":
|
||||
# Encode audio to acoustic latents
|
||||
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
|
||||
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||
|
||||
# Apply scaling and bias
|
||||
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
|
||||
|
||||
# Connect to language model space
|
||||
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
|
||||
|
||||
return acoustic_features, acoustic_connected
|
||||
elif speech_type == "pt":
|
||||
encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
|
||||
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||
|
||||
# Apply scaling and bias
|
||||
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
|
||||
|
||||
# Connect to language model space
|
||||
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
|
||||
|
||||
return acoustic_features, acoustic_connected
|
||||
else:
|
||||
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
||||
|
||||
# @can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speech_input_mask: Optional[torch.BoolTensor] = None,
|
||||
logits_to_keep: Union[int, slice] = 0,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
||||
"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
speech_tensors (`torch.FloatTensor`, *optional*):
|
||||
Input speech waveforms for voice cloning or speech understanding.
|
||||
speech_masks (`torch.BoolTensor`, *optional*):
|
||||
Masks indicating valid speech frames.
|
||||
speech_input_mask (`torch.BoolTensor`, *optional*):
|
||||
Positions in the input sequence where speech embeddings should be inserted.
|
||||
|
||||
Returns:
|
||||
`VibeVoiceCausalLMOutputWithPast` or tuple
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Get embeddings
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||
|
||||
# Process speech inputs if provided
|
||||
if speech_tensors is not None and speech_masks is not None:
|
||||
acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors.to(self.dtype), speech_masks)
|
||||
if speech_input_mask is not None:
|
||||
inputs_embeds[speech_input_mask] = speech_embeds
|
||||
|
||||
outputs = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
raise NotImplementedError("Loss computation is not implemented in this version.")
|
||||
|
||||
return VibeVoiceCausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
last_hidden_state=hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
|
||||
if generation_config is None:
|
||||
generation_config = GenerationConfig(
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
)
|
||||
else:
|
||||
generation_config = GenerationConfig(
|
||||
**generation_config,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
generation_config, model_kwargs = self._prepare_generation_config(
|
||||
generation_config,
|
||||
True,
|
||||
speech_start_id=tokenizer.speech_start_id,
|
||||
speech_end_id=tokenizer.speech_end_id,
|
||||
speech_diffusion_id=tokenizer.speech_diffusion_id,
|
||||
**kwargs
|
||||
)
|
||||
generation_config.speech_start_id = tokenizer.speech_start_id
|
||||
generation_config.speech_end_id = tokenizer.speech_end_id
|
||||
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
|
||||
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
device = self.device
|
||||
|
||||
self._prepare_special_tokens(generation_config, True, device=device)
|
||||
generation_config.use_cache = True
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
input_ids = inputs_tensor.to(self.device)
|
||||
|
||||
input_ids_length = input_ids.shape[1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
has_default_min_length=has_default_min_length,
|
||||
model_input_name=model_input_name,
|
||||
inputs_tensor=inputs_tensor,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
max_cache_length = generation_config.max_length - 1
|
||||
# Backwards compatible fix for _prepare_cache_for_generation method signature
|
||||
# New transformers version expects 5 args, old version expects 6
|
||||
import inspect
|
||||
try:
|
||||
sig = inspect.signature(self._prepare_cache_for_generation)
|
||||
if len(sig.parameters) == 5:
|
||||
# New transformers version (4.56+)
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
|
||||
else:
|
||||
# Old transformers version (pre-4.56)
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
except Exception as e:
|
||||
# Fallback to try both versions
|
||||
try:
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
|
||||
except TypeError:
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
|
||||
for k, v in model_kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
model_kwargs[k] = v.to(device=device)
|
||||
|
||||
if return_processors:
|
||||
logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=LogitsProcessorList(),
|
||||
device=inputs_tensor.device,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
|
||||
|
||||
return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
|
||||
else:
|
||||
return generation_config, model_kwargs, input_ids
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speech_input_mask: Optional[torch.BoolTensor] = None,
|
||||
return_speech: bool = True,
|
||||
cfg_scale: float = 1.0,
|
||||
stop_check_fn: Optional[Callable[[], bool]] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
|
||||
"""
|
||||
Generates sequences of token ids and optionally speech outputs.
|
||||
|
||||
Args:
|
||||
All standard generation arguments from GenerationMixin
|
||||
negative_prompt_ids: Negative prompt for CFG in speech generation
|
||||
negative_prompt_attention_mask: Attention mask for negative prompt
|
||||
speech_tensors: Input speech for voice cloning
|
||||
speech_masks: Masks for speech tensors
|
||||
speech_input_mask: Positions to insert speech embeddings
|
||||
return_speech: Whether to decode and return speech outputs
|
||||
cfg_scale: CFG scale for speech generation
|
||||
stop_check_fn: Optional callable that returns True if generation should stop
|
||||
|
||||
Returns:
|
||||
Generated token sequences and optionally speech outputs
|
||||
"""
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||
parsed_scripts = kwargs.pop("parsed_scripts", None)
|
||||
all_speakers_list = kwargs.pop("all_speakers_list", None)
|
||||
max_length_times = kwargs.pop("max_length_times", 2)
|
||||
|
||||
if kwargs.get('max_new_tokens', None) is None:
|
||||
kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
|
||||
|
||||
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
|
||||
generation_config, inputs, tokenizer, return_processors=True, **kwargs
|
||||
)
|
||||
|
||||
negative_kwargs = {
|
||||
'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
|
||||
'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
|
||||
'max_new_tokens': kwargs.get('max_new_tokens', 100)
|
||||
}
|
||||
negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
|
||||
None, None, tokenizer, return_processors=False, **negative_kwargs
|
||||
)
|
||||
|
||||
acoustic_cache = VibeVoiceTokenizerStreamingCache()
|
||||
semantic_cache = VibeVoiceTokenizerStreamingCache()
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
device = input_ids.device
|
||||
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
||||
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
is_prefill = True
|
||||
inputs_embeds = None
|
||||
verbose = kwargs.get("verbose", False)
|
||||
|
||||
# Initialize audio chunks storage for each sample
|
||||
audio_chunks = [[] for _ in range(batch_size)]
|
||||
|
||||
initial_length = input_ids.shape[-1]
|
||||
initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
|
||||
|
||||
# Define all valid tokens that can be generated
|
||||
valid_tokens = [
|
||||
generation_config.speech_start_id,
|
||||
generation_config.speech_end_id,
|
||||
generation_config.speech_diffusion_id,
|
||||
generation_config.eos_token_id
|
||||
]
|
||||
# Add bos_token_id if it exists
|
||||
if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
|
||||
valid_tokens.append(generation_config.bos_token_id)
|
||||
|
||||
# Add custom processor to constrain token generation
|
||||
token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(token_constraint_processor)
|
||||
|
||||
max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
|
||||
max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
|
||||
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create progress iterator if verbose
|
||||
if kwargs.get("show_progress_bar", True):
|
||||
progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
|
||||
else:
|
||||
progress_bar = range(max_steps)
|
||||
|
||||
for step in progress_bar:
|
||||
# Check for external stop signal
|
||||
if stop_check_fn is not None and stop_check_fn():
|
||||
if verbose:
|
||||
print(f"Generation stopped externally at step {step + 1}")
|
||||
# End the audio streamer if it exists
|
||||
if audio_streamer is not None:
|
||||
audio_streamer.end()
|
||||
break
|
||||
|
||||
# Check if audio_streamer has been ended (stopped externally)
|
||||
if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
|
||||
if any(audio_streamer.finished_flags):
|
||||
if verbose:
|
||||
print(f"Audio generation stopped externally at step {step + 1}")
|
||||
break
|
||||
|
||||
if finished_tags.all():
|
||||
if hasattr(progress_bar, 'set_description'):
|
||||
progress_bar.set_description("Generation complete")
|
||||
break
|
||||
|
||||
if input_ids.shape[-1] >= generation_config.max_length:
|
||||
print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
|
||||
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
|
||||
if reached_samples.numel() > 0:
|
||||
reach_max_step_sample[reached_samples] = True
|
||||
break
|
||||
|
||||
# Update progress bar description with active samples
|
||||
if hasattr(progress_bar, 'set_description'):
|
||||
active_samples = (~finished_tags).sum().item()
|
||||
progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
if is_prefill:
|
||||
# we process the speech inputs only during the first generation step
|
||||
prefill_inputs = {
|
||||
"speech_tensors": speech_tensors.to(device=device),
|
||||
"speech_masks": speech_masks.to(device),
|
||||
"speech_input_mask": speech_input_mask.to(device),
|
||||
}
|
||||
is_prefill = False
|
||||
else:
|
||||
_ = model_inputs.pop('inputs_embeds', None)
|
||||
prefill_inputs = {'inputs_embeds': inputs_embeds}
|
||||
|
||||
# Forward pass through the model
|
||||
outputs = self(
|
||||
**model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
|
||||
)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=False,
|
||||
)
|
||||
|
||||
# Get logits and apply logits processor
|
||||
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
||||
# next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
|
||||
# token selection
|
||||
if generation_config.do_sample:
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
next_tokens[finished_tags] = generation_config.eos_token_id
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
||||
if not kwargs.get('refresh_negative', True):
|
||||
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
|
||||
# Forward negative pass through the model
|
||||
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
|
||||
negative_model_inputs['inputs_embeds'] = inputs_embeds
|
||||
negative_model_inputs['input_ids'] = None
|
||||
|
||||
negative_outputs = self(
|
||||
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
|
||||
)
|
||||
negative_model_kwargs = self._update_model_kwargs_for_generation(
|
||||
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
|
||||
)
|
||||
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
||||
# reached end of generation
|
||||
if (next_tokens == generation_config.eos_token_id).any():
|
||||
eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
|
||||
# Only print for samples that are newly finished (not already marked as finished)
|
||||
new_eos_indices = eos_indices[~finished_tags[eos_indices]]
|
||||
if new_eos_indices.numel() > 0:
|
||||
finished_tags[new_eos_indices] = True
|
||||
if verbose:
|
||||
print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
|
||||
if audio_streamer is not None:
|
||||
audio_streamer.end(new_eos_indices)
|
||||
|
||||
# Check if any sample reached its maximum generation length
|
||||
max_length_reached = step >= max_step_per_sample
|
||||
new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
|
||||
if new_max_length_indices.numel() > 0:
|
||||
finished_tags[new_max_length_indices] = True
|
||||
reach_max_step_sample[new_max_length_indices] = True
|
||||
if verbose:
|
||||
print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
|
||||
if audio_streamer is not None:
|
||||
audio_streamer.end(new_max_length_indices)
|
||||
|
||||
# speech_end
|
||||
diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
|
||||
if diffusion_end_indices.numel() > 0:
|
||||
# Clear tokenizer caches for samples that reached speech end
|
||||
acoustic_cache.set_to_zero(diffusion_end_indices)
|
||||
semantic_cache.set_to_zero(diffusion_end_indices)
|
||||
|
||||
# speech_begin
|
||||
diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
|
||||
if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
|
||||
# update attention mask
|
||||
for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
|
||||
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
|
||||
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
|
||||
# update past key values
|
||||
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
|
||||
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
|
||||
# Process each non-diffusion sample
|
||||
for sample_idx in diffusion_start_indices.tolist():
|
||||
# Shift cache for this sample
|
||||
k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
|
||||
v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
|
||||
# update negative_input_ids
|
||||
for sample_idx in diffusion_start_indices.tolist():
|
||||
negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
|
||||
|
||||
# Prepare inputs_embeds for next iteration
|
||||
# Initialize with default embeddings for all tokens
|
||||
next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
|
||||
|
||||
# forward diffusion
|
||||
# Diffusion indices are those that are not finished and not special tokens
|
||||
diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
|
||||
|
||||
if diffusion_indices.numel() > 0:
|
||||
if kwargs.get('refresh_negative', True):
|
||||
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
|
||||
# Forward negative pass through the model
|
||||
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
|
||||
negative_model_inputs['inputs_embeds'] = inputs_embeds
|
||||
negative_model_inputs['input_ids'] = None
|
||||
|
||||
negative_outputs = self(
|
||||
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
|
||||
)
|
||||
negative_model_kwargs = self._update_model_kwargs_for_generation(
|
||||
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
|
||||
)
|
||||
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
|
||||
# correct the non-diffusion indices
|
||||
# we forward all samples' negative outputs even if
|
||||
# they are not in diffusion mode to keep the cache consistent
|
||||
# So we need to correct the kv cache of non-diffusion samples
|
||||
non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
|
||||
if non_diffusion_mask.any():
|
||||
non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
|
||||
start_indices = correct_cnt[non_diffusion_indices]
|
||||
|
||||
# 1. Update attention_mask - need to handle each sample separately
|
||||
seq_len = negative_model_kwargs['attention_mask'].shape[1]
|
||||
for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
|
||||
# Shift the attention mask for this sample
|
||||
if start_idx + 1 < seq_len - 1:
|
||||
negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
|
||||
negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
|
||||
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
|
||||
|
||||
# 2. Update past_key_values
|
||||
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
|
||||
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
|
||||
# Process each non-diffusion sample
|
||||
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
||||
if start_idx + 1 < k_cache.shape[2] - 1:
|
||||
# Shift cache for this sample
|
||||
k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
|
||||
v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
|
||||
|
||||
# 3. Update negative_input_ids
|
||||
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
||||
if start_idx + 1 < negative_input_ids.shape[1] - 1:
|
||||
negative_input_ids[sample_idx, start_idx+1:] = \
|
||||
negative_input_ids[sample_idx, start_idx:-1].clone()
|
||||
|
||||
correct_cnt[non_diffusion_indices] += 1
|
||||
|
||||
positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
|
||||
negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
|
||||
|
||||
speech_latent = self.sample_speech_tokens(
|
||||
positive_condition,
|
||||
negative_condition,
|
||||
cfg_scale=cfg_scale,
|
||||
).unsqueeze(1)
|
||||
|
||||
# Decode acoustic latent to audio using acoustic streaming cache
|
||||
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
|
||||
audio_chunk = self.model.acoustic_tokenizer.decode(
|
||||
scaled_latent.to(self.model.acoustic_tokenizer.device),
|
||||
cache=acoustic_cache, # Use acoustic-specific cache
|
||||
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
|
||||
use_cache=True,
|
||||
debug=False
|
||||
)
|
||||
|
||||
# Store audio chunks for each sample
|
||||
for i, sample_idx in enumerate(diffusion_indices):
|
||||
idx = sample_idx.item()
|
||||
# Only append audio chunk if the sample is not finished
|
||||
if not finished_tags[idx]:
|
||||
audio_chunks[idx].append(audio_chunk[i])
|
||||
|
||||
# Add streaming support here
|
||||
if audio_streamer is not None:
|
||||
# Stream the audio chunks immediately
|
||||
audio_streamer.put(audio_chunk, diffusion_indices)
|
||||
|
||||
# Encode audio to semantic features using semantic streaming cache
|
||||
semantic_features = self.model.semantic_tokenizer.encode(
|
||||
audio_chunk,
|
||||
cache=semantic_cache, # Use semantic-specific cache
|
||||
sample_indices=diffusion_indices,
|
||||
use_cache=True,
|
||||
debug=False
|
||||
).mean # semantic tokenizer has no VAE.
|
||||
|
||||
# Combine acoustic and semantic features for next input
|
||||
acoustic_embed = self.model.acoustic_connector(speech_latent)
|
||||
semantic_embed = self.model.semantic_connector(semantic_features)
|
||||
diffusion_embeds = acoustic_embed + semantic_embed
|
||||
|
||||
# Update embeddings for diffusion indices
|
||||
next_inputs_embeds[diffusion_indices] = diffusion_embeds
|
||||
|
||||
# Set inputs_embeds for next iteration
|
||||
inputs_embeds = next_inputs_embeds
|
||||
|
||||
if audio_streamer is not None:
|
||||
audio_streamer.end()
|
||||
|
||||
# Concatenate audio chunks for each sample
|
||||
final_audio_outputs = []
|
||||
for sample_chunks in audio_chunks:
|
||||
if sample_chunks:
|
||||
# Concatenate all chunks along the time dimension (assumed to be the last dimension)
|
||||
concatenated_audio = torch.cat(sample_chunks, dim=-1)
|
||||
final_audio_outputs.append(concatenated_audio)
|
||||
else:
|
||||
# If no audio was generated for this sample, append None
|
||||
final_audio_outputs.append(None)
|
||||
|
||||
return VibeVoiceGenerationOutput(
|
||||
sequences=input_ids,
|
||||
speech_outputs=final_audio_outputs if return_speech else None,
|
||||
reach_max_step_sample=reach_max_step_sample,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
|
||||
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
|
||||
condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
|
||||
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
|
||||
for t in self.model.noise_scheduler.timesteps:
|
||||
half = speech[: len(speech) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
|
||||
return speech[: len(speech) // 2]
|
||||
|
||||
|
||||
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceForConditionalGenerationInference",
|
||||
]
|
||||
287
vibevoice/modular/modular_vibevoice_diffusion_head.py
Normal file
287
vibevoice/modular/modular_vibevoice_diffusion_head.py
Normal file
@ -0,0 +1,287 @@
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.models.auto import AutoModel
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
# from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import logging
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if self.weight is not None:
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
"""Apply modulation to input tensor."""
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`): Size of the output embedding
|
||||
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(hidden_size, hidden_size, bias=False),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
dim (`int`): The dimension of the output.
|
||||
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding.to(t.dtype)
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForwardNetwork(nn.Module):
|
||||
"""
|
||||
Standard feed-forward network with SwiGLU activation.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): Input dimension
|
||||
ffn_dim (`int`): Hidden dimension
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
ffn_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
||||
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
|
||||
|
||||
def forward(self, x):
|
||||
gate = self.gate_proj(x)
|
||||
up = self.up_proj(x)
|
||||
|
||||
# SwiGLU activation
|
||||
# gate = F.silu(gate)
|
||||
gate = self.act_fn(gate)
|
||||
return self.down_proj(gate * up)
|
||||
|
||||
|
||||
class HeadLayer(nn.Module):
|
||||
"""
|
||||
A layer in the diffusion head.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): Input dimension
|
||||
ffn_dim (`int`): Hidden dimension
|
||||
cond_dim (`int`): Condition embedding dimension
|
||||
norm_eps (`float`, optional): Epsilon for normalization
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
ffn_dim,
|
||||
cond_dim,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.cond_dim = cond_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.ffn = FeedForwardNetwork(
|
||||
self.embed_dim,
|
||||
self.ffn_dim,
|
||||
)
|
||||
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
|
||||
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
Final layer in the diffusion head.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`): Input dimension
|
||||
output_size (`int`): Output dimension
|
||||
cond_size (`int`): Condition embedding dimension
|
||||
norm_eps (`float`, optional): Epsilon for normalization
|
||||
"""
|
||||
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
|
||||
super().__init__()
|
||||
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
|
||||
self.linear = nn.Linear(hidden_size, output_size, bias=False)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
# nn.SiLU(),
|
||||
ACT2FN['silu'],
|
||||
nn.Linear(cond_size, 2 * hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class VibeVoiceDiffusionHead(PreTrainedModel):
|
||||
"""
|
||||
Diffusion head model for vibevoice.
|
||||
|
||||
Args:
|
||||
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
|
||||
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
|
||||
"""
|
||||
config_class = VibeVoiceDiffusionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.cond_dim = config.hidden_size
|
||||
latent_size = config.latent_size
|
||||
|
||||
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
|
||||
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
|
||||
self.t_embedder = TimestepEmbedder(self.cond_dim)
|
||||
|
||||
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
|
||||
|
||||
# Create the intermediate layers
|
||||
self.layers = nn.ModuleList([
|
||||
HeadLayer(
|
||||
embed_dim=config.hidden_size,
|
||||
ffn_dim=ffn_dim,
|
||||
cond_dim=self.cond_dim,
|
||||
norm_eps=config.rms_norm_eps
|
||||
)
|
||||
for _ in range(config.head_layers)
|
||||
])
|
||||
|
||||
# Final layer for output
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=config.hidden_size,
|
||||
output_size=latent_size,
|
||||
cond_size=self.cond_dim,
|
||||
norm_eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
"""Initialize the weights of the model."""
|
||||
# Initialize timestep embedder
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers
|
||||
for layer in self.layers:
|
||||
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
|
||||
|
||||
# Zero-out output layers
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
noisy_images,
|
||||
timesteps,
|
||||
condition,
|
||||
):
|
||||
"""
|
||||
Forward pass of the prediction head.
|
||||
|
||||
Args:
|
||||
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
|
||||
timesteps (`torch.Tensor`): Timesteps for diffusion
|
||||
condition (`torch.Tensor`): Conditioning information
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The predicted noise/velocity
|
||||
"""
|
||||
x = self.noisy_images_proj(noisy_images)
|
||||
t = self.t_embedder(timesteps)
|
||||
condition = self.cond_proj(condition)
|
||||
c = condition + t
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, c)
|
||||
|
||||
x = self.final_layer(x, c)
|
||||
return x
|
||||
|
||||
|
||||
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceDiffusionHead",
|
||||
]
|
||||
214
vibevoice/modular/modular_vibevoice_text_tokenizer.py
Normal file
214
vibevoice/modular/modular_vibevoice_text_tokenizer.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""Tokenization classes for vibevoice."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers.utils import logging
|
||||
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
|
||||
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VibeVoiceTextTokenizer(Qwen2Tokenizer):
|
||||
"""
|
||||
Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
Path to the merges file.
|
||||
errors (`str`, *optional*, defaults to `"replace"`):
|
||||
Paradigm to follow when decoding bytes to UTF-8.
|
||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The unknown token.
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token. Not used for vibevoice.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The token used for padding.
|
||||
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to add special tokens when encoding.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
errors="replace",
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token=None,
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>",
|
||||
add_prefix_space=False,
|
||||
add_special_tokens=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
errors=errors,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
add_special_tokens=add_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add VibeVoice-specific special tokens
|
||||
self._add_vibevoice_special_tokens()
|
||||
|
||||
def _add_vibevoice_special_tokens(self):
|
||||
"""Add VibeVoice-specific special tokens."""
|
||||
special_tokens = {
|
||||
"additional_special_tokens": [
|
||||
"<|vision_start|>", # Speech start (reusing vision tokens)
|
||||
"<|vision_end|>", # Speech end
|
||||
"<|vision_pad|>", # Speech diffusion pad
|
||||
]
|
||||
}
|
||||
num_added = self.add_special_tokens(special_tokens)
|
||||
|
||||
# Cache special token IDs
|
||||
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
||||
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
||||
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
||||
|
||||
self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
|
||||
|
||||
return num_added
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
"""Id of the end of sequence token."""
|
||||
return self._eos_id
|
||||
|
||||
@property
|
||||
def speech_start_id(self) -> int:
|
||||
"""Id of the speech start token."""
|
||||
return self._speech_start_id
|
||||
|
||||
@property
|
||||
def speech_end_id(self) -> int:
|
||||
"""Id of the speech end token."""
|
||||
return self._speech_end_id
|
||||
|
||||
@property
|
||||
def speech_diffusion_id(self) -> int:
|
||||
"""Id of the speech diffusion token."""
|
||||
return self._speech_diffusion_id
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
"""Id used for padding (returns -100 for loss masking)."""
|
||||
return -100
|
||||
|
||||
|
||||
class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
|
||||
"""
|
||||
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
|
||||
Based on the Qwen2 tokenizer with additional special tokens for speech.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`, *optional*):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`, *optional*):
|
||||
Path to the merges file.
|
||||
tokenizer_file (`str`, *optional*):
|
||||
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
|
||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The unknown token.
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token. Not used for vibevoice.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The token used for padding.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token=None,
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>",
|
||||
add_prefix_space=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add VibeVoice-specific special tokens
|
||||
self._add_vibevoice_special_tokens()
|
||||
|
||||
def _add_vibevoice_special_tokens(self):
|
||||
"""Add VibeVoice-specific special tokens."""
|
||||
special_tokens = {
|
||||
"additional_special_tokens": [
|
||||
"<|vision_start|>", # Speech start (reusing vision tokens)
|
||||
"<|vision_end|>", # Speech end
|
||||
"<|vision_pad|>", # Speech diffusion pad
|
||||
]
|
||||
}
|
||||
num_added = self.add_special_tokens(special_tokens)
|
||||
|
||||
# Cache special token IDs
|
||||
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
||||
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
||||
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
||||
|
||||
# self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
|
||||
self._eos_id = self.eos_token_id # qwen2 / qwen3
|
||||
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
|
||||
|
||||
return num_added
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
"""Id of the end of sequence token."""
|
||||
return self._eos_id
|
||||
|
||||
@property
|
||||
def speech_start_id(self) -> int:
|
||||
"""Id of the speech start token."""
|
||||
return self._speech_start_id
|
||||
|
||||
@property
|
||||
def speech_end_id(self) -> int:
|
||||
"""Id of the speech end token."""
|
||||
return self._speech_end_id
|
||||
|
||||
@property
|
||||
def speech_diffusion_id(self) -> int:
|
||||
"""Id of the speech diffusion token."""
|
||||
return self._speech_diffusion_id
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
"""Id used for padding (returns -100 for loss masking)."""
|
||||
return self._pad_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceTextTokenizer",
|
||||
"VibeVoiceTextTokenizerFast",
|
||||
]
|
||||
1195
vibevoice/modular/modular_vibevoice_tokenizer.py
Normal file
1195
vibevoice/modular/modular_vibevoice_tokenizer.py
Normal file
File diff suppressed because it is too large
Load Diff
264
vibevoice/modular/streamer.py
Normal file
264
vibevoice/modular/streamer.py
Normal file
@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import asyncio
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
|
||||
from transformers.generation import BaseStreamer
|
||||
|
||||
|
||||
class AudioStreamer(BaseStreamer):
|
||||
"""
|
||||
Audio streamer that stores audio chunks in queues for each sample in the batch.
|
||||
This allows streaming audio generation for multiple samples simultaneously.
|
||||
|
||||
Parameters:
|
||||
batch_size (`int`):
|
||||
The batch size for generation
|
||||
stop_signal (`any`, *optional*):
|
||||
The signal to put in the queue when generation ends. Defaults to None.
|
||||
timeout (`float`, *optional*):
|
||||
The timeout for the audio queue. If `None`, the queue will block indefinitely.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
stop_signal: Optional[any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.stop_signal = stop_signal
|
||||
self.timeout = timeout
|
||||
|
||||
# Create a queue for each sample in the batch
|
||||
self.audio_queues = [Queue() for _ in range(batch_size)]
|
||||
self.finished_flags = [False for _ in range(batch_size)]
|
||||
self.sample_indices_map = {} # Maps from sample index to queue index
|
||||
|
||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
||||
"""
|
||||
Receives audio chunks and puts them in the appropriate queues.
|
||||
|
||||
Args:
|
||||
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
|
||||
sample_indices: Tensor indicating which samples these chunks belong to
|
||||
"""
|
||||
for i, sample_idx in enumerate(sample_indices):
|
||||
idx = sample_idx.item()
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
# Convert to numpy or keep as tensor based on preference
|
||||
audio_chunk = audio_chunks[i].detach().cpu()
|
||||
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
|
||||
|
||||
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Signals the end of generation for specified samples or all samples.
|
||||
|
||||
Args:
|
||||
sample_indices: Optional tensor of sample indices to end. If None, ends all.
|
||||
"""
|
||||
if sample_indices is None:
|
||||
# End all samples
|
||||
for idx in range(self.batch_size):
|
||||
if not self.finished_flags[idx]:
|
||||
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
||||
self.finished_flags[idx] = True
|
||||
else:
|
||||
# End specific samples
|
||||
for sample_idx in sample_indices:
|
||||
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator over the batch of audio streams."""
|
||||
return AudioBatchIterator(self)
|
||||
|
||||
def get_stream(self, sample_idx: int):
|
||||
"""Get the audio stream for a specific sample."""
|
||||
if sample_idx >= self.batch_size:
|
||||
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
||||
return AudioSampleIterator(self, sample_idx)
|
||||
|
||||
|
||||
class AudioSampleIterator:
|
||||
"""Iterator for a single audio stream from the batch."""
|
||||
|
||||
def __init__(self, streamer: AudioStreamer, sample_idx: int):
|
||||
self.streamer = streamer
|
||||
self.sample_idx = sample_idx
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
|
||||
if value == self.streamer.stop_signal:
|
||||
raise StopIteration()
|
||||
return value
|
||||
|
||||
|
||||
class AudioBatchIterator:
|
||||
"""Iterator that yields audio chunks for all samples in the batch."""
|
||||
|
||||
def __init__(self, streamer: AudioStreamer):
|
||||
self.streamer = streamer
|
||||
self.active_samples = set(range(streamer.batch_size))
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self.active_samples:
|
||||
raise StopIteration()
|
||||
|
||||
batch_chunks = {}
|
||||
samples_to_remove = set()
|
||||
|
||||
# Try to get chunks from all active samples
|
||||
for idx in self.active_samples:
|
||||
try:
|
||||
value = self.streamer.audio_queues[idx].get(block=False)
|
||||
if value == self.streamer.stop_signal:
|
||||
samples_to_remove.add(idx)
|
||||
else:
|
||||
batch_chunks[idx] = value
|
||||
except:
|
||||
# Queue is empty for this sample, skip it this iteration
|
||||
pass
|
||||
|
||||
# Remove finished samples
|
||||
self.active_samples -= samples_to_remove
|
||||
|
||||
if batch_chunks:
|
||||
return batch_chunks
|
||||
elif self.active_samples:
|
||||
# If no chunks were ready but we still have active samples,
|
||||
# wait a bit and try again
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
return self.__next__()
|
||||
else:
|
||||
raise StopIteration()
|
||||
|
||||
|
||||
class AsyncAudioStreamer(AudioStreamer):
|
||||
"""
|
||||
Async version of AudioStreamer for use in async contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
stop_signal: Optional[any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
super().__init__(batch_size, stop_signal, timeout)
|
||||
# Replace regular queues with async queues
|
||||
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
|
||||
self.loop = asyncio.get_running_loop()
|
||||
|
||||
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
||||
"""Put audio chunks in the appropriate async queues."""
|
||||
for i, sample_idx in enumerate(sample_indices):
|
||||
idx = sample_idx.item()
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
audio_chunk = audio_chunks[i].detach().cpu()
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.audio_queues[idx].put_nowait, audio_chunk
|
||||
)
|
||||
|
||||
def end(self, sample_indices: Optional[torch.Tensor] = None):
|
||||
"""Signal the end of generation for specified samples."""
|
||||
if sample_indices is None:
|
||||
indices_to_end = range(self.batch_size)
|
||||
else:
|
||||
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
||||
|
||||
for idx in indices_to_end:
|
||||
if idx < self.batch_size and not self.finished_flags[idx]:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.audio_queues[idx].put_nowait, self.stop_signal
|
||||
)
|
||||
self.finished_flags[idx] = True
|
||||
|
||||
async def get_stream(self, sample_idx: int):
|
||||
"""Get async iterator for a specific sample's audio stream."""
|
||||
if sample_idx >= self.batch_size:
|
||||
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
||||
|
||||
while True:
|
||||
value = await self.audio_queues[sample_idx].get()
|
||||
if value == self.stop_signal:
|
||||
break
|
||||
yield value
|
||||
|
||||
def __aiter__(self):
|
||||
"""Returns an async iterator over all audio streams."""
|
||||
return AsyncAudioBatchIterator(self)
|
||||
|
||||
|
||||
class AsyncAudioBatchIterator:
|
||||
"""Async iterator for batch audio streaming."""
|
||||
|
||||
def __init__(self, streamer: AsyncAudioStreamer):
|
||||
self.streamer = streamer
|
||||
self.active_samples = set(range(streamer.batch_size))
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self.active_samples:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
batch_chunks = {}
|
||||
samples_to_remove = set()
|
||||
|
||||
# Create tasks for all active samples
|
||||
tasks = {
|
||||
idx: asyncio.create_task(self._get_chunk(idx))
|
||||
for idx in self.active_samples
|
||||
}
|
||||
|
||||
# Wait for at least one chunk to be ready
|
||||
done, pending = await asyncio.wait(
|
||||
tasks.values(),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
timeout=self.streamer.timeout
|
||||
)
|
||||
|
||||
# Cancel pending tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Process completed tasks
|
||||
for idx, task in tasks.items():
|
||||
if task in done:
|
||||
try:
|
||||
value = await task
|
||||
if value == self.streamer.stop_signal:
|
||||
samples_to_remove.add(idx)
|
||||
else:
|
||||
batch_chunks[idx] = value
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.active_samples -= samples_to_remove
|
||||
|
||||
if batch_chunks:
|
||||
return batch_chunks
|
||||
elif self.active_samples:
|
||||
# Try again if we still have active samples
|
||||
return await self.__anext__()
|
||||
else:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
async def _get_chunk(self, idx):
|
||||
"""Helper to get a chunk from a specific queue."""
|
||||
return await self.streamer.audio_queues[idx].get()
|
||||
0
vibevoice/processor/__init__.py
Normal file
0
vibevoice/processor/__init__.py
Normal file
677
vibevoice/processor/vibevoice_processor.py
Normal file
677
vibevoice/processor/vibevoice_processor.py
Normal file
@ -0,0 +1,677 @@
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Union, Dict, Any, Tuple
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from transformers.utils import TensorType, logging
|
||||
from .vibevoice_tokenizer_processor import AudioNormalizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VibeVoiceProcessor:
|
||||
r"""
|
||||
Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
|
||||
|
||||
[`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
|
||||
See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
|
||||
The tokenizer for text processing.
|
||||
audio_processor (`VibeVoiceTokenizerProcessor`):
|
||||
The audio processor for speech processing.
|
||||
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
|
||||
The compression ratio for speech tokenization.
|
||||
db_normalize (`bool`, *optional*, defaults to True):
|
||||
Whether to apply decibel normalization to audio inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
|
||||
self.tokenizer = tokenizer
|
||||
self.audio_processor = audio_processor
|
||||
self.speech_tok_compress_ratio = speech_tok_compress_ratio
|
||||
self.db_normalize = db_normalize
|
||||
self.audio_normalizer = AudioNormalizer() if db_normalize else None
|
||||
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
- a string, the *model id* of a pretrained model
|
||||
- a path to a *directory* containing processor config
|
||||
|
||||
Returns:
|
||||
[`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
|
||||
VibeVoiceTextTokenizer,
|
||||
VibeVoiceTextTokenizerFast
|
||||
)
|
||||
|
||||
# Load processor configuration
|
||||
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
|
||||
config = {
|
||||
"speech_tok_compress_ratio": 3200,
|
||||
"db_normalize": True,
|
||||
}
|
||||
|
||||
# Extract main processor parameters
|
||||
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
|
||||
db_normalize = config.get("db_normalize", True)
|
||||
|
||||
# Load tokenizer - try from model path first, then fallback to Qwen
|
||||
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
|
||||
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
|
||||
if 'qwen' in language_model_pretrained_name.lower():
|
||||
tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
|
||||
language_model_pretrained_name,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
|
||||
|
||||
# Load audio processor
|
||||
if "audio_processor" in config:
|
||||
# Create audio processor from config
|
||||
audio_config = config["audio_processor"]
|
||||
audio_processor = VibeVoiceTokenizerProcessor(
|
||||
sampling_rate=audio_config.get("sampling_rate", 24000),
|
||||
normalize_audio=audio_config.get("normalize_audio", True),
|
||||
target_dB_FS=audio_config.get("target_dB_FS", -25),
|
||||
eps=audio_config.get("eps", 1e-6),
|
||||
)
|
||||
else:
|
||||
# Create default audio processor
|
||||
audio_processor = VibeVoiceTokenizerProcessor()
|
||||
|
||||
# Create and return the processor
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
speech_tok_compress_ratio=speech_tok_compress_ratio,
|
||||
db_normalize=db_normalize,
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
||||
"""
|
||||
Save a processor to a directory, so that it can be re-loaded using the
|
||||
[`~VibeVoiceProcessor.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the processor will be saved.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Save processor configuration
|
||||
processor_config = {
|
||||
"processor_class": "VibeVoiceProcessor",
|
||||
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
|
||||
"db_normalize": self.db_normalize,
|
||||
"audio_processor": {
|
||||
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
|
||||
"sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
|
||||
"normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
|
||||
"target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
|
||||
"eps": getattr(self.audio_processor, 'eps', 1e-6),
|
||||
}
|
||||
}
|
||||
|
||||
config_path = os.path.join(save_directory, "preprocessor_config.json")
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(processor_config, f, indent=2)
|
||||
|
||||
logger.info(f"Processor configuration saved in {config_path}")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
|
||||
padding: Union[bool, str, PaddingStrategy] = True,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_attention_mask: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Main method to process one or more podcast scripts with optional voice samples.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`):
|
||||
The input text(s) to process. Can be:
|
||||
- A single script string
|
||||
- A list of script strings for batch processing
|
||||
- A path to a .json or .txt file
|
||||
- A list of paths
|
||||
voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
|
||||
Voice samples for each script. Can be:
|
||||
- A list of samples for a single script
|
||||
- A list of lists for batch processing
|
||||
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
|
||||
Whether to pad sequences to the same length
|
||||
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
|
||||
Whether to truncate sequences
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned sequences
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
If set, will return tensors of a particular framework
|
||||
return_attention_mask (`bool`, defaults to `True`):
|
||||
Whether to return the attention mask
|
||||
|
||||
Returns:
|
||||
`BatchEncoding`: A BatchEncoding with the following fields:
|
||||
- **input_ids** -- List of token id sequences or tensor
|
||||
- **attention_mask** -- List of attention masks or tensor
|
||||
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
|
||||
- **speech_masks** -- Speech masks (if voice_samples provided)
|
||||
- **speech_input_mask** -- Boolean masks indicating speech token positions
|
||||
"""
|
||||
# Handle single vs batch input
|
||||
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
|
||||
# Single input
|
||||
texts = [text]
|
||||
is_batched = False
|
||||
else:
|
||||
# Batch input
|
||||
texts = text
|
||||
is_batched = True
|
||||
|
||||
# Handle voice samples
|
||||
if voice_samples is not None:
|
||||
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
|
||||
# Single set of voice samples
|
||||
voice_samples_list = [voice_samples]
|
||||
else:
|
||||
# Batch of voice samples
|
||||
voice_samples_list = voice_samples
|
||||
else:
|
||||
voice_samples_list = [None] * len(texts)
|
||||
|
||||
# Process each input
|
||||
all_encodings = []
|
||||
for text_input, voice_input in zip(texts, voice_samples_list):
|
||||
encoding = self._process_single(text_input, voice_input)
|
||||
all_encodings.append(encoding)
|
||||
|
||||
# Combine batch
|
||||
batch_encoding = self._batch_encode(
|
||||
all_encodings,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
|
||||
return batch_encoding
|
||||
|
||||
def _process_single(
|
||||
self,
|
||||
text: Union[str, TextInput],
|
||||
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a single podcast script."""
|
||||
# Determine if text is a file path or direct script
|
||||
script = None
|
||||
if isinstance(text, str):
|
||||
# Check if it's a file path
|
||||
if text.endswith('.json') and os.path.exists(text):
|
||||
script = self._convert_json_to_script(text)
|
||||
elif text.endswith('.txt') and os.path.exists(text):
|
||||
script = self._convert_text_to_script(text)
|
||||
else:
|
||||
# Assume it's the script content directly
|
||||
script = text
|
||||
|
||||
if script is None:
|
||||
raise ValueError(f"Could not process input text: {text}")
|
||||
|
||||
# Parse the script
|
||||
parsed_lines = self._parse_script(script)
|
||||
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
|
||||
|
||||
# Create system prompt
|
||||
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
|
||||
system_tokens = self.tokenizer.encode(self.system_prompt)
|
||||
|
||||
# Process voice samples if provided
|
||||
if voice_samples:
|
||||
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
|
||||
else:
|
||||
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
|
||||
|
||||
# Build full token sequence
|
||||
full_tokens = system_tokens + voice_tokens
|
||||
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
|
||||
|
||||
# Add text input section
|
||||
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
|
||||
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
|
||||
|
||||
for speaker_id, speaker_text in parsed_lines:
|
||||
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
|
||||
full_tokens += speaker_text_tokens
|
||||
speech_input_mask += [False] * len(speaker_text_tokens)
|
||||
|
||||
# Add speech output section
|
||||
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
|
||||
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
|
||||
|
||||
return {
|
||||
"input_ids": full_tokens,
|
||||
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
|
||||
"speech_input_mask": speech_input_mask,
|
||||
"parsed_script": parsed_lines,
|
||||
"all_speakers": all_speakers,
|
||||
}
|
||||
|
||||
def _batch_encode(
|
||||
self,
|
||||
encodings: List[Dict[str, Any]],
|
||||
padding: Union[bool, str, PaddingStrategy] = True,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_attention_mask: bool = True,
|
||||
) -> BatchEncoding:
|
||||
"""Combine multiple encodings into a batch with padding."""
|
||||
# Extract input_ids and create attention_mask
|
||||
input_ids_list = [enc["input_ids"] for enc in encodings]
|
||||
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
|
||||
|
||||
# Determine padding strategy
|
||||
if isinstance(padding, bool):
|
||||
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
|
||||
elif isinstance(padding, str):
|
||||
padding_strategy = PaddingStrategy(padding)
|
||||
else:
|
||||
padding_strategy = padding
|
||||
|
||||
# Apply padding to input_ids
|
||||
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_len = max(len(ids) for ids in input_ids_list)
|
||||
elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
|
||||
max_len = max_length
|
||||
else:
|
||||
max_len = max(len(ids) for ids in input_ids_list)
|
||||
|
||||
# Pad sequences
|
||||
padded_input_ids = []
|
||||
attention_masks = []
|
||||
padded_speech_input_masks = []
|
||||
|
||||
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
|
||||
# Truncate if needed
|
||||
if truncation and len(input_ids) > max_len:
|
||||
input_ids = input_ids[:max_len]
|
||||
speech_mask = speech_mask[:max_len]
|
||||
|
||||
# Pad
|
||||
padding_length = max_len - len(input_ids)
|
||||
# padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
|
||||
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
|
||||
attention_mask = [0] * padding_length + [1] * len(input_ids)
|
||||
padded_speech_mask = [False] * padding_length + speech_mask
|
||||
|
||||
padded_input_ids.append(padded_ids)
|
||||
attention_masks.append(attention_mask)
|
||||
padded_speech_input_masks.append(padded_speech_mask)
|
||||
|
||||
input_ids_list = padded_input_ids
|
||||
speech_input_masks_list = padded_speech_input_masks
|
||||
else:
|
||||
# No padding, just create attention masks
|
||||
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
|
||||
|
||||
# Process speech inputs
|
||||
all_speech_inputs = []
|
||||
has_speech = False
|
||||
for enc in encodings:
|
||||
if enc["speech_inputs"] is not None:
|
||||
all_speech_inputs.extend(enc["speech_inputs"])
|
||||
has_speech = True
|
||||
|
||||
# Prepare batch encoding
|
||||
batch_encoding = BatchEncoding()
|
||||
|
||||
# Handle tensor conversion
|
||||
if return_tensors is not None:
|
||||
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
|
||||
if return_attention_mask and attention_masks is not None:
|
||||
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
|
||||
batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
|
||||
else:
|
||||
batch_encoding["input_ids"] = input_ids_list
|
||||
if return_attention_mask and attention_masks is not None:
|
||||
batch_encoding["attention_mask"] = attention_masks
|
||||
batch_encoding["speech_input_mask"] = speech_input_masks_list
|
||||
|
||||
# Process speech tensors if present
|
||||
if has_speech:
|
||||
speech_dict = self.prepare_speech_inputs(
|
||||
all_speech_inputs,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
|
||||
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
|
||||
else:
|
||||
batch_encoding["speech_tensors"] = None
|
||||
batch_encoding["speech_masks"] = None
|
||||
|
||||
# Add metadata
|
||||
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
|
||||
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
|
||||
|
||||
return batch_encoding
|
||||
|
||||
def _create_voice_prompt(
|
||||
self,
|
||||
speaker_samples: List[Union[str, np.ndarray]]
|
||||
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
|
||||
"""
|
||||
Create voice prompt tokens and process audio samples.
|
||||
|
||||
Returns:
|
||||
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
|
||||
"""
|
||||
vae_token_id = self.tokenizer.speech_diffusion_id
|
||||
|
||||
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
|
||||
voice_speech_inputs = []
|
||||
voice_speech_masks = [False] * len(voice_full_tokens)
|
||||
|
||||
for speaker_id, speaker_audio in enumerate(speaker_samples):
|
||||
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
|
||||
|
||||
# Process audio
|
||||
if isinstance(speaker_audio, str):
|
||||
# Load audio from file
|
||||
wav = self.audio_processor._load_audio_from_path(speaker_audio)
|
||||
else:
|
||||
wav = np.array(speaker_audio, dtype=np.float32)
|
||||
|
||||
# Apply normalization if needed
|
||||
if self.db_normalize and self.audio_normalizer:
|
||||
wav = self.audio_normalizer(wav)
|
||||
|
||||
# Calculate token length based on compression ratio
|
||||
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
|
||||
# vae_tok_len = wav.shape[0]
|
||||
# else:
|
||||
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
|
||||
|
||||
# Build tokens and masks
|
||||
speaker_tokens = (prefix_tokens +
|
||||
[self.tokenizer.speech_start_id] +
|
||||
[vae_token_id] * vae_tok_len +
|
||||
[self.tokenizer.speech_end_id] +
|
||||
self.tokenizer.encode('\n', add_special_tokens=False))
|
||||
|
||||
vae_input_mask = ([False] * len(prefix_tokens) +
|
||||
[False] +
|
||||
[True] * vae_tok_len +
|
||||
[False] +
|
||||
[False])
|
||||
|
||||
voice_full_tokens.extend(speaker_tokens)
|
||||
voice_speech_masks.extend(vae_input_mask)
|
||||
voice_speech_inputs.append(wav)
|
||||
|
||||
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
|
||||
|
||||
def prepare_speech_inputs(
|
||||
self,
|
||||
speech_inputs: List[np.ndarray],
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare speech inputs for model consumption.
|
||||
|
||||
Args:
|
||||
speech_inputs: List of speech arrays
|
||||
return_tensors: Output tensor type
|
||||
device: Device to place tensors on
|
||||
dtype: Data type for tensors
|
||||
|
||||
Returns:
|
||||
Dictionary with padded_speeches and speech_masks
|
||||
"""
|
||||
if not speech_inputs:
|
||||
return {"padded_speeches": None, "speech_masks": None}
|
||||
|
||||
# Calculate sequence lengths
|
||||
vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
|
||||
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
|
||||
max_speech_length = max(s.shape[0] for s in speech_inputs)
|
||||
|
||||
# Pad speeches
|
||||
if speech_inputs[0].ndim == 1:
|
||||
padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
|
||||
else:
|
||||
padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
|
||||
speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
|
||||
|
||||
for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
|
||||
padded_speeches[i, :len(speech)] = speech
|
||||
speech_masks[i, :vae_tok_length] = True
|
||||
|
||||
result = {
|
||||
"padded_speeches": padded_speeches,
|
||||
"speech_masks": speech_masks,
|
||||
}
|
||||
|
||||
# Convert to tensors if requested
|
||||
if return_tensors == "pt":
|
||||
result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
|
||||
result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_json_to_script(self, json_file: str) -> str:
|
||||
"""
|
||||
Convert JSON format to script format.
|
||||
Expected JSON format:
|
||||
[
|
||||
{"speaker": "1", "text": "Hello everyone..."},
|
||||
{"speaker": "2", "text": "Great to be here..."}
|
||||
]
|
||||
"""
|
||||
import json
|
||||
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("JSON file must contain a list of speaker entries")
|
||||
|
||||
script_lines = []
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
logger.warning(f"Skipping non-dict entry: {item}")
|
||||
continue
|
||||
|
||||
speaker = item.get('speaker')
|
||||
text = item.get('text')
|
||||
|
||||
if speaker is None or text is None:
|
||||
logger.warning(f"Skipping entry missing speaker or text: {item}")
|
||||
continue
|
||||
|
||||
# Ensure speaker ID is valid
|
||||
try:
|
||||
speaker_id = int(speaker)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
|
||||
continue
|
||||
|
||||
# Clean up text
|
||||
text = text.strip()
|
||||
if text:
|
||||
script_lines.append(f"Speaker {speaker_id}: {text}")
|
||||
|
||||
if not script_lines:
|
||||
raise ValueError("No valid entries found in JSON file")
|
||||
|
||||
return "\n".join(script_lines)
|
||||
|
||||
def _convert_text_to_script(self, text_file: str) -> str:
|
||||
"""
|
||||
Convert text file to script format.
|
||||
Handles multiple formats:
|
||||
1. Already formatted as "Speaker X: text"
|
||||
2. Plain text (assigns to Speaker 1)
|
||||
|
||||
Handles edge cases like multiple colons in a line.
|
||||
"""
|
||||
with open(text_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
script_lines = []
|
||||
current_speaker = 1
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Try to parse as "Speaker X: text" format
|
||||
# Use regex to be more robust
|
||||
speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
|
||||
|
||||
if speaker_match:
|
||||
speaker_id = int(speaker_match.group(1))
|
||||
text = speaker_match.group(2).strip()
|
||||
if text:
|
||||
script_lines.append(f"Speaker {speaker_id}: {text}")
|
||||
else:
|
||||
# Treat as plain text - assign to current speaker
|
||||
script_lines.append(f"Speaker {current_speaker}: {line}")
|
||||
|
||||
if not script_lines:
|
||||
raise ValueError("No valid content found in text file")
|
||||
|
||||
return "\n".join(script_lines)
|
||||
|
||||
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
|
||||
"""Parse script into list of (speaker_id, text) tuples."""
|
||||
lines = script.strip().split("\n")
|
||||
parsed_lines = []
|
||||
speaker_ids = []
|
||||
|
||||
# First pass: parse all lines and collect speaker IDs
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Use regex to handle edge cases like multiple colons
|
||||
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
|
||||
|
||||
if match:
|
||||
speaker_id = int(match.group(1))
|
||||
text = ' ' + match.group(2).strip()
|
||||
parsed_lines.append((speaker_id, text))
|
||||
speaker_ids.append(speaker_id)
|
||||
else:
|
||||
logger.warning(f"Could not parse line: '{line}'")
|
||||
|
||||
if not parsed_lines:
|
||||
raise ValueError("No valid speaker lines found in script")
|
||||
|
||||
# Check if we need to normalize speaker IDs (only if all are > 0)
|
||||
min_speaker_id = min(speaker_ids)
|
||||
if min_speaker_id > 0:
|
||||
# Normalize to start from 0
|
||||
normalized_lines = []
|
||||
for speaker_id, text in parsed_lines:
|
||||
normalized_lines.append((speaker_id - 1, text))
|
||||
return normalized_lines
|
||||
else:
|
||||
# Keep original IDs
|
||||
return parsed_lines
|
||||
|
||||
def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
|
||||
"""Merge text and audio inputs into a single BatchEncoding."""
|
||||
# Start with text inputs
|
||||
merged = BatchEncoding(text_inputs)
|
||||
|
||||
# Add audio-specific fields
|
||||
if "audio" in audio_inputs:
|
||||
merged["speech_inputs"] = audio_inputs["audio"]
|
||||
if "streaming" in audio_inputs:
|
||||
merged["streaming"] = audio_inputs["streaming"]
|
||||
|
||||
return merged
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
|
||||
Please refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
|
||||
Please refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
"""
|
||||
Return the list of inputs accepted by the model.
|
||||
"""
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
audio_processor_input_names = self.audio_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
|
||||
|
||||
def save_audio(self,
|
||||
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
||||
output_path: str = "output.wav",
|
||||
sampling_rate: Optional[int] = None,
|
||||
normalize: bool = False,
|
||||
batch_prefix: str = "audio_",
|
||||
) -> str:
|
||||
"""
|
||||
Save audio data to a file.
|
||||
Args:
|
||||
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
|
||||
The audio data to save. Can be a single tensor/array or a list of them.
|
||||
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
|
||||
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
|
||||
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
|
||||
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
|
||||
Returns:
|
||||
str: The path to the saved audio file.
|
||||
"""
|
||||
return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceProcessor",
|
||||
]
|
||||
483
vibevoice/processor/vibevoice_tokenizer_processor.py
Normal file
483
vibevoice/processor/vibevoice_tokenizer_processor.py
Normal file
@ -0,0 +1,483 @@
|
||||
"""
|
||||
Processor class for VibeVoice models.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import warnings
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AudioNormalizer:
|
||||
"""
|
||||
Audio normalization class for VibeVoice tokenizer.
|
||||
|
||||
This class provides audio normalization to ensure consistent input levels
|
||||
for the VibeVoice tokenizer while maintaining audio quality.
|
||||
"""
|
||||
|
||||
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize the audio normalizer.
|
||||
|
||||
Args:
|
||||
target_dB_FS (float): Target dB FS level for the audio. Default: -25
|
||||
eps (float): Small value to avoid division by zero. Default: 1e-6
|
||||
"""
|
||||
self.target_dB_FS = target_dB_FS
|
||||
self.eps = eps
|
||||
|
||||
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
|
||||
"""
|
||||
Adjust the audio to the target dB FS level.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
|
||||
Returns:
|
||||
tuple: (normalized_audio, rms, scalar)
|
||||
"""
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
|
||||
normalized_audio = audio * scalar
|
||||
return normalized_audio, rms, scalar
|
||||
|
||||
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
|
||||
"""
|
||||
Avoid clipping by scaling down if necessary.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
scalar (float, optional): Explicit scaling factor
|
||||
|
||||
Returns:
|
||||
tuple: (normalized_audio, scalar)
|
||||
"""
|
||||
if scalar is None:
|
||||
max_val = np.max(np.abs(audio))
|
||||
if max_val > 1.0:
|
||||
scalar = max_val + self.eps
|
||||
else:
|
||||
scalar = 1.0
|
||||
|
||||
return audio / scalar, scalar
|
||||
|
||||
def __call__(self, audio: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize the audio by adjusting to target dB FS and avoiding clipping.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized audio signal
|
||||
"""
|
||||
# First adjust to target dB FS
|
||||
audio, _, _ = self.tailor_dB_FS(audio)
|
||||
# Then avoid clipping
|
||||
audio, _ = self.avoid_clipping(audio)
|
||||
return audio
|
||||
|
||||
|
||||
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
|
||||
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
|
||||
"""
|
||||
Processor for VibeVoice acoustic tokenizer models.
|
||||
|
||||
This processor handles audio preprocessing for VibeVoice models, including:
|
||||
- Audio format conversion (stereo to mono)
|
||||
- Optional audio normalization
|
||||
- Streaming support for infinite-length audio
|
||||
|
||||
Args:
|
||||
sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
|
||||
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
|
||||
target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
|
||||
eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
|
||||
"""
|
||||
model_input_names = ["input_features"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate: int = 24000,
|
||||
normalize_audio: bool = True,
|
||||
target_dB_FS: float = -25,
|
||||
eps: float = 1e-6,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.sampling_rate = sampling_rate
|
||||
self.normalize_audio = normalize_audio
|
||||
|
||||
# Initialize audio normalizer if needed
|
||||
if self.normalize_audio:
|
||||
self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
|
||||
else:
|
||||
self.normalizer = None
|
||||
|
||||
# Save config
|
||||
self.feature_extractor_dict = {
|
||||
"sampling_rate": sampling_rate,
|
||||
"normalize_audio": normalize_audio,
|
||||
"target_dB_FS": target_dB_FS,
|
||||
"eps": eps,
|
||||
}
|
||||
|
||||
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert stereo audio to mono if needed.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio array
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mono audio array
|
||||
"""
|
||||
if len(audio.shape) == 1:
|
||||
return audio
|
||||
elif len(audio.shape) == 2:
|
||||
if audio.shape[0] == 2: # (2, time)
|
||||
return np.mean(audio, axis=0)
|
||||
elif audio.shape[1] == 2: # (time, 2)
|
||||
return np.mean(audio, axis=1)
|
||||
else:
|
||||
# If one dimension is 1, squeeze it
|
||||
if audio.shape[0] == 1:
|
||||
return audio.squeeze(0)
|
||||
elif audio.shape[1] == 1:
|
||||
return audio.squeeze(1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected audio shape: {audio.shape}")
|
||||
else:
|
||||
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
|
||||
|
||||
def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
|
||||
"""
|
||||
Process a single audio array.
|
||||
|
||||
Args:
|
||||
audio: Single audio input
|
||||
|
||||
Returns:
|
||||
np.ndarray: Processed audio
|
||||
"""
|
||||
# Convert to numpy array
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = np.array(audio, dtype=np.float32)
|
||||
else:
|
||||
audio = audio.astype(np.float32)
|
||||
|
||||
# Ensure mono
|
||||
audio = self._ensure_mono(audio)
|
||||
|
||||
# Normalize if requested
|
||||
if self.normalize_audio and self.normalizer is not None:
|
||||
audio = self.normalizer(audio)
|
||||
|
||||
return audio
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Process audio for VibeVoice models.
|
||||
|
||||
Args:
|
||||
audio: Audio input(s) to process. Can be:
|
||||
- str: Path to audio file
|
||||
- np.ndarray: Audio array
|
||||
- List[float]: Audio as list of floats
|
||||
- List[np.ndarray]: Batch of audio arrays
|
||||
- List[str]: Batch of audio file paths
|
||||
sampling_rate (int, optional): Sampling rate of the input audio
|
||||
return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
|
||||
|
||||
Returns:
|
||||
dict: Processed audio inputs with keys:
|
||||
- input_features: Audio tensor(s) ready for the model
|
||||
"""
|
||||
if audio is None:
|
||||
raise ValueError("Audio input is required")
|
||||
|
||||
# Validate sampling rate
|
||||
if sampling_rate is not None and sampling_rate != self.sampling_rate:
|
||||
logger.warning(
|
||||
f"Input sampling rate ({sampling_rate}) differs from expected "
|
||||
f"sampling rate ({self.sampling_rate}). Please resample your audio."
|
||||
)
|
||||
|
||||
# Handle different input types
|
||||
if isinstance(audio, str):
|
||||
# Single audio file path
|
||||
audio = self._load_audio_from_path(audio)
|
||||
is_batched = False
|
||||
elif isinstance(audio, list):
|
||||
if len(audio) == 0:
|
||||
raise ValueError("Empty audio list provided")
|
||||
|
||||
# Check if it's a list of file paths
|
||||
if all(isinstance(item, str) for item in audio):
|
||||
# Batch of audio file paths
|
||||
audio = [self._load_audio_from_path(path) for path in audio]
|
||||
is_batched = True
|
||||
else:
|
||||
# Check if it's batched audio arrays
|
||||
is_batched = isinstance(audio[0], (np.ndarray, list))
|
||||
else:
|
||||
# Single audio array or list
|
||||
is_batched = False
|
||||
|
||||
# Process audio
|
||||
if is_batched:
|
||||
processed_audio = [self._process_single_audio(a) for a in audio]
|
||||
else:
|
||||
processed_audio = [self._process_single_audio(audio)]
|
||||
|
||||
# Convert to tensors if requested
|
||||
if return_tensors == "pt":
|
||||
if len(processed_audio) == 1:
|
||||
# Create a proper batch dimension (B, T)
|
||||
input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
|
||||
else:
|
||||
# For batched input with different lengths, create a batch properly
|
||||
input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1)
|
||||
elif return_tensors == "np":
|
||||
if len(processed_audio) == 1:
|
||||
input_features = processed_audio[0][np.newaxis, np.newaxis, :]
|
||||
else:
|
||||
input_features = np.stack(processed_audio)[:, np.newaxis, :]
|
||||
else:
|
||||
input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio
|
||||
|
||||
outputs = {
|
||||
"audio": input_features, # Use "audio" instead of "input_features"
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
|
||||
"""
|
||||
Load audio from file path.
|
||||
|
||||
Args:
|
||||
audio_path (str): Path to audio file
|
||||
|
||||
Returns:
|
||||
np.ndarray: Loaded audio array
|
||||
"""
|
||||
# Get file extension to determine loading method
|
||||
file_ext = os.path.splitext(audio_path)[1].lower()
|
||||
|
||||
if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
|
||||
# Audio file - use librosa
|
||||
import librosa
|
||||
audio_array, sr = librosa.load(
|
||||
audio_path,
|
||||
sr=self.sampling_rate,
|
||||
mono=True
|
||||
)
|
||||
return audio_array
|
||||
elif file_ext == '.pt':
|
||||
# PyTorch tensor file
|
||||
audio_tensor = torch.load(audio_path, map_location='cpu').squeeze()
|
||||
if isinstance(audio_tensor, torch.Tensor):
|
||||
audio_array = audio_tensor.numpy()
|
||||
else:
|
||||
audio_array = np.array(audio_tensor)
|
||||
return audio_array.astype(np.float32)
|
||||
elif file_ext == '.npy':
|
||||
# NumPy file
|
||||
audio_array = np.load(audio_path)
|
||||
return audio_array.astype(np.float32)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported file format: {file_ext}. "
|
||||
f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
|
||||
)
|
||||
|
||||
def preprocess_audio(
|
||||
self,
|
||||
audio_path_or_array: Union[str, np.ndarray],
|
||||
normalize: Optional[bool] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convenience method to preprocess audio from file path or array.
|
||||
This method is kept for backward compatibility but __call__ is recommended.
|
||||
|
||||
Args:
|
||||
audio_path_or_array: Path to audio file or numpy array
|
||||
normalize: Whether to normalize (overrides default setting)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Preprocessed audio array
|
||||
"""
|
||||
if isinstance(audio_path_or_array, str):
|
||||
audio_array = self._load_audio_from_path(audio_path_or_array)
|
||||
else:
|
||||
audio_array = np.array(audio_path_or_array, dtype=np.float32)
|
||||
|
||||
# Override normalization setting if specified
|
||||
original_normalize = self.normalize_audio
|
||||
if normalize is not None:
|
||||
self.normalize_audio = normalize
|
||||
|
||||
try:
|
||||
processed = self._process_single_audio(audio_array)
|
||||
finally:
|
||||
# Restore original setting
|
||||
self.normalize_audio = original_normalize
|
||||
|
||||
return processed
|
||||
|
||||
# Override to_dict method for configuration saving
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert the object to a dict containing all attributes needed for serialization.
|
||||
"""
|
||||
return self.feature_extractor_dict
|
||||
|
||||
def save_audio(
|
||||
self,
|
||||
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
||||
output_path: str = "output.wav",
|
||||
sampling_rate: Optional[int] = None,
|
||||
normalize: bool = False,
|
||||
batch_prefix: str = "audio_",
|
||||
):
|
||||
"""
|
||||
Save audio data to WAV file(s).
|
||||
|
||||
Args:
|
||||
audio: Audio data to save. Can be:
|
||||
- torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
|
||||
- np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
|
||||
- List of tensors or arrays
|
||||
output_path: Path where to save the audio. If saving multiple files,
|
||||
this is treated as a directory and individual files will be saved inside.
|
||||
sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
|
||||
normalize: Whether to normalize audio before saving.
|
||||
batch_prefix: Prefix for batch files when saving multiple audios.
|
||||
|
||||
Returns:
|
||||
List[str]: Paths to the saved audio files.
|
||||
"""
|
||||
if sampling_rate is None:
|
||||
sampling_rate = self.sampling_rate
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"soundfile is required to save audio files. "
|
||||
"Install it with: pip install soundfile"
|
||||
)
|
||||
|
||||
# Ensure audio is in the right format
|
||||
if isinstance(audio, torch.Tensor):
|
||||
# Convert PyTorch tensor to numpy
|
||||
audio_np = audio.float().detach().cpu().numpy()
|
||||
elif isinstance(audio, np.ndarray):
|
||||
audio_np = audio
|
||||
elif isinstance(audio, list):
|
||||
# Handle list of tensors or arrays
|
||||
if all(isinstance(a, torch.Tensor) for a in audio):
|
||||
audio_np = [a.float().detach().cpu().numpy() for a in audio]
|
||||
else:
|
||||
audio_np = audio
|
||||
else:
|
||||
raise ValueError(f"Unsupported audio type: {type(audio)}")
|
||||
|
||||
saved_paths = []
|
||||
|
||||
# Handle based on shape or type
|
||||
if isinstance(audio_np, list):
|
||||
# Multiple separate audios to save
|
||||
output_dir = output_path
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save each audio
|
||||
for i, audio_item in enumerate(audio_np):
|
||||
audio_item = self._prepare_audio_for_save(audio_item, normalize)
|
||||
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
|
||||
sf.write(file_path, audio_item, sampling_rate)
|
||||
saved_paths.append(file_path)
|
||||
|
||||
else:
|
||||
# Handle different dimensions
|
||||
if len(audio_np.shape) >= 3: # (B, C, T) or similar
|
||||
# Get batch size
|
||||
batch_size = audio_np.shape[0]
|
||||
|
||||
if batch_size > 1:
|
||||
# Multiple audios in a batch
|
||||
output_dir = output_path
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save each audio in the batch
|
||||
for i in range(batch_size):
|
||||
# Extract single audio and remove channel dim if present
|
||||
single_audio = audio_np[i]
|
||||
if len(single_audio.shape) > 1:
|
||||
if single_audio.shape[0] == 1: # (1, T)
|
||||
single_audio = single_audio.squeeze(0)
|
||||
|
||||
single_audio = self._prepare_audio_for_save(single_audio, normalize)
|
||||
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
|
||||
sf.write(file_path, single_audio, sampling_rate)
|
||||
saved_paths.append(file_path)
|
||||
else:
|
||||
# Single audio with batch and channel dims
|
||||
audio_item = audio_np.squeeze() # Remove batch and channel dimensions
|
||||
audio_item = self._prepare_audio_for_save(audio_item, normalize)
|
||||
sf.write(output_path, audio_item, sampling_rate)
|
||||
saved_paths.append(output_path)
|
||||
else:
|
||||
# Single audio without batch dimension
|
||||
audio_item = self._prepare_audio_for_save(audio_np, normalize)
|
||||
sf.write(output_path, audio_item, sampling_rate)
|
||||
saved_paths.append(output_path)
|
||||
|
||||
return saved_paths
|
||||
|
||||
def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
|
||||
"""
|
||||
Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
|
||||
|
||||
Args:
|
||||
audio: Audio data as numpy array
|
||||
normalize: Whether to normalize audio
|
||||
|
||||
Returns:
|
||||
np.ndarray: Processed audio ready for saving
|
||||
"""
|
||||
# Ensure right dimensionality
|
||||
if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T)
|
||||
audio = audio.squeeze(0)
|
||||
|
||||
# Normalize if requested
|
||||
if normalize:
|
||||
max_val = np.abs(audio).max()
|
||||
if max_val > 0:
|
||||
audio = audio / max_val
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
__all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"]
|
||||
0
vibevoice/schedule/__init__.py
Normal file
0
vibevoice/schedule/__init__.py
Normal file
1065
vibevoice/schedule/dpm_solver.py
Normal file
1065
vibevoice/schedule/dpm_solver.py
Normal file
File diff suppressed because it is too large
Load Diff
19
vibevoice/schedule/timestep_sampler.py
Normal file
19
vibevoice/schedule/timestep_sampler.py
Normal file
@ -0,0 +1,19 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class UniformSampler:
|
||||
def __init__(self, timesteps = 1000):
|
||||
self.timesteps = timesteps
|
||||
def sample(self, batch_size, device):
|
||||
return torch.randint(0, self.timesteps, (batch_size,), device=device)
|
||||
|
||||
class LogitNormalSampler:
|
||||
def __init__(self, timesteps = 1000, m = 0, s = 1):
|
||||
self.timesteps = timesteps
|
||||
timesteps = torch.linspace(0, 1, timesteps)
|
||||
logit = torch.log(timesteps / (1 - timesteps))
|
||||
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
|
||||
def sample(self, batch_size, device):
|
||||
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
|
||||
|
||||
0
vibevoice/scripts/__init__.py
Normal file
0
vibevoice/scripts/__init__.py
Normal file
166
vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py
Normal file
166
vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py
Normal file
@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import torch
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from vibevoice.modular.configuration_vibevoice import (
|
||||
VibeVoiceConfig
|
||||
)
|
||||
from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def convert_vibevoice_nnscaler_checkpoint_to_hf(
|
||||
checkpoint_path: str,
|
||||
pytorch_dump_folder_path: str,
|
||||
config_path: str = None,
|
||||
):
|
||||
"""
|
||||
Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
|
||||
Supports both regular checkpoints and tensor parallel checkpoints.
|
||||
"""
|
||||
|
||||
# Load regular checkpoint
|
||||
logger.info(f"Loading regular checkpoint from {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
|
||||
|
||||
# config = checkpoint['train_args']
|
||||
init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
|
||||
pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
|
||||
|
||||
init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
|
||||
if init_config_path.exists():
|
||||
logger.info(f"Loading initial config from {init_config_path}")
|
||||
with open(init_config_path, 'r') as f:
|
||||
init_config = json.load(f)
|
||||
else:
|
||||
raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
|
||||
|
||||
tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
|
||||
logger.info(f"Tie word embeddings: {tie_word_embeddings}")
|
||||
|
||||
init_config['decoder_config']['use_cache'] = True
|
||||
config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
|
||||
|
||||
# # Extract the model state dict
|
||||
model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
|
||||
if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
|
||||
# If not tying weights, we need to add the lm_head weight separately
|
||||
model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
|
||||
|
||||
# Override with provided config if available
|
||||
if config_path:
|
||||
logger.info(f"Loading config from {config_path}")
|
||||
with open(config_path, 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
config = VibeVoiceConfig.from_dict(config_dict)
|
||||
|
||||
# Set the default dtype to bfloat16 before creating the model
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
# Create the HuggingFace model
|
||||
logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
|
||||
model = VibeVoiceForConditionalGeneration(config)
|
||||
|
||||
# Restore original dtype
|
||||
torch.set_default_dtype(original_dtype)
|
||||
|
||||
# Load the state dict
|
||||
logger.info("Loading weights into model")
|
||||
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
logger.warning(f"Missing keys: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logger.warning(f"Unexpected keys: {unexpected_keys}")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
|
||||
# Save the model and config
|
||||
logger.info(f"Saving model to {pytorch_dump_folder_path}")
|
||||
|
||||
# Save config
|
||||
config.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
# Save VibeVoiceProcessor configuration
|
||||
logger.info("Saving VibeVoiceProcessor configuration")
|
||||
processor_config = {
|
||||
"processor_class": "VibeVoiceProcessor",
|
||||
"speech_tok_compress_ratio": 3200,
|
||||
"db_normalize": True,
|
||||
# Audio processor configuration
|
||||
"audio_processor": {
|
||||
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
|
||||
"sampling_rate": 24000,
|
||||
"normalize_audio": True,
|
||||
"target_dB_FS": -25,
|
||||
"eps": 1e-6,
|
||||
},
|
||||
"language_model_pretrained_name": pretrained_name,
|
||||
}
|
||||
|
||||
processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
|
||||
with open(processor_config_path, 'w') as f:
|
||||
json.dump(processor_config, f, indent=2)
|
||||
logger.info(f"Saved processor config to {processor_config_path}")
|
||||
|
||||
# Save model with sharding
|
||||
# save_pretrained handles tied weights automatically
|
||||
logger.info("Saving model weights with sharding...")
|
||||
model.save_pretrained(
|
||||
pytorch_dump_folder_path,
|
||||
max_shard_size="2GB", # Set maximum size for each shard
|
||||
safe_serialization=True # Ensure saving in .safetensors format
|
||||
)
|
||||
logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
|
||||
|
||||
logger.info("Conversion complete!")
|
||||
|
||||
# Verify the saved model can be loaded
|
||||
logger.info("Verifying saved model...")
|
||||
loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
|
||||
logger.info("Model successfully loaded from saved checkpoint!")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--nnscaler_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
|
||||
"provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
|
||||
"and the script will automatically detect and merge all parts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional path to a config JSON file to override extracted config",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_vibevoice_nnscaler_checkpoint_to_hf(
|
||||
args.nnscaler_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.config_path,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user