First commit

This commit is contained in:
2025-09-03 19:24:26 +10:00
parent 2bf7123ae2
commit 1e1402795f
30 changed files with 7582 additions and 33 deletions

35
.gitignore vendored
View File

@ -1,4 +1,3 @@
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@ -8,7 +7,10 @@ __pycache__/
*.so
# Distribution / packaging
.github
.idea
.Python
__pycache__
build/
develop-eggs/
dist/
@ -95,12 +97,6 @@ ipython_config.py
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
@ -113,10 +109,8 @@ ipython_config.py
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
@ -167,24 +161,3 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# ---> JupyterNotebooks
# gitignore template for Jupyter Notebooks
# website: http://jupyter.org/
.ipynb_checkpoints
*/.ipynb_checkpoints/*
# IPython
profile_default/
ipython_config.py
# Remove previous ipynb_checkpoints
# git rm -r .ipynb_checkpoints/

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 WildAi
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

173
README.md
View File

@ -1,3 +1,172 @@
# VibeVoice-Modifications
<!-- Improved compatibility of back to top link: See: https://github.com/othneildrew/Best-README-Template/pull/73 -->
<a id="readme-top"></a>
VibeVoice-Modifications
<div align="center">
<h1 align="center">ComfyUI-VibeVoice</h1>
<img src="./example_workflows/VibeVoice_example.png" alt="ComfyUI-VibeVoice Nodes" alt="Logo" width="600" height="388">
<p align="center">
A custom node for ComfyUI that integrates Microsoft's VibeVoice, a frontier model for generating expressive, long-form, multi-speaker conversational audio.
<br />
<br />
<a href="https://github.com/wildminder/ComfyUI-VibeVoice/issues/new?labels=bug&template=bug-report---.md">Report Bug</a>
·
<a href="https://github.com/wildminder/ComfyUI-VibeVoice/issues/new?labels=enhancement&template=feature-request---.md">Request Feature</a>
<!-- PROJECT SHIELDS -->
[![Stargazers][stars-shield]][stars-url]
[![Issues][issues-shield]][issues-url]
[![Contributors][contributors-shield]][contributors-url]
[![Forks][forks-shield]][forks-url]
</p>
</div>
<!-- ABOUT THE PROJECT -->
## About The Project
This project brings the power of **VibeVoice** into the modular workflow of ComfyUI. VibeVoice is a novel framework by Microsoft for generating expressive, long-form, multi-speaker conversational audio. It excels at creating natural-sounding dialogue, podcasts, and more, with consistent voices for up to 4 speakers.
The custom node handles everything from model downloading and memory management to audio processing, allowing you to generate high-quality speech directly from a text script and reference audio files.
**Key Features:**
* **Multi-Speaker TTS:** Generate conversations with up to 4 distinct voices in a single audio output.
* **Zero-Shot Voice Cloning:** Use any audio file (`.wav`, `.mp3`) as a reference for a speaker's voice.
* **Automatic Model Management:** Models are downloaded automatically from Hugging Face and managed efficiently by ComfyUI to save VRAM.
* **Fine-Grained Control:** Adjust parameters like CFG scale, temperature, and sampling methods to tune the performance and style of the generated speech.
* **4-Bit Quantization:** Run the large language model component in 4-bit mode to significantly reduce VRAM usage and improve speed on memory-constrained GPUs, especially for the 7B model.
* **Transformers 4.56+ Compatibility:** Fully backwards compatible with both older and newer versions of the Transformers library.
* **Force Offload Option:** Toggle to force model offloading from VRAM after generation to save memory between runs - now with improved ComfyUI compatibility.
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- GETTING STARTED -->
## Getting Started
Follow these steps to get the ComfyUI-VibeVoice node running in your environment.
### Installation
The node can be installed via **ComfyUI Manager:** Find `ComfyUI-VibeVoice` and click "Install". Or, install it manually:
1. **Clone the Repository:**
Navigate to your `ComfyUI/custom_nodes/` directory and clone this repository:
```sh
git clone https://github.com/wildminder/ComfyUI-VibeVoice.git
```
2. **Install Dependencies:**
Open a terminal or command prompt, navigate into the cloned directory, and install the required Python packages. **For quantization support, ensure you install `bitsandbytes`**.
```sh
cd ComfyUI-VibeVoice
pip install -r requirements.txt
```
3. **Start/Restart ComfyUI:**
Launch ComfyUI. The "VibeVoice TTS" node will appear under the `audio/tts` category. The first time you use the node, it will automatically download the selected model to your `ComfyUI/models/tts/VibeVoice/` folder.
## Models
| Model | Context Length | Generation Length | Weight |
|-------|----------------|----------|----------|
| VibeVoice-1.5B | 64K | ~90 min | [HF link](https://huggingface.co/microsoft/VibeVoice-1.5B) |
| VibeVoice-Large| 32K | ~45 min | [HF link](https://huggingface.co/microsoft/VibeVoice-Large) |
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- USAGE EXAMPLES -->
## Usage
The node is designed to be intuitive within the ComfyUI workflow.
1. **Add Nodes:** Add the `VibeVoice TTS` node to your graph. Use ComfyUI's built-in `Load Audio` node to load your reference voice files.
2. **Connect Voices:** Connect the `AUDIO` output from each `Load Audio` node to the corresponding `speaker_*_voice` input on the VibeVoice TTS node.
3. **Write Script:** In the `text` input, write your dialogue. Assign lines to speakers using the format `Speaker 1: ...`, `Speaker 2: ...`, etc., on separate lines.
4. **Generate:** Queue the prompt. The node will process the script and generate a single audio file containing the full conversation.
_For a complete workflow, you can drag the example image from the `example_workflows` folder onto your ComfyUI canvas._
### Node Inputs
* **`model_name`**: Select the VibeVoice model to use.
* **`quantize_llm`**: (New!) Enable to run the LLM component in 4-bit (NF4) mode. This dramatically reduces VRAM and can significantly speed up inference on the 7B model.
* **`text`**: The conversational script. Lines must be prefixed with `Speaker <number>:` (e.g., `Speaker 1:`).
* **`cfg_scale`**: Controls how strongly the model adheres to the reference voice's timbre.
* **`inference_steps`**: Number of diffusion steps for the audio decoder.
* **`seed`**: A seed for reproducibility.
* **`do_sample`, `temperature`, `top_p`, `top_k`**: Standard sampling parameters for controlling the creativity and determinism of the speech generation.
* **`force_offload`**: (New!) Forces the model to be completely offloaded from VRAM after generation. Useful for memory management but may slow down subsequent runs.
* **`speaker_*_voice` (Optional)**: Connect an `AUDIO` output from a `Load Audio` node to provide a voice reference.
### Performance & Quantization
A key feature of this node is the optional **4-bit quantization** for the language model component. This is highly recommended for users with memory-constrained GPUs (e.g., <= 16GB VRAM) who wish to run the larger `VibeVoice-Large-pt` model.
**Benefits of `quantize_llm = Enabled`:**
| Model | Performance Impact | VRAM Savings |
|---|---|---|
| **VibeVoice-Large (7B)** | **~8.5x faster** inference | Saves **>4.4 GB** (over 36%) |
| **VibeVoice-1.5B** | ~1.5x slower inference | Saves **~5.5 GB** (over 63%) |
As shown, quantization provides a massive speedup and VRAM reduction for the 7B model, making it accessible on a wider range of hardware. While it slightly slows down the 1.5B model, the significant VRAM savings may still be beneficial for complex workflows.
### Transformers Library Compatibility
This version includes automatic detection and compatibility for both older and newer versions of the Transformers library:
* **Transformers 4.56+**: Automatically uses the new method signature for `_prepare_cache_for_generation`
* **Older Versions**: Maintains compatibility with pre-4.56 versions using the legacy method signature
* **Fallback Mechanism**: If detection fails, the node will automatically try both versions to ensure maximum compatibility
This ensures the node works seamlessly regardless of your Transformers version without requiring manual updates.
### Tips from the Original Authors
* **Punctuation:** For Chinese text, using English punctuation (commas and periods) can improve stability.
* **Model Choice:** The 7B model variant (`VibeVoice-Large`) is generally more stable.
* **Spontaneous Sounds/Music:** The model may spontaneously generate background music, especially if the reference audio contains it or if the text includes introductory phrases like "Welcome to...". This is an emergent capability and cannot be directly controlled.
* **Singing:** The model was not trained on singing data, but it may attempt to sing as an emergent behavior. Results may vary.
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- BUG FIXES -->
## Recent Bug Fixes
### Force Offload Compatibility Fix
* **Fixed:** Resolved `AttributeError: module 'comfy.model_management' has no attribute 'unload_model_clones'` error when using the force offload option
* **Details:** Updated the force offload implementation to use ComfyUI's standard `unload_all_models()` API instead of the deprecated `unload_model_clones()` function
* **Impact:** Force offload functionality now works correctly with all versions of ComfyUI
### Multi-Speaker DynamicCache Fix
* **Fixed:** Resolved `'DynamicCache' object has no attribute 'key_cache'` error when using multiple speakers
* **Details:** Updated cache access in `modeling_vibevoice_inference.py` to use proper DynamicCache API - accessing layers via indexing instead of deprecated `.key_cache` and `.value_cache` attributes
* **Impact:** Multi-speaker functionality now works correctly with newer versions of Transformers library
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- LICENSE -->
## License
This project is distributed under the MIT License. See `LICENSE.txt` for more information. The VibeVoice model and its components are subject to the licenses provided by Microsoft. Please use responsibly.
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- ACKNOWLEDGMENTS -->
## Acknowledgments
* **Microsoft** for creating and open-sourcing the [VibeVoice](https://github.com/microsoft/VibeVoice) project.
* **The ComfyUI team** for their incredible and extensible platform.
* **othneildrew** for the [Best-README-Template](https://github.com/othneildrew/Best-README-Template).
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- MARKDOWN LINKS & IMAGES -->
[contributors-shield]: https://img.shields.io/github/contributors/wildminder/ComfyUI-VibeVoice.svg?style=for-the-badge
[contributors-url]: https://github.com/wildminder/ComfyUI-VibeVoice/graphs/contributors
[forks-shield]: https://img.shields.io/github/forks/wildminder/ComfyUI-VibeVoice.svg?style=for-the-badge
[forks-url]: https://github.com/wildminder/ComfyUI-VibeVoice/network/members
[stars-shield]: https://img.shields.io/github/stars/wildminder/ComfyUI-VibeVoice.svg?style=for-the-badge
[stars-url]: https://github.com/wildminder/ComfyUI-VibeVoice/stargazers
[issues-shield]: https://img.shields.io/github/issues/wildminder/ComfyUI-VibeVoice.svg?style=for-the-badge
[issues-url]: https://github.com/wildminder/ComfyUI-VibeVoice/issues

63
__init__.py Normal file
View File

@ -0,0 +1,63 @@
import os
import sys
import logging
# allowing absolute imports like 'from vibevoice.modular...' to work.
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.append(current_dir)
import folder_paths
from .vibevoice_nodes import NODE_CLASS_MAPPINGS as BASE_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as BASE_NODE_DISPLAY_NAME_MAPPINGS
# Configure a logger for the entire custom node package
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logger.propagate = False
if not logger.hasHandlers():
handler = logging.StreamHandler()
formatter = logging.Formatter(f"[ComfyUI-VibeVoice] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
VIBEVOICE_MODEL_SUBDIR = os.path.join("tts", "VibeVoice")
vibevoice_models_full_path = os.path.join(folder_paths.models_dir, VIBEVOICE_MODEL_SUBDIR)
os.makedirs(vibevoice_models_full_path, exist_ok=True)
# Register the tts/VibeVoice path with ComfyUI
tts_path = os.path.join(folder_paths.models_dir, "tts")
if "tts" not in folder_paths.folder_names_and_paths:
supported_exts = folder_paths.supported_pt_extensions.union({".safetensors", ".json"})
folder_paths.folder_names_and_paths["tts"] = ([tts_path], supported_exts)
else:
if tts_path not in folder_paths.folder_names_and_paths["tts"][0]:
folder_paths.folder_names_and_paths["tts"][0].append(tts_path)
try:
from .vibevoice_node_chunked_wrapper import (
NODE_CLASS_MAPPINGS as WRAP_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS as WRAP_NODE_DISPLAY_NAME_MAPPINGS,
)
except Exception as e:
logger.warning(f"[ComfyUI-VibeVoice] Wrapper failed to load: {e}")
WRAP_NODE_CLASS_MAPPINGS = {}
WRAP_NODE_DISPLAY_NAME_MAPPINGS = {}
# Merge and export
NODE_CLASS_MAPPINGS = {
**BASE_NODE_CLASS_MAPPINGS,
**WRAP_NODE_CLASS_MAPPINGS,
}
NODE_DISPLAY_NAME_MAPPINGS = {
**BASE_NODE_DISPLAY_NAME_MAPPINGS,
**WRAP_NODE_DISPLAY_NAME_MAPPINGS,
}
WEB_DIRECTORY = "./js"
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'WEB_DIRECTORY']

View File

@ -0,0 +1,277 @@
{
"id": "b91265e5-1b03-4b63-8dc3-4abd9a030e08",
"revision": 0,
"last_node_id": 11,
"last_link_id": 29,
"nodes": [
{
"id": 4,
"type": "LoadAudio",
"pos": [
-1900,
-1130
],
"size": [
274.080078125,
136
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "AUDIO",
"type": "AUDIO",
"links": [
28
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.3.52",
"Node name for S&R": "LoadAudio",
"ue_properties": {
"widget_ue_connectable": {
"audio": true,
"audioUI": true,
"upload": true
},
"version": "7.0.1"
}
},
"widgets_values": [
"male_rickmorty.mp3",
null,
null
]
},
{
"id": 11,
"type": "VibeVoiceTTS",
"pos": [
-1570,
-1130
],
"size": [
460,
510
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "speaker_1_voice",
"shape": 7,
"type": "AUDIO",
"link": 28
},
{
"name": "speaker_2_voice",
"shape": 7,
"type": "AUDIO",
"link": 29
},
{
"name": "speaker_3_voice",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"name": "speaker_4_voice",
"shape": 7,
"type": "AUDIO",
"link": null
}
],
"outputs": [
{
"name": "AUDIO",
"type": "AUDIO",
"links": [
27
]
}
],
"properties": {
"cnr_id": "ComfyUI-VibeVoice",
"ver": "37803a884fb8f9b43c38286f6d654c7f97181a73",
"Node name for S&R": "VibeVoiceTTS"
},
"widgets_values": [
"VibeVoice-1.5B",
"Speaker 1: I can't believe you did it again. I waited for two hours. Two hours! Not a single call, not a text. Do you have any idea how embarrassing that was, just sitting there alone?\nSpeaker 2: Look, I know, I'm sorry, alright? Work was a complete nightmare. My boss dropped a critical deadline on me at the last minute. I didn't even have a second to breathe, let alone check my phone.\nSpeaker 1: A nightmare? That's the same excuse you used last time. I'm starting to think you just don't care. It's easier to say 'work was crazy' than to just admit that I'm not a priority for you anymore.",
false,
"sdpa",
1.3,
10,
56109085141530,
"randomize",
true,
0.95,
0.95,
0
],
"color": "#232",
"bgcolor": "#353"
},
{
"id": 8,
"type": "LoadAudio",
"pos": [
-1900,
-940
],
"size": [
274.080078125,
136
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "AUDIO",
"type": "AUDIO",
"links": [
29
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.3.52",
"Node name for S&R": "LoadAudio",
"ue_properties": {
"widget_ue_connectable": {
"audio": true,
"audioUI": true,
"upload": true
},
"version": "7.0.1"
}
},
"widgets_values": [
"male_stewie.mp3",
null,
null
]
},
{
"id": 10,
"type": "MarkdownNote",
"pos": [
-1030,
-960
],
"size": [
420,
210
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [],
"title": "Notes",
"properties": {
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.0.1"
}
},
"widgets_values": [
"## Models\n\nWill be downloaded on the first run, or download them manually and place them into the directory: /models/tts/VibeVoice\n\n| Model | Context Length | Generation Length | Weight |\n|-------|----------------|----------|----------|\n| VibeVoice-0.5B-Streaming | - | - | On the way |\n| VibeVoice-1.5B | 64K | ~90 min | [HF link](https://huggingface.co/microsoft/VibeVoice-1.5B) |\n| VibeVoice-Large| 32K | ~45 min | [HF link](https://huggingface.co/microsoft/VibeVoice-Large) |"
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 3,
"type": "SaveAudio",
"pos": [
-1040,
-1130
],
"size": [
270,
112
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "audio",
"type": "AUDIO",
"link": 27
}
],
"outputs": [],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.3.52",
"Node name for S&R": "SaveAudio",
"ue_properties": {
"widget_ue_connectable": {
"filename_prefix": true,
"audioUI": true
},
"version": "7.0.1"
}
},
"widgets_values": [
"audio/VibeVoice"
]
}
],
"links": [
[
27,
11,
0,
3,
0,
"AUDIO"
],
[
28,
4,
0,
11,
0,
"AUDIO"
],
[
29,
8,
0,
11,
1,
"AUDIO"
]
],
"groups": [],
"config": {},
"extra": {
"ue_links": [],
"links_added_by_ue": [],
"ds": {
"scale": 1.2100000000000004,
"offset": [
2024.7933884297524,
1252.3140495867776
]
},
"frontendVersion": "1.25.11",
"VHS_latentpreview": false,
"VHS_latentpreviewrate": 0,
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"version": 0.4
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

107
js/vibevoice_wrapper_ui.js Normal file
View File

@ -0,0 +1,107 @@
// custom_nodes/YourPkg/js/vibevoice_wrapper_ui.js
import { app } from "../../scripts/app.js";
app.registerExtension({
name: "vibevoice.wrapper.ui",
async beforeRegisterNodeDef(nodeType, nodeData) {
const isWrapper =
nodeType?.comfyClass === "VibeVoiceTTS_Wrapper" ||
nodeData?.name === "VibeVoice TTS (Chunked Wrapper)";
if (!isWrapper) return;
const origOnCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function () {
origOnCreated?.apply(this, arguments);
// only set up handlers here; do NOT mutate slots yet
wireUpHandlers(this);
};
function wireUpHandlers(node) {
const findW = (n) => node.widgets?.find((w) => w.name === n);
const wNum = findW("num_speakers");
const wChunk = findW("chunk_lines");
const wLines = findW("lines_per_chunk");
function ensureSpeakerInputs(count) {
// add missing inputs
for (let i = 1; i <= count; i++) {
const name = `speaker_${i}_voice`;
if (node.findInputSlot(name) === -1) node.addInput(name, "AUDIO");
}
// remove extras
for (let i = count + 1; i <= 4; i++) {
const name = `speaker_${i}_voice`;
const idx = node.findInputSlot(name);
if (idx !== -1) node.removeInput(idx);
}
}
// guard: only mutate once node.graph exists (prevents NullGraphError)
function safeMutate(fn) {
const doIt = () => {
if (!node.graph) {
// defer until the node is actually attached to a graph
setTimeout(doIt, 0);
return;
}
fn();
app.graph.setDirtyCanvas(true, true);
};
doIt();
}
function refresh() {
const n = Math.max(1, Math.min(4, Number(wNum?.value ?? 1)));
safeMutate(() => ensureSpeakerInputs(n));
if (wLines) wLines.hidden = !(wChunk?.value);
}
// robust wiring (some frontends only call one of these)
if (wNum) { wNum.callback = refresh; wNum.onChange = refresh; }
if (wChunk) { wChunk.callback = refresh; wChunk.onChange = refresh; }
// dont call refresh yet; node may not be in graph during configure
node.__vv_refresh = refresh;
}
},
// Called for brand-new nodes added from the menu (node has a graph here)
async nodeCreated(node) {
if (
node?.comfyClass === "VibeVoiceTTS_Wrapper" ||
node?.title === "VibeVoice TTS (Chunked Wrapper)"
) {
// next tick to ensure widgets fully exist
setTimeout(() => node.__vv_refresh?.(), 0);
}
},
// Called when nodes are created as part of loading a workflow
loadedGraphNode(node) {
if (
node?.comfyClass === "VibeVoiceTTS_Wrapper" ||
node?.title === "VibeVoice TTS (Chunked Wrapper)"
) {
setTimeout(() => node.__vv_refresh?.(), 0);
}
},
// After the graph finishes configuring (safe point to mutate slots)
async afterConfigureGraph() {
// final pass in case anything was deferred
for (const node of app.graph._nodes) {
if (
node?.comfyClass === "VibeVoiceTTS_Wrapper" ||
node?.title === "VibeVoice TTS (Chunked Wrapper)"
) {
node.__vv_refresh?.();
}
}
},
async setup() {
console.log("[vibevoice.wrapper.ui] setup complete");
},
});

23
pyproject.toml Normal file
View File

@ -0,0 +1,23 @@
[project]
name = "ComfyUI-VibeVoice"
description = "VibeVoice TTS. Expressive, long-form, multi-speaker conversational audio"
version = "1.2.0"
license = {file = "LICENSE"}
dependencies = ["torch", "torchaudio", "librosa", "numpy", "huggingface_hub", "einops", "scipy", "tokenizers", "soundfile", "s3tokenizer", "tqdm", "conformer", "safetensors", "transformers", "diffusers", "bitsandbytes"]
[project.urls]
Repository = "https://github.com/wildminder/ComfyUI-VibeVoice"
# Used by Comfy Registry https://comfyregistry.org
[tool.comfy]
PublisherId = "wildai"
DisplayName = "ComfyUI-VibeVoice"
Icon = ""

17
requirements.txt Normal file
View File

@ -0,0 +1,17 @@
torch
accelerate
torchaudio
librosa
numpy
huggingface_hub
einops
scipy
tokenizers
soundfile
s3tokenizer
conformer
safetensors
transformers
diffusers
tqdm
bitsandbytes

0
vibevoice/__init__.py Normal file
View File

View File

@ -0,0 +1,112 @@
{
"_attn_implementation_autoset": true,
"acoustic_vae_dim": 64,
"acoustic_tokenizer_config": {
"causal": true,
"channels": 1,
"conv_bias": true,
"conv_norm": "none",
"corpus_normalize": 0.0,
"decoder_depths": null,
"decoder_n_filters": 32,
"decoder_ratios": [
8,
5,
5,
4,
2,
2
],
"disable_last_norm": true,
"encoder_depths": "3-3-3-3-3-3-8",
"encoder_n_filters": 32,
"encoder_ratios": [
8,
5,
5,
4,
2,
2
],
"fix_std": 0.5,
"layer_scale_init_value": 1e-06,
"layernorm": "RMSNorm",
"layernorm_elementwise_affine": true,
"layernorm_eps": 1e-05,
"mixer_layer": "depthwise_conv",
"model_type": "vibepod_acoustic_tokenizer",
"pad_mode": "constant",
"std_dist_type": "gaussian",
"vae_dim": 64,
"weight_init_value": 0.01
},
"decoder_config": {
"attention_dropout": 0.0,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 8960,
"max_position_embeddings": 65536,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 12,
"num_hidden_layers": 28,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
},
"diffusion_head_config": {
"ddpm_batch_mul": 4,
"ddpm_beta_schedule": "cosine",
"ddpm_num_inference_steps": 20,
"ddpm_num_steps": 1000,
"diffusion_type": "ddpm",
"head_ffn_ratio": 3.0,
"head_layers": 4,
"hidden_size": 1536,
"latent_size": 64,
"model_type": "vibepod_diffusion_head",
"prediction_type": "v_prediction",
"rms_norm_eps": 1e-05,
"speech_vae_dim": 64
},
"model_type": "vibepod",
"semantic_tokenizer_config": {
"causal": true,
"channels": 1,
"conv_bias": true,
"conv_norm": "none",
"corpus_normalize": 0.0,
"disable_last_norm": true,
"encoder_depths": "3-3-3-3-3-3-8",
"encoder_n_filters": 32,
"encoder_ratios": [
8,
5,
5,
4,
2,
2
],
"fix_std": 0,
"layer_scale_init_value": 1e-06,
"layernorm": "RMSNorm",
"layernorm_elementwise_affine": true,
"layernorm_eps": 1e-05,
"mixer_layer": "depthwise_conv",
"model_type": "vibepod_semantic_tokenizer",
"pad_mode": "constant",
"std_dist_type": "none",
"vae_dim": 128,
"weight_init_value": 0.01
},
"semantic_vae_dim": 128,
"torch_dtype": "bfloat16"
}

View File

@ -0,0 +1,113 @@
{
"_attn_implementation_autoset": true,
"acoustic_vae_dim": 64,
"acoustic_tokenizer_config": {
"causal": true,
"channels": 1,
"conv_bias": true,
"conv_norm": "none",
"corpus_normalize": 0.0,
"decoder_depths": null,
"decoder_n_filters": 32,
"decoder_ratios": [
8,
5,
5,
4,
2,
2
],
"disable_last_norm": true,
"encoder_depths": "3-3-3-3-3-3-8",
"encoder_n_filters": 32,
"encoder_ratios": [
8,
5,
5,
4,
2,
2
],
"fix_std": 0.5,
"layer_scale_init_value": 1e-06,
"layernorm": "RMSNorm",
"layernorm_elementwise_affine": true,
"layernorm_eps": 1e-05,
"mixer_layer": "depthwise_conv",
"model_type": "vibepod_acoustic_tokenizer",
"pad_mode": "constant",
"std_dist_type": "gaussian",
"vae_dim": 64,
"weight_init_value": 0.01
},
"decoder_config": {
"attention_dropout": 0.0,
"hidden_act": "silu",
"hidden_size": 3584,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 32768,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.1",
"use_cache": true,
"use_mrope": false,
"use_sliding_window": false,
"vocab_size": 152064
},
"diffusion_head_config": {
"ddpm_batch_mul": 4,
"ddpm_beta_schedule": "cosine",
"ddpm_num_inference_steps": 20,
"ddpm_num_steps": 1000,
"diffusion_type": "ddpm",
"head_ffn_ratio": 3.0,
"head_layers": 4,
"hidden_size": 3584,
"latent_size": 64,
"model_type": "vibepod_diffusion_head",
"prediction_type": "v_prediction",
"rms_norm_eps": 1e-05,
"speech_vae_dim": 64
},
"model_type": "vibepod",
"semantic_tokenizer_config": {
"causal": true,
"channels": 1,
"conv_bias": true,
"conv_norm": "none",
"corpus_normalize": 0.0,
"disable_last_norm": true,
"encoder_depths": "3-3-3-3-3-3-8",
"encoder_n_filters": 32,
"encoder_ratios": [
8,
5,
5,
4,
2,
2
],
"fix_std": 0,
"layer_scale_init_value": 1e-06,
"layernorm": "RMSNorm",
"layernorm_elementwise_affine": true,
"layernorm_eps": 1e-05,
"mixer_layer": "depthwise_conv",
"model_type": "vibepod_semantic_tokenizer",
"pad_mode": "constant",
"std_dist_type": "none",
"vae_dim": 128,
"weight_init_value": 0.01
},
"semantic_vae_dim": 128,
"torch_dtype": "bfloat16"
}

View File

View File

@ -0,0 +1,248 @@
""" VibeVoice_AcousticTokenizer model configuration"""
from typing import Dict, List, Optional, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
logger = logging.get_logger(__name__)
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
model_type = "vibevoice_acoustic_tokenizer"
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0.5,
std_dist_type: str = 'gaussian',
# common
mixer_layer: str = 'depthwise_conv',
conv_norm: str = 'none',
pad_mode: str = 'constant',
disable_last_norm: bool = True,
layernorm: str = 'RMSNorm',
layernorm_eps: float = 1e-5,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-6,
weight_init_value: float = 1e-2,
# encoder specific
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
encoder_depths: str = "3-3-3-3-3-3-8",
# decoder specific
decoder_n_filters: int = 32,
decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
decoder_depths: Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
# common parameters
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
# encoder specific parameters
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
# decoder specific parameters
self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
self.decoder_n_filters = decoder_n_filters
self.decoder_depths = decoder_depths
class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
model_type = "vibevoice_semantic_tokenizer"
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0,
std_dist_type: str = 'none',
# common
mixer_layer: str = 'depthwise_conv',
conv_norm: str = 'none',
pad_mode: str = 'constant',
disable_last_norm: bool = True,
layernorm: str = 'RMSNorm',
layernorm_eps: float = 1e-5,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-6,
weight_init_value: float = 1e-2,
# encoder specific
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
encoder_depths: str = "3-3-3-3-3-3-8",
**kwargs
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
# common parameters
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
# encoder specific parameters
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
model_type = "vibevoice_diffusion_head"
def __init__(
self,
hidden_size=768,
head_layers=4,
head_ffn_ratio=3.0,
rms_norm_eps=1e-5,
latent_size=64,
speech_vae_dim=None,
prediction_type="v_prediction",
diffusion_type="ddpm",
ddpm_num_steps=1000,
ddpm_num_inference_steps=20,
ddpm_beta_schedule="cosine",
ddpm_batch_mul=4,
**kwargs
):
self.hidden_size = hidden_size
self.head_layers = head_layers
self.head_ffn_ratio = head_ffn_ratio
self.rms_norm_eps = rms_norm_eps
self.latent_size = latent_size
self.speech_vae_dim = speech_vae_dim
self.prediction_type = prediction_type
self.diffusion_type = diffusion_type
self.ddpm_num_steps = ddpm_num_steps
self.ddpm_num_inference_steps = ddpm_num_inference_steps
self.ddpm_beta_schedule = ddpm_beta_schedule
self.ddpm_batch_mul = ddpm_batch_mul
super().__init__(**kwargs)
class VibeVoiceConfig(PretrainedConfig):
model_type = "vibevoice"
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
"decoder_config": Qwen2Config,
"diffusion_head_config": VibeVoiceDiffusionHeadConfig,
}
# keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
semantic_tokenizer_config=None,
decoder_config=None,
diffusion_head_config=None,
**kwargs
):
# kwargs["_attn_implementation"] = "flash_attention_2"
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
# If an instance of the config class is provided
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if semantic_tokenizer_config is None:
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
elif isinstance(semantic_tokenizer_config, dict):
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
# If an instance of the config class is provided
self.semantic_tokenizer_config = semantic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
# If a dictionary is provided, instantiate the config class with it
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
if decoder_config.get("model_type", '') == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
elif isinstance(decoder_config, (Qwen2Config,)):
# If an instance of the config class is provided
self.decoder_config = decoder_config
if diffusion_head_config is None:
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
elif isinstance(diffusion_head_config, dict):
diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
# If an instance of the config class is provided
self.diffusion_head_config = diffusion_head_config
# other parameters
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
super().__init__(**kwargs)
__all__ = [
"VibeVoiceAcousticTokenizerConfig",
"VibeVoiceSemanticTokenizerConfig",
"VibeVoiceDiffusionHeadConfig",
"VibeVoiceConfig"
]

View File

@ -0,0 +1,488 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice import VibeVoiceConfig
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
@dataclass
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
diffusion_loss: Optional[torch.FloatTensor] = None
speech_token_num: Optional[int] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class VibeVoiceGenerationOutput(ModelOutput):
"""
Output type for VibeVoice generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences.
speech_outputs (`List[torch.FloatTensor]`, *optional*):
List of generated speech waveforms or latents for each speech segment.
"""
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
# @auto_docstring
class VibeVoicePreTrainedModel(PreTrainedModel):
config_class = VibeVoiceConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, VibeVoiceDiffusionHead):
module.initialize_weights()
return
# Use the language model's initializer_range if available
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
std = self.config.decoder_config.initializer_range
else:
std = 0.02 # Default value
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
# @auto_docstring
class VibeVoiceModel(VibeVoicePreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
# Initialize Qwen2 model for language modeling
lm_config = config.decoder_config
self.language_model = AutoModel.from_config(lm_config)
# Initialize speech components if needed
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
# Initialize prediction head for speech generation
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
# Initialize noise scheduler
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type
)
def get_input_embeddings(self):
if hasattr(self.language_model, 'embed_tokens'):
# If the language model has an embed_tokens attribute, return it
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
if attr.orig_name == 'embed_tokens.weight':
return getattr(self.language_model, name)
assert False, 'should not arrive here'
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.acoustic_tokenizer = acoustic_tokenizer
self.semantic_tokenizer = semantic_tokenizer
# Reset the encoder to evaluation mode
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.eval()
if self.semantic_tokenizer is not None:
self.semantic_tokenizer.eval()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Forward through language model
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return outputs
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = VibeVoiceModel(config)
self.vocab_size = config.decoder_config.vocab_size
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_decoder(self, decoder):
self.model.language_model = decoder
def get_decoder(self):
return self.model.language_model
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
"""
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
# The standard PreTrainedModel method will handle the tying.
# It typically does a simple parameter object assignment, which is
# CORRECT to do BEFORE FSDP wraps the model.
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if hasattr(input_embeddings, 'weight'):
output_embeddings.weight = input_embeddings.weight
else:
# maybe returned input_embeddings a tensor directly
output_embeddings.weight = input_embeddings
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
"constant",
0,
)
print("✅ Tied input and output embeddings using standard assignment.")
else:
print(" tie_word_embeddings is False, not tying weights.")
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
# The key is to avoid calling it after accelerator.prepare().
def set_output_embeddings(self, new_embeddings):
# Your current implementation using data.copy_ is good practice,
# but the best way is to not call this after prepare().
self.lm_head = new_embeddings
def forward_speech_features(
self,
speech_tensors=None,
speech_masks=None,
speech_type="audio",
return_unmask=False
):
if speech_tensors is None:
# Use config to get vae_dim instead of non-existent self.args
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
connect_features = self.model.acoustic_connector(audio_features)
return audio_features, connect_features
else:
with torch.no_grad():
if speech_type == "audio":
with torch.no_grad():
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
elif speech_type == "vae":
# Use config to get vae_dim instead of non-existent self.args
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
# gaussian sample from the speech_mode
batch_size = speech_mode.size(0)
value = self.model.acoustic_tokenizer.fix_std / 0.8
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
else:
raise NotImplementedError(f"Speech type {speech_type} not implemented")
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
bias_factor = -audio_tokens[speech_masks].flatten().mean()
# Only use distributed operations if the process group is initialized
if dist.is_available() and dist.is_initialized():
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
world_size = dist.get_world_size()
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
self.model.speech_bias_factor.copy_(bias_factor / world_size)
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
else:
# Single process case
self.model.speech_scaling_factor.copy_(scaling_factor)
self.model.speech_bias_factor.copy_(bias_factor)
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
connect_features = self.model.acoustic_connector(audio_features)
if return_unmask:
return audio_features, connect_features
return audio_features[speech_masks], connect_features[speech_masks]
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
# New arguments for speech processing and loss calculation
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speeches_loss_input: Optional[torch.FloatTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
acoustic_input_mask: Optional[torch.BoolTensor] = None,
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
ddpm_batch_mul: int = 1,
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
x = self.get_input_embeddings()(input_ids)
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
if speeches_loss_input is not None:
# only part audio need diffuse
speech_all_features, speech_all_connect_features = self.forward_speech_features(
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
return_unmask=True
)
if speech_tensors is not None:
if semantic_speech_all_connect_features is not None:
x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
else:
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
else:
speech_features, speech_connect_features = self.forward_speech_features(
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
)
if speech_tensors is not None:
x[acoustic_input_mask] = speech_connect_features
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=x,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=False,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
# logits = logits.float()
loss = None
if labels is not None:
# The custom CE loss with masking is calculated in the training script.
# We leave the standard loss calculation here as None.
pass
# --- Diffusion Loss Calculation ---
diffusion_loss = None
# This block is executed only if we are in a context that involves speech.
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
condition_features = hidden_states[acoustic_loss_mask]
speech_len, latent_size = speech_features.shape
noise = torch.randn(
(speech_len * ddpm_batch_mul, latent_size),
device=hidden_states.device,
dtype=hidden_states.dtype
)
timesteps = torch.multinomial(
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
speech_len * ddpm_batch_mul,
replacement=True,
).to(hidden_states.device)
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
noisy_speech_features = self.model.noise_scheduler.add_noise(
speech_features_repeated, noise, timesteps
)
model_output = self.model.prediction_head(
noisy_speech_features,
timesteps.type_as(x),
condition_features_repeated
)
prediction_type = self.config.diffusion_head_config.prediction_type
if prediction_type == "epsilon":
target_for_loss = noise
elif prediction_type == "v_prediction":
target_for_loss = self.model.noise_scheduler.get_velocity(
speech_features_repeated, noise, timesteps
)
else:
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
if latent_size > 0 and ddpm_batch_mul > 0:
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
else:
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
else:
# Dummy loss for DDP to work when there are no speech samples in a batch,
# but we are in a speech context.
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
# --- End Diffusion Loss Calculation ---
if not return_dict:
output = (logits, speech_len) + outputs.to_tuple()[1:]
return (loss, diffusion_loss) + output
return VibeVoiceCausalLMOutputWithPast(
loss=loss,
diffusion_loss=diffusion_loss,
speech_token_num=speech_len if speech_tensors is not None else 0,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
__all__ = [
"VibeVoiceModel",
"VibeVoicePreTrainedModel",
"VibeVoiceForConditionalGeneration",
"VibeVoiceCausalLMOutputWithPast",
"VibeVoiceGenerationOutput",
]

View File

@ -0,0 +1,731 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
# from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice import VibeVoiceConfig
from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
from .streamer import AudioStreamer, AsyncAudioStreamer
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
@dataclass
class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
logits: Optional[torch.FloatTensor] = None
@dataclass
class VibeVoiceGenerationOutput(ModelOutput):
"""
Output type for VibeVoice generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences.
speech_outputs (`List[torch.FloatTensor]`, *optional*):
List of generated speech waveforms or latents for each speech segment.
"""
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
reach_max_step_sample: Optional[torch.BoolTensor] = None
class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
"""Constrains token generation to only valid tokens during speech generation."""
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Create a mask for valid tokens
mask = torch.full_like(scores, float('-inf'))
mask[:, self.valid_token_ids] = 0
# Apply mask to scores
scores = scores + mask
return scores
class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
# Initialize the base model
self.model = VibeVoiceModel(config)
# LM head for text generation
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
# inference configuration
self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
# Initialize weights and apply final processing
self.post_init()
@property
def noise_scheduler(self):
return self.model.noise_scheduler
@property
def prediction_head(self):
return self.model.prediction_head
@property
def speech_scaling_factor(self):
return self.model.speech_scaling_factor
@property
def speech_bias_factor(self):
return self.model.speech_bias_factor
@property
def acoustic_tokenizer(self):
return self.model.acoustic_tokenizer
@property
def semantic_tokenizer(self):
return self.model.semantic_tokenizer
@property
def acoustic_connector(self):
return self.model.acoustic_connector
@property
def semantic_connector(self):
return self.model.semantic_connector
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
"""
# Tie lm_head.weight to language_model.embed_tokens.weight
if not getattr(self.config, 'tie_word_embeddings', False):
return
if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
self.lm_head.weight = self.model.language_model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
def set_ddpm_inference_steps(self, num_steps=None):
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
"""Process speech inputs through tokenizers and connectors."""
with torch.no_grad():
if speech_type == "audio":
# Encode audio to acoustic latents
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
# Apply scaling and bias
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
# Connect to language model space
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
return acoustic_features, acoustic_connected
elif speech_type == "pt":
encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
# Apply scaling and bias
acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
# Connect to language model space
acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
return acoustic_features, acoustic_connected
else:
raise NotImplementedError(f"Speech type {speech_type} not implemented")
# @can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
logits_to_keep: Union[int, slice] = 0,
**kwargs,
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
speech_tensors (`torch.FloatTensor`, *optional*):
Input speech waveforms for voice cloning or speech understanding.
speech_masks (`torch.BoolTensor`, *optional*):
Masks indicating valid speech frames.
speech_input_mask (`torch.BoolTensor`, *optional*):
Positions in the input sequence where speech embeddings should be inserted.
Returns:
`VibeVoiceCausalLMOutputWithPast` or tuple
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get embeddings
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
# Process speech inputs if provided
if speech_tensors is not None and speech_masks is not None:
acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors.to(self.dtype), speech_masks)
if speech_input_mask is not None:
inputs_embeds[speech_input_mask] = speech_embeds
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
raise NotImplementedError("Loss computation is not implemented in this version.")
return VibeVoiceCausalLMOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
if generation_config is None:
generation_config = GenerationConfig(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id
)
else:
generation_config = GenerationConfig(
**generation_config,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id
)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config,
True,
speech_start_id=tokenizer.speech_start_id,
speech_end_id=tokenizer.speech_end_id,
speech_diffusion_id=tokenizer.speech_diffusion_id,
**kwargs
)
generation_config.speech_start_id = tokenizer.speech_start_id
generation_config.speech_end_id = tokenizer.speech_end_id
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
batch_size = inputs_tensor.shape[0]
device = self.device
self._prepare_special_tokens(generation_config, True, device=device)
generation_config.use_cache = True
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = inputs_tensor.to(self.device)
input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
max_cache_length = generation_config.max_length - 1
# Backwards compatible fix for _prepare_cache_for_generation method signature
# New transformers version expects 5 args, old version expects 6
import inspect
try:
sig = inspect.signature(self._prepare_cache_for_generation)
if len(sig.parameters) == 5:
# New transformers version (4.56+)
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
else:
# Old transformers version (pre-4.56)
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
except Exception as e:
# Fallback to try both versions
try:
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
except TypeError:
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):
model_kwargs[k] = v.to(device=device)
if return_processors:
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
device=inputs_tensor.device,
model_kwargs=model_kwargs,
)
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
else:
return generation_config, model_kwargs, input_ids
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
**kwargs,
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
"""
Generates sequences of token ids and optionally speech outputs.
Args:
All standard generation arguments from GenerationMixin
negative_prompt_ids: Negative prompt for CFG in speech generation
negative_prompt_attention_mask: Attention mask for negative prompt
speech_tensors: Input speech for voice cloning
speech_masks: Masks for speech tensors
speech_input_mask: Positions to insert speech embeddings
return_speech: Whether to decode and return speech outputs
cfg_scale: CFG scale for speech generation
stop_check_fn: Optional callable that returns True if generation should stop
Returns:
Generated token sequences and optionally speech outputs
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
parsed_scripts = kwargs.pop("parsed_scripts", None)
all_speakers_list = kwargs.pop("all_speakers_list", None)
max_length_times = kwargs.pop("max_length_times", 2)
if kwargs.get('max_new_tokens', None) is None:
kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
negative_kwargs = {
'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
'max_new_tokens': kwargs.get('max_new_tokens', 100)
}
negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
acoustic_cache = VibeVoiceTokenizerStreamingCache()
semantic_cache = VibeVoiceTokenizerStreamingCache()
batch_size = input_ids.shape[0]
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
is_prefill = True
inputs_embeds = None
verbose = kwargs.get("verbose", False)
# Initialize audio chunks storage for each sample
audio_chunks = [[] for _ in range(batch_size)]
initial_length = input_ids.shape[-1]
initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
# Define all valid tokens that can be generated
valid_tokens = [
generation_config.speech_start_id,
generation_config.speech_end_id,
generation_config.speech_diffusion_id,
generation_config.eos_token_id
]
# Add bos_token_id if it exists
if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
valid_tokens.append(generation_config.bos_token_id)
# Add custom processor to constrain token generation
token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(token_constraint_processor)
max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
# Create progress iterator if verbose
if kwargs.get("show_progress_bar", True):
progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
else:
progress_bar = range(max_steps)
for step in progress_bar:
# Check for external stop signal
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step + 1}")
# End the audio streamer if it exists
if audio_streamer is not None:
audio_streamer.end()
break
# Check if audio_streamer has been ended (stopped externally)
if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
if any(audio_streamer.finished_flags):
if verbose:
print(f"Audio generation stopped externally at step {step + 1}")
break
if finished_tags.all():
if hasattr(progress_bar, 'set_description'):
progress_bar.set_description("Generation complete")
break
if input_ids.shape[-1] >= generation_config.max_length:
print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
# Update progress bar description with active samples
if hasattr(progress_bar, 'set_description'):
active_samples = (~finished_tags).sum().item()
progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
if is_prefill:
# we process the speech inputs only during the first generation step
prefill_inputs = {
"speech_tensors": speech_tensors.to(device=device),
"speech_masks": speech_masks.to(device),
"speech_input_mask": speech_input_mask.to(device),
}
is_prefill = False
else:
_ = model_inputs.pop('inputs_embeds', None)
prefill_inputs = {'inputs_embeds': inputs_embeds}
# Forward pass through the model
outputs = self(
**model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False,
)
# Get logits and apply logits processor
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
# next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
next_token_scores = logits_processor(input_ids, next_token_logits)
# token selection
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
next_tokens[finished_tags] = generation_config.eos_token_id
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if not kwargs.get('refresh_negative', True):
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
# Forward negative pass through the model
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
negative_model_inputs['inputs_embeds'] = inputs_embeds
negative_model_inputs['input_ids'] = None
negative_outputs = self(
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
)
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
# reached end of generation
if (next_tokens == generation_config.eos_token_id).any():
eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
# Only print for samples that are newly finished (not already marked as finished)
new_eos_indices = eos_indices[~finished_tags[eos_indices]]
if new_eos_indices.numel() > 0:
finished_tags[new_eos_indices] = True
if verbose:
print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
if audio_streamer is not None:
audio_streamer.end(new_eos_indices)
# Check if any sample reached its maximum generation length
max_length_reached = step >= max_step_per_sample
new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
if new_max_length_indices.numel() > 0:
finished_tags[new_max_length_indices] = True
reach_max_step_sample[new_max_length_indices] = True
if verbose:
print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
if audio_streamer is not None:
audio_streamer.end(new_max_length_indices)
# speech_end
diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
if diffusion_end_indices.numel() > 0:
# Clear tokenizer caches for samples that reached speech end
acoustic_cache.set_to_zero(diffusion_end_indices)
semantic_cache.set_to_zero(diffusion_end_indices)
# speech_begin
diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
# update attention mask
for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
# update past key values
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
# Process each non-diffusion sample
for sample_idx in diffusion_start_indices.tolist():
# Shift cache for this sample
k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
# update negative_input_ids
for sample_idx in diffusion_start_indices.tolist():
negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
# Prepare inputs_embeds for next iteration
# Initialize with default embeddings for all tokens
next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
# forward diffusion
# Diffusion indices are those that are not finished and not special tokens
diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
if diffusion_indices.numel() > 0:
if kwargs.get('refresh_negative', True):
negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
# Forward negative pass through the model
if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
negative_model_inputs['inputs_embeds'] = inputs_embeds
negative_model_inputs['input_ids'] = None
negative_outputs = self(
**negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
)
negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
# correct the non-diffusion indices
# we forward all samples' negative outputs even if
# they are not in diffusion mode to keep the cache consistent
# So we need to correct the kv cache of non-diffusion samples
non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
if non_diffusion_mask.any():
non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
start_indices = correct_cnt[non_diffusion_indices]
# 1. Update attention_mask - need to handle each sample separately
seq_len = negative_model_kwargs['attention_mask'].shape[1]
for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
# Shift the attention mask for this sample
if start_idx + 1 < seq_len - 1:
negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
# 2. Update past_key_values
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
# Process each non-diffusion sample
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
if start_idx + 1 < k_cache.shape[2] - 1:
# Shift cache for this sample
k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
# 3. Update negative_input_ids
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
if start_idx + 1 < negative_input_ids.shape[1] - 1:
negative_input_ids[sample_idx, start_idx+1:] = \
negative_input_ids[sample_idx, start_idx:-1].clone()
correct_cnt[non_diffusion_indices] += 1
positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
speech_latent = self.sample_speech_tokens(
positive_condition,
negative_condition,
cfg_scale=cfg_scale,
).unsqueeze(1)
# Decode acoustic latent to audio using acoustic streaming cache
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache, # Use acoustic-specific cache
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
use_cache=True,
debug=False
)
# Store audio chunks for each sample
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
# Only append audio chunk if the sample is not finished
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
# Add streaming support here
if audio_streamer is not None:
# Stream the audio chunks immediately
audio_streamer.put(audio_chunk, diffusion_indices)
# Encode audio to semantic features using semantic streaming cache
semantic_features = self.model.semantic_tokenizer.encode(
audio_chunk,
cache=semantic_cache, # Use semantic-specific cache
sample_indices=diffusion_indices,
use_cache=True,
debug=False
).mean # semantic tokenizer has no VAE.
# Combine acoustic and semantic features for next input
acoustic_embed = self.model.acoustic_connector(speech_latent)
semantic_embed = self.model.semantic_connector(semantic_features)
diffusion_embeds = acoustic_embed + semantic_embed
# Update embeddings for diffusion indices
next_inputs_embeds[diffusion_indices] = diffusion_embeds
# Set inputs_embeds for next iteration
inputs_embeds = next_inputs_embeds
if audio_streamer is not None:
audio_streamer.end()
# Concatenate audio chunks for each sample
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
# Concatenate all chunks along the time dimension (assumed to be the last dimension)
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
# If no audio was generated for this sample, append None
final_audio_outputs.append(None)
return VibeVoiceGenerationOutput(
sequences=input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
@torch.no_grad()
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
for t in self.model.noise_scheduler.timesteps:
half = speech[: len(speech) // 2]
combined = torch.cat([half, half], dim=0)
eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
return speech[: len(speech) // 2]
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
__all__ = [
"VibeVoiceForConditionalGenerationInference",
]

View File

@ -0,0 +1,287 @@
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.modeling_utils import PreTrainedModel
# from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.activations import ACT2FN
from transformers.utils import logging
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
logger = logging.get_logger(__name__)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
def modulate(x, shift, scale):
"""Apply modulation to input tensor."""
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
Args:
hidden_size (`int`): Size of the output embedding
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(hidden_size, hidden_size, bias=False),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
These may be fractional.
dim (`int`): The dimension of the output.
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
Returns:
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding.to(t.dtype)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForwardNetwork(nn.Module):
"""
Standard feed-forward network with SwiGLU activation.
Args:
embed_dim (`int`): Input dimension
ffn_dim (`int`): Hidden dimension
"""
def __init__(
self,
embed_dim,
ffn_dim,
):
super().__init__()
self.embed_dim = embed_dim
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
# SwiGLU activation
# gate = F.silu(gate)
gate = self.act_fn(gate)
return self.down_proj(gate * up)
class HeadLayer(nn.Module):
"""
A layer in the diffusion head.
Args:
embed_dim (`int`): Input dimension
ffn_dim (`int`): Hidden dimension
cond_dim (`int`): Condition embedding dimension
norm_eps (`float`, optional): Epsilon for normalization
"""
def __init__(
self,
embed_dim,
ffn_dim,
cond_dim,
norm_eps=1e-5,
):
super().__init__()
self.embed_dim = embed_dim
self.cond_dim = cond_dim
self.ffn_dim = ffn_dim
self.ffn = FeedForwardNetwork(
self.embed_dim,
self.ffn_dim,
)
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
)
def forward(self, x, c):
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
return x
class FinalLayer(nn.Module):
"""
Final layer in the diffusion head.
Args:
hidden_size (`int`): Input dimension
output_size (`int`): Output dimension
cond_size (`int`): Condition embedding dimension
norm_eps (`float`, optional): Epsilon for normalization
"""
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
super().__init__()
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
self.linear = nn.Linear(hidden_size, output_size, bias=False)
self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(cond_size, 2 * hidden_size, bias=False)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class VibeVoiceDiffusionHead(PreTrainedModel):
"""
Diffusion head model for vibevoice.
Args:
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
"""
config_class = VibeVoiceDiffusionHeadConfig
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
config,
):
super().__init__(config)
self.config = config
self.cond_dim = config.hidden_size
latent_size = config.latent_size
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
self.t_embedder = TimestepEmbedder(self.cond_dim)
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
# Create the intermediate layers
self.layers = nn.ModuleList([
HeadLayer(
embed_dim=config.hidden_size,
ffn_dim=ffn_dim,
cond_dim=self.cond_dim,
norm_eps=config.rms_norm_eps
)
for _ in range(config.head_layers)
])
# Final layer for output
self.final_layer = FinalLayer(
hidden_size=config.hidden_size,
output_size=latent_size,
cond_size=self.cond_dim,
norm_eps=config.rms_norm_eps
)
self.initialize_weights()
def initialize_weights(self):
"""Initialize the weights of the model."""
# Initialize timestep embedder
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers
for layer in self.layers:
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
# Zero-out output layers
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
def forward(
self,
noisy_images,
timesteps,
condition,
):
"""
Forward pass of the prediction head.
Args:
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
timesteps (`torch.Tensor`): Timesteps for diffusion
condition (`torch.Tensor`): Conditioning information
Returns:
`torch.Tensor`: The predicted noise/velocity
"""
x = self.noisy_images_proj(noisy_images)
t = self.t_embedder(timesteps)
condition = self.cond_proj(condition)
c = condition + t
for layer in self.layers:
x = layer(x, c)
x = self.final_layer(x, c)
return x
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
__all__ = [
"VibeVoiceDiffusionHead",
]

View File

@ -0,0 +1,214 @@
"""Tokenization classes for vibevoice."""
from typing import List, Optional, Union
from transformers.utils import logging
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
logger = logging.get_logger(__name__)
class VibeVoiceTextTokenizer(Qwen2Tokenizer):
"""
Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token.
bos_token (`str`, *optional*):
The beginning of sequence token. Not used for vibevoice.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding.
add_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to add special tokens when encoding.
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
add_special_tokens=True,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
add_special_tokens=add_special_tokens,
**kwargs,
)
# Add VibeVoice-specific special tokens
self._add_vibevoice_special_tokens()
def _add_vibevoice_special_tokens(self):
"""Add VibeVoice-specific special tokens."""
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>", # Speech start (reusing vision tokens)
"<|vision_end|>", # Speech end
"<|vision_pad|>", # Speech diffusion pad
]
}
num_added = self.add_special_tokens(special_tokens)
# Cache special token IDs
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
return num_added
@property
def eos_id(self) -> int:
"""Id of the end of sequence token."""
return self._eos_id
@property
def speech_start_id(self) -> int:
"""Id of the speech start token."""
return self._speech_start_id
@property
def speech_end_id(self) -> int:
"""Id of the speech end token."""
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
"""Id of the speech diffusion token."""
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
"""Id used for padding (returns -100 for loss masking)."""
return -100
class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
"""
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
Based on the Qwen2 tokenizer with additional special tokens for speech.
Args:
vocab_file (`str`, *optional*):
Path to the vocabulary file.
merges_file (`str`, *optional*):
Path to the merges file.
tokenizer_file (`str`, *optional*):
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token.
bos_token (`str`, *optional*):
The beginning of sequence token. Not used for vibevoice.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding.
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
# Add VibeVoice-specific special tokens
self._add_vibevoice_special_tokens()
def _add_vibevoice_special_tokens(self):
"""Add VibeVoice-specific special tokens."""
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>", # Speech start (reusing vision tokens)
"<|vision_end|>", # Speech end
"<|vision_pad|>", # Speech diffusion pad
]
}
num_added = self.add_special_tokens(special_tokens)
# Cache special token IDs
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
# self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
self._eos_id = self.eos_token_id # qwen2 / qwen3
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
return num_added
@property
def eos_id(self) -> int:
"""Id of the end of sequence token."""
return self._eos_id
@property
def speech_start_id(self) -> int:
"""Id of the speech start token."""
return self._speech_start_id
@property
def speech_end_id(self) -> int:
"""Id of the speech end token."""
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
"""Id of the speech diffusion token."""
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
"""Id used for padding (returns -100 for loss masking)."""
return self._pad_id
__all__ = [
"VibeVoiceTextTokenizer",
"VibeVoiceTextTokenizerFast",
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,264 @@
from __future__ import annotations
import torch
import asyncio
from queue import Queue
from typing import TYPE_CHECKING, Optional
from transformers.generation import BaseStreamer
class AudioStreamer(BaseStreamer):
"""
Audio streamer that stores audio chunks in queues for each sample in the batch.
This allows streaming audio generation for multiple samples simultaneously.
Parameters:
batch_size (`int`):
The batch size for generation
stop_signal (`any`, *optional*):
The signal to put in the queue when generation ends. Defaults to None.
timeout (`float`, *optional*):
The timeout for the audio queue. If `None`, the queue will block indefinitely.
"""
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
self.batch_size = batch_size
self.stop_signal = stop_signal
self.timeout = timeout
# Create a queue for each sample in the batch
self.audio_queues = [Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.sample_indices_map = {} # Maps from sample index to queue index
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
"""
Receives audio chunks and puts them in the appropriate queues.
Args:
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
sample_indices: Tensor indicating which samples these chunks belong to
"""
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
# Convert to numpy or keep as tensor based on preference
audio_chunk = audio_chunks[i].detach().cpu()
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
def end(self, sample_indices: Optional[torch.Tensor] = None):
"""
Signals the end of generation for specified samples or all samples.
Args:
sample_indices: Optional tensor of sample indices to end. If None, ends all.
"""
if sample_indices is None:
# End all samples
for idx in range(self.batch_size):
if not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
else:
# End specific samples
for sample_idx in sample_indices:
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
if idx < self.batch_size and not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
def __iter__(self):
"""Returns an iterator over the batch of audio streams."""
return AudioBatchIterator(self)
def get_stream(self, sample_idx: int):
"""Get the audio stream for a specific sample."""
if sample_idx >= self.batch_size:
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
return AudioSampleIterator(self, sample_idx)
class AudioSampleIterator:
"""Iterator for a single audio stream from the batch."""
def __init__(self, streamer: AudioStreamer, sample_idx: int):
self.streamer = streamer
self.sample_idx = sample_idx
def __iter__(self):
return self
def __next__(self):
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
if value == self.streamer.stop_signal:
raise StopIteration()
return value
class AudioBatchIterator:
"""Iterator that yields audio chunks for all samples in the batch."""
def __init__(self, streamer: AudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __iter__(self):
return self
def __next__(self):
if not self.active_samples:
raise StopIteration()
batch_chunks = {}
samples_to_remove = set()
# Try to get chunks from all active samples
for idx in self.active_samples:
try:
value = self.streamer.audio_queues[idx].get(block=False)
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except:
# Queue is empty for this sample, skip it this iteration
pass
# Remove finished samples
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
# If no chunks were ready but we still have active samples,
# wait a bit and try again
import time
time.sleep(0.01)
return self.__next__()
else:
raise StopIteration()
class AsyncAudioStreamer(AudioStreamer):
"""
Async version of AudioStreamer for use in async contexts.
"""
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
super().__init__(batch_size, stop_signal, timeout)
# Replace regular queues with async queues
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
"""Put audio chunks in the appropriate async queues."""
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
audio_chunk = audio_chunks[i].detach().cpu()
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, audio_chunk
)
def end(self, sample_indices: Optional[torch.Tensor] = None):
"""Signal the end of generation for specified samples."""
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
async def get_stream(self, sample_idx: int):
"""Get async iterator for a specific sample's audio stream."""
if sample_idx >= self.batch_size:
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
while True:
value = await self.audio_queues[sample_idx].get()
if value == self.stop_signal:
break
yield value
def __aiter__(self):
"""Returns an async iterator over all audio streams."""
return AsyncAudioBatchIterator(self)
class AsyncAudioBatchIterator:
"""Async iterator for batch audio streaming."""
def __init__(self, streamer: AsyncAudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __aiter__(self):
return self
async def __anext__(self):
if not self.active_samples:
raise StopAsyncIteration()
batch_chunks = {}
samples_to_remove = set()
# Create tasks for all active samples
tasks = {
idx: asyncio.create_task(self._get_chunk(idx))
for idx in self.active_samples
}
# Wait for at least one chunk to be ready
done, pending = await asyncio.wait(
tasks.values(),
return_when=asyncio.FIRST_COMPLETED,
timeout=self.streamer.timeout
)
# Cancel pending tasks
for task in pending:
task.cancel()
# Process completed tasks
for idx, task in tasks.items():
if task in done:
try:
value = await task
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except asyncio.CancelledError:
pass
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
# Try again if we still have active samples
return await self.__anext__()
else:
raise StopAsyncIteration()
async def _get_chunk(self, idx):
"""Helper to get a chunk from a specific queue."""
return await self.streamer.audio_queues[idx].get()

View File

View File

@ -0,0 +1,677 @@
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re
import numpy as np
import torch
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import TensorType, logging
from .vibevoice_tokenizer_processor import AudioNormalizer
logger = logging.get_logger(__name__)
class VibeVoiceProcessor:
r"""
Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
[`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
Args:
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
The tokenizer for text processing.
audio_processor (`VibeVoiceTokenizerProcessor`):
The audio processor for speech processing.
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
The compression ratio for speech tokenization.
db_normalize (`bool`, *optional*, defaults to True):
Whether to apply decibel normalization to audio inputs.
"""
def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.db_normalize = db_normalize
self.audio_normalizer = AudioNormalizer() if db_normalize else None
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model
- a path to a *directory* containing processor config
Returns:
[`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
"""
import os
import json
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
VibeVoiceTextTokenizer,
VibeVoiceTextTokenizerFast
)
# Load processor configuration
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
config = {
"speech_tok_compress_ratio": 3200,
"db_normalize": True,
}
# Extract main processor parameters
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
db_normalize = config.get("db_normalize", True)
# Load tokenizer - try from model path first, then fallback to Qwen
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
if 'qwen' in language_model_pretrained_name.lower():
tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
language_model_pretrained_name,
**kwargs
)
else:
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
# Load audio processor
if "audio_processor" in config:
# Create audio processor from config
audio_config = config["audio_processor"]
audio_processor = VibeVoiceTokenizerProcessor(
sampling_rate=audio_config.get("sampling_rate", 24000),
normalize_audio=audio_config.get("normalize_audio", True),
target_dB_FS=audio_config.get("target_dB_FS", -25),
eps=audio_config.get("eps", 1e-6),
)
else:
# Create default audio processor
audio_processor = VibeVoiceTokenizerProcessor()
# Create and return the processor
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
db_normalize=db_normalize,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
"""
Save a processor to a directory, so that it can be re-loaded using the
[`~VibeVoiceProcessor.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the processor will be saved.
"""
import os
import json
os.makedirs(save_directory, exist_ok=True)
# Save processor configuration
processor_config = {
"processor_class": "VibeVoiceProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"db_normalize": self.db_normalize,
"audio_processor": {
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
"sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
"normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
"target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
"eps": getattr(self.audio_processor, 'eps', 1e-6),
}
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, 'w') as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path}")
def __call__(
self,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
**kwargs,
) -> BatchEncoding:
"""
Main method to process one or more podcast scripts with optional voice samples.
Args:
text (`str`, `List[str]`):
The input text(s) to process. Can be:
- A single script string
- A list of script strings for batch processing
- A path to a .json or .txt file
- A list of paths
voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
Voice samples for each script. Can be:
- A list of samples for a single script
- A list of lists for batch processing
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
Whether to pad sequences to the same length
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
Whether to truncate sequences
max_length (`int`, *optional*):
Maximum length of the returned sequences
return_tensors (`str` or `TensorType`, *optional*):
If set, will return tensors of a particular framework
return_attention_mask (`bool`, defaults to `True`):
Whether to return the attention mask
Returns:
`BatchEncoding`: A BatchEncoding with the following fields:
- **input_ids** -- List of token id sequences or tensor
- **attention_mask** -- List of attention masks or tensor
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
- **speech_masks** -- Speech masks (if voice_samples provided)
- **speech_input_mask** -- Boolean masks indicating speech token positions
"""
# Handle single vs batch input
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
# Single input
texts = [text]
is_batched = False
else:
# Batch input
texts = text
is_batched = True
# Handle voice samples
if voice_samples is not None:
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
# Single set of voice samples
voice_samples_list = [voice_samples]
else:
# Batch of voice samples
voice_samples_list = voice_samples
else:
voice_samples_list = [None] * len(texts)
# Process each input
all_encodings = []
for text_input, voice_input in zip(texts, voice_samples_list):
encoding = self._process_single(text_input, voice_input)
all_encodings.append(encoding)
# Combine batch
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
)
return batch_encoding
def _process_single(
self,
text: Union[str, TextInput],
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
) -> Dict[str, Any]:
"""Process a single podcast script."""
# Determine if text is a file path or direct script
script = None
if isinstance(text, str):
# Check if it's a file path
if text.endswith('.json') and os.path.exists(text):
script = self._convert_json_to_script(text)
elif text.endswith('.txt') and os.path.exists(text):
script = self._convert_text_to_script(text)
else:
# Assume it's the script content directly
script = text
if script is None:
raise ValueError(f"Could not process input text: {text}")
# Parse the script
parsed_lines = self._parse_script(script)
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
# Create system prompt
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
system_tokens = self.tokenizer.encode(self.system_prompt)
# Process voice samples if provided
if voice_samples:
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
else:
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
# Build full token sequence
full_tokens = system_tokens + voice_tokens
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
# Add text input section
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
for speaker_id, speaker_text in parsed_lines:
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
full_tokens += speaker_text_tokens
speech_input_mask += [False] * len(speaker_text_tokens)
# Add speech output section
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
return {
"input_ids": full_tokens,
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
"speech_input_mask": speech_input_mask,
"parsed_script": parsed_lines,
"all_speakers": all_speakers,
}
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
) -> BatchEncoding:
"""Combine multiple encodings into a batch with padding."""
# Extract input_ids and create attention_mask
input_ids_list = [enc["input_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
# Determine padding strategy
if isinstance(padding, bool):
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
elif isinstance(padding, str):
padding_strategy = PaddingStrategy(padding)
else:
padding_strategy = padding
# Apply padding to input_ids
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
if padding_strategy == PaddingStrategy.LONGEST:
max_len = max(len(ids) for ids in input_ids_list)
elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
max_len = max_length
else:
max_len = max(len(ids) for ids in input_ids_list)
# Pad sequences
padded_input_ids = []
attention_masks = []
padded_speech_input_masks = []
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
# Truncate if needed
if truncation and len(input_ids) > max_len:
input_ids = input_ids[:max_len]
speech_mask = speech_mask[:max_len]
# Pad
padding_length = max_len - len(input_ids)
# padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
attention_mask = [0] * padding_length + [1] * len(input_ids)
padded_speech_mask = [False] * padding_length + speech_mask
padded_input_ids.append(padded_ids)
attention_masks.append(attention_mask)
padded_speech_input_masks.append(padded_speech_mask)
input_ids_list = padded_input_ids
speech_input_masks_list = padded_speech_input_masks
else:
# No padding, just create attention masks
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
# Process speech inputs
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
# Prepare batch encoding
batch_encoding = BatchEncoding()
# Handle tensor conversion
if return_tensors is not None:
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
else:
batch_encoding["input_ids"] = input_ids_list
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
# Process speech tensors if present
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs,
return_tensors=return_tensors,
)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
# Add metadata
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
return batch_encoding
def _create_voice_prompt(
self,
speaker_samples: List[Union[str, np.ndarray]]
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
"""
Create voice prompt tokens and process audio samples.
Returns:
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
"""
vae_token_id = self.tokenizer.speech_diffusion_id
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
voice_speech_inputs = []
voice_speech_masks = [False] * len(voice_full_tokens)
for speaker_id, speaker_audio in enumerate(speaker_samples):
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
# Process audio
if isinstance(speaker_audio, str):
# Load audio from file
wav = self.audio_processor._load_audio_from_path(speaker_audio)
else:
wav = np.array(speaker_audio, dtype=np.float32)
# Apply normalization if needed
if self.db_normalize and self.audio_normalizer:
wav = self.audio_normalizer(wav)
# Calculate token length based on compression ratio
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
# vae_tok_len = wav.shape[0]
# else:
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
# Build tokens and masks
speaker_tokens = (prefix_tokens +
[self.tokenizer.speech_start_id] +
[vae_token_id] * vae_tok_len +
[self.tokenizer.speech_end_id] +
self.tokenizer.encode('\n', add_special_tokens=False))
vae_input_mask = ([False] * len(prefix_tokens) +
[False] +
[True] * vae_tok_len +
[False] +
[False])
voice_full_tokens.extend(speaker_tokens)
voice_speech_masks.extend(vae_input_mask)
voice_speech_inputs.append(wav)
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Dict[str, Any]:
"""
Prepare speech inputs for model consumption.
Args:
speech_inputs: List of speech arrays
return_tensors: Output tensor type
device: Device to place tensors on
dtype: Data type for tensors
Returns:
Dictionary with padded_speeches and speech_masks
"""
if not speech_inputs:
return {"padded_speeches": None, "speech_masks": None}
# Calculate sequence lengths
vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
max_speech_length = max(s.shape[0] for s in speech_inputs)
# Pad speeches
if speech_inputs[0].ndim == 1:
padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
else:
padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
padded_speeches[i, :len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {
"padded_speeches": padded_speeches,
"speech_masks": speech_masks,
}
# Convert to tensors if requested
if return_tensors == "pt":
result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
return result
def _convert_json_to_script(self, json_file: str) -> str:
"""
Convert JSON format to script format.
Expected JSON format:
[
{"speaker": "1", "text": "Hello everyone..."},
{"speaker": "2", "text": "Great to be here..."}
]
"""
import json
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("JSON file must contain a list of speaker entries")
script_lines = []
for item in data:
if not isinstance(item, dict):
logger.warning(f"Skipping non-dict entry: {item}")
continue
speaker = item.get('speaker')
text = item.get('text')
if speaker is None or text is None:
logger.warning(f"Skipping entry missing speaker or text: {item}")
continue
# Ensure speaker ID is valid
try:
speaker_id = int(speaker)
except (ValueError, TypeError):
logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
continue
# Clean up text
text = text.strip()
if text:
script_lines.append(f"Speaker {speaker_id}: {text}")
if not script_lines:
raise ValueError("No valid entries found in JSON file")
return "\n".join(script_lines)
def _convert_text_to_script(self, text_file: str) -> str:
"""
Convert text file to script format.
Handles multiple formats:
1. Already formatted as "Speaker X: text"
2. Plain text (assigns to Speaker 1)
Handles edge cases like multiple colons in a line.
"""
with open(text_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
script_lines = []
current_speaker = 1
for line in lines:
line = line.strip()
if not line:
continue
# Try to parse as "Speaker X: text" format
# Use regex to be more robust
speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
if speaker_match:
speaker_id = int(speaker_match.group(1))
text = speaker_match.group(2).strip()
if text:
script_lines.append(f"Speaker {speaker_id}: {text}")
else:
# Treat as plain text - assign to current speaker
script_lines.append(f"Speaker {current_speaker}: {line}")
if not script_lines:
raise ValueError("No valid content found in text file")
return "\n".join(script_lines)
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
"""Parse script into list of (speaker_id, text) tuples."""
lines = script.strip().split("\n")
parsed_lines = []
speaker_ids = []
# First pass: parse all lines and collect speaker IDs
for line in lines:
if not line.strip():
continue
# Use regex to handle edge cases like multiple colons
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
if match:
speaker_id = int(match.group(1))
text = ' ' + match.group(2).strip()
parsed_lines.append((speaker_id, text))
speaker_ids.append(speaker_id)
else:
logger.warning(f"Could not parse line: '{line}'")
if not parsed_lines:
raise ValueError("No valid speaker lines found in script")
# Check if we need to normalize speaker IDs (only if all are > 0)
min_speaker_id = min(speaker_ids)
if min_speaker_id > 0:
# Normalize to start from 0
normalized_lines = []
for speaker_id, text in parsed_lines:
normalized_lines.append((speaker_id - 1, text))
return normalized_lines
else:
# Keep original IDs
return parsed_lines
def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
"""Merge text and audio inputs into a single BatchEncoding."""
# Start with text inputs
merged = BatchEncoding(text_inputs)
# Add audio-specific fields
if "audio" in audio_inputs:
merged["speech_inputs"] = audio_inputs["audio"]
if "streaming" in audio_inputs:
merged["streaming"] = audio_inputs["streaming"]
return merged
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
"""
Return the list of inputs accepted by the model.
"""
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
def save_audio(self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> str:
"""
Save audio data to a file.
Args:
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
The audio data to save. Can be a single tensor/array or a list of them.
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
Returns:
str: The path to the saved audio file.
"""
return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
__all__ = [
"VibeVoiceProcessor",
]

View File

@ -0,0 +1,483 @@
"""
Processor class for VibeVoice models.
"""
import os
import json
import warnings
from typing import List, Optional, Union, Dict, Any
import numpy as np
import torch
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AudioNormalizer:
"""
Audio normalization class for VibeVoice tokenizer.
This class provides audio normalization to ensure consistent input levels
for the VibeVoice tokenizer while maintaining audio quality.
"""
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
"""
Initialize the audio normalizer.
Args:
target_dB_FS (float): Target dB FS level for the audio. Default: -25
eps (float): Small value to avoid division by zero. Default: 1e-6
"""
self.target_dB_FS = target_dB_FS
self.eps = eps
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
"""
Adjust the audio to the target dB FS level.
Args:
audio (np.ndarray): Input audio signal
Returns:
tuple: (normalized_audio, rms, scalar)
"""
rms = np.sqrt(np.mean(audio**2))
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
normalized_audio = audio * scalar
return normalized_audio, rms, scalar
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
"""
Avoid clipping by scaling down if necessary.
Args:
audio (np.ndarray): Input audio signal
scalar (float, optional): Explicit scaling factor
Returns:
tuple: (normalized_audio, scalar)
"""
if scalar is None:
max_val = np.max(np.abs(audio))
if max_val > 1.0:
scalar = max_val + self.eps
else:
scalar = 1.0
return audio / scalar, scalar
def __call__(self, audio: np.ndarray) -> np.ndarray:
"""
Normalize the audio by adjusting to target dB FS and avoiding clipping.
Args:
audio (np.ndarray): Input audio signal
Returns:
np.ndarray: Normalized audio signal
"""
# First adjust to target dB FS
audio, _, _ = self.tailor_dB_FS(audio)
# Then avoid clipping
audio, _ = self.avoid_clipping(audio)
return audio
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
"""
Processor for VibeVoice acoustic tokenizer models.
This processor handles audio preprocessing for VibeVoice models, including:
- Audio format conversion (stereo to mono)
- Optional audio normalization
- Streaming support for infinite-length audio
Args:
sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
"""
model_input_names = ["input_features"]
def __init__(
self,
sampling_rate: int = 24000,
normalize_audio: bool = True,
target_dB_FS: float = -25,
eps: float = 1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.sampling_rate = sampling_rate
self.normalize_audio = normalize_audio
# Initialize audio normalizer if needed
if self.normalize_audio:
self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
else:
self.normalizer = None
# Save config
self.feature_extractor_dict = {
"sampling_rate": sampling_rate,
"normalize_audio": normalize_audio,
"target_dB_FS": target_dB_FS,
"eps": eps,
}
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
"""
Convert stereo audio to mono if needed.
Args:
audio (np.ndarray): Input audio array
Returns:
np.ndarray: Mono audio array
"""
if len(audio.shape) == 1:
return audio
elif len(audio.shape) == 2:
if audio.shape[0] == 2: # (2, time)
return np.mean(audio, axis=0)
elif audio.shape[1] == 2: # (time, 2)
return np.mean(audio, axis=1)
else:
# If one dimension is 1, squeeze it
if audio.shape[0] == 1:
return audio.squeeze(0)
elif audio.shape[1] == 1:
return audio.squeeze(1)
else:
raise ValueError(f"Unexpected audio shape: {audio.shape}")
else:
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
"""
Process a single audio array.
Args:
audio: Single audio input
Returns:
np.ndarray: Processed audio
"""
# Convert to numpy array
if not isinstance(audio, np.ndarray):
audio = np.array(audio, dtype=np.float32)
else:
audio = audio.astype(np.float32)
# Ensure mono
audio = self._ensure_mono(audio)
# Normalize if requested
if self.normalize_audio and self.normalizer is not None:
audio = self.normalizer(audio)
return audio
def __call__(
self,
audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[str] = None,
**kwargs,
):
"""
Process audio for VibeVoice models.
Args:
audio: Audio input(s) to process. Can be:
- str: Path to audio file
- np.ndarray: Audio array
- List[float]: Audio as list of floats
- List[np.ndarray]: Batch of audio arrays
- List[str]: Batch of audio file paths
sampling_rate (int, optional): Sampling rate of the input audio
return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
Returns:
dict: Processed audio inputs with keys:
- input_features: Audio tensor(s) ready for the model
"""
if audio is None:
raise ValueError("Audio input is required")
# Validate sampling rate
if sampling_rate is not None and sampling_rate != self.sampling_rate:
logger.warning(
f"Input sampling rate ({sampling_rate}) differs from expected "
f"sampling rate ({self.sampling_rate}). Please resample your audio."
)
# Handle different input types
if isinstance(audio, str):
# Single audio file path
audio = self._load_audio_from_path(audio)
is_batched = False
elif isinstance(audio, list):
if len(audio) == 0:
raise ValueError("Empty audio list provided")
# Check if it's a list of file paths
if all(isinstance(item, str) for item in audio):
# Batch of audio file paths
audio = [self._load_audio_from_path(path) for path in audio]
is_batched = True
else:
# Check if it's batched audio arrays
is_batched = isinstance(audio[0], (np.ndarray, list))
else:
# Single audio array or list
is_batched = False
# Process audio
if is_batched:
processed_audio = [self._process_single_audio(a) for a in audio]
else:
processed_audio = [self._process_single_audio(audio)]
# Convert to tensors if requested
if return_tensors == "pt":
if len(processed_audio) == 1:
# Create a proper batch dimension (B, T)
input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
else:
# For batched input with different lengths, create a batch properly
input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1)
elif return_tensors == "np":
if len(processed_audio) == 1:
input_features = processed_audio[0][np.newaxis, np.newaxis, :]
else:
input_features = np.stack(processed_audio)[:, np.newaxis, :]
else:
input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio
outputs = {
"audio": input_features, # Use "audio" instead of "input_features"
}
return outputs
def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
"""
Load audio from file path.
Args:
audio_path (str): Path to audio file
Returns:
np.ndarray: Loaded audio array
"""
# Get file extension to determine loading method
file_ext = os.path.splitext(audio_path)[1].lower()
if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
# Audio file - use librosa
import librosa
audio_array, sr = librosa.load(
audio_path,
sr=self.sampling_rate,
mono=True
)
return audio_array
elif file_ext == '.pt':
# PyTorch tensor file
audio_tensor = torch.load(audio_path, map_location='cpu').squeeze()
if isinstance(audio_tensor, torch.Tensor):
audio_array = audio_tensor.numpy()
else:
audio_array = np.array(audio_tensor)
return audio_array.astype(np.float32)
elif file_ext == '.npy':
# NumPy file
audio_array = np.load(audio_path)
return audio_array.astype(np.float32)
else:
raise ValueError(
f"Unsupported file format: {file_ext}. "
f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
)
def preprocess_audio(
self,
audio_path_or_array: Union[str, np.ndarray],
normalize: Optional[bool] = None,
) -> np.ndarray:
"""
Convenience method to preprocess audio from file path or array.
This method is kept for backward compatibility but __call__ is recommended.
Args:
audio_path_or_array: Path to audio file or numpy array
normalize: Whether to normalize (overrides default setting)
Returns:
np.ndarray: Preprocessed audio array
"""
if isinstance(audio_path_or_array, str):
audio_array = self._load_audio_from_path(audio_path_or_array)
else:
audio_array = np.array(audio_path_or_array, dtype=np.float32)
# Override normalization setting if specified
original_normalize = self.normalize_audio
if normalize is not None:
self.normalize_audio = normalize
try:
processed = self._process_single_audio(audio_array)
finally:
# Restore original setting
self.normalize_audio = original_normalize
return processed
# Override to_dict method for configuration saving
def to_dict(self) -> Dict[str, Any]:
"""
Convert the object to a dict containing all attributes needed for serialization.
"""
return self.feature_extractor_dict
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
):
"""
Save audio data to WAV file(s).
Args:
audio: Audio data to save. Can be:
- torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
- np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
- List of tensors or arrays
output_path: Path where to save the audio. If saving multiple files,
this is treated as a directory and individual files will be saved inside.
sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
normalize: Whether to normalize audio before saving.
batch_prefix: Prefix for batch files when saving multiple audios.
Returns:
List[str]: Paths to the saved audio files.
"""
if sampling_rate is None:
sampling_rate = self.sampling_rate
try:
import soundfile as sf
except ImportError:
raise ImportError(
"soundfile is required to save audio files. "
"Install it with: pip install soundfile"
)
# Ensure audio is in the right format
if isinstance(audio, torch.Tensor):
# Convert PyTorch tensor to numpy
audio_np = audio.float().detach().cpu().numpy()
elif isinstance(audio, np.ndarray):
audio_np = audio
elif isinstance(audio, list):
# Handle list of tensors or arrays
if all(isinstance(a, torch.Tensor) for a in audio):
audio_np = [a.float().detach().cpu().numpy() for a in audio]
else:
audio_np = audio
else:
raise ValueError(f"Unsupported audio type: {type(audio)}")
saved_paths = []
# Handle based on shape or type
if isinstance(audio_np, list):
# Multiple separate audios to save
output_dir = output_path
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Save each audio
for i, audio_item in enumerate(audio_np):
audio_item = self._prepare_audio_for_save(audio_item, normalize)
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
sf.write(file_path, audio_item, sampling_rate)
saved_paths.append(file_path)
else:
# Handle different dimensions
if len(audio_np.shape) >= 3: # (B, C, T) or similar
# Get batch size
batch_size = audio_np.shape[0]
if batch_size > 1:
# Multiple audios in a batch
output_dir = output_path
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Save each audio in the batch
for i in range(batch_size):
# Extract single audio and remove channel dim if present
single_audio = audio_np[i]
if len(single_audio.shape) > 1:
if single_audio.shape[0] == 1: # (1, T)
single_audio = single_audio.squeeze(0)
single_audio = self._prepare_audio_for_save(single_audio, normalize)
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
sf.write(file_path, single_audio, sampling_rate)
saved_paths.append(file_path)
else:
# Single audio with batch and channel dims
audio_item = audio_np.squeeze() # Remove batch and channel dimensions
audio_item = self._prepare_audio_for_save(audio_item, normalize)
sf.write(output_path, audio_item, sampling_rate)
saved_paths.append(output_path)
else:
# Single audio without batch dimension
audio_item = self._prepare_audio_for_save(audio_np, normalize)
sf.write(output_path, audio_item, sampling_rate)
saved_paths.append(output_path)
return saved_paths
def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
"""
Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
Args:
audio: Audio data as numpy array
normalize: Whether to normalize audio
Returns:
np.ndarray: Processed audio ready for saving
"""
# Ensure right dimensionality
if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T)
audio = audio.squeeze(0)
# Normalize if requested
if normalize:
max_val = np.abs(audio).max()
if max_val > 0:
audio = audio / max_val
return audio
__all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"]

View File

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,19 @@
import math
import torch
class UniformSampler:
def __init__(self, timesteps = 1000):
self.timesteps = timesteps
def sample(self, batch_size, device):
return torch.randint(0, self.timesteps, (batch_size,), device=device)
class LogitNormalSampler:
def __init__(self, timesteps = 1000, m = 0, s = 1):
self.timesteps = timesteps
timesteps = torch.linspace(0, 1, timesteps)
logit = torch.log(timesteps / (1 - timesteps))
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
def sample(self, batch_size, device):
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)

View File

View File

@ -0,0 +1,166 @@
#!/usr/bin/env python
# coding=utf-8
import argparse
import json
import os
from pathlib import Path
import re
import torch
from typing import Dict, List, Tuple
from vibevoice.modular.configuration_vibevoice import (
VibeVoiceConfig
)
from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
from transformers.utils import logging
logger = logging.get_logger(__name__)
def convert_vibevoice_nnscaler_checkpoint_to_hf(
checkpoint_path: str,
pytorch_dump_folder_path: str,
config_path: str = None,
):
"""
Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
Supports both regular checkpoints and tensor parallel checkpoints.
"""
# Load regular checkpoint
logger.info(f"Loading regular checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
# config = checkpoint['train_args']
init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
if init_config_path.exists():
logger.info(f"Loading initial config from {init_config_path}")
with open(init_config_path, 'r') as f:
init_config = json.load(f)
else:
raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
logger.info(f"Tie word embeddings: {tie_word_embeddings}")
init_config['decoder_config']['use_cache'] = True
config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
# # Extract the model state dict
model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
# If not tying weights, we need to add the lm_head weight separately
model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
# Override with provided config if available
if config_path:
logger.info(f"Loading config from {config_path}")
with open(config_path, 'r') as f:
config_dict = json.load(f)
config = VibeVoiceConfig.from_dict(config_dict)
# Set the default dtype to bfloat16 before creating the model
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
# Create the HuggingFace model
logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
model = VibeVoiceForConditionalGeneration(config)
# Restore original dtype
torch.set_default_dtype(original_dtype)
# Load the state dict
logger.info("Loading weights into model")
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys: {missing_keys}")
if unexpected_keys:
logger.warning(f"Unexpected keys: {unexpected_keys}")
# Create output directory
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
# Save the model and config
logger.info(f"Saving model to {pytorch_dump_folder_path}")
# Save config
config.save_pretrained(pytorch_dump_folder_path)
# Save VibeVoiceProcessor configuration
logger.info("Saving VibeVoiceProcessor configuration")
processor_config = {
"processor_class": "VibeVoiceProcessor",
"speech_tok_compress_ratio": 3200,
"db_normalize": True,
# Audio processor configuration
"audio_processor": {
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
"sampling_rate": 24000,
"normalize_audio": True,
"target_dB_FS": -25,
"eps": 1e-6,
},
"language_model_pretrained_name": pretrained_name,
}
processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
with open(processor_config_path, 'w') as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Saved processor config to {processor_config_path}")
# Save model with sharding
# save_pretrained handles tied weights automatically
logger.info("Saving model weights with sharding...")
model.save_pretrained(
pytorch_dump_folder_path,
max_shard_size="2GB", # Set maximum size for each shard
safe_serialization=True # Ensure saving in .safetensors format
)
logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
logger.info("Conversion complete!")
# Verify the saved model can be loaded
logger.info("Verifying saved model...")
loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
logger.info("Model successfully loaded from saved checkpoint!")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--nnscaler_checkpoint_path",
type=str,
required=True,
help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
"provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
"and the script will automatically detect and merge all parts.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
type=str,
required=True,
help="Path to the output PyTorch model directory",
)
parser.add_argument(
"--config_path",
type=str,
default=None,
help="Optional path to a config JSON file to override extracted config",
)
args = parser.parse_args()
convert_vibevoice_nnscaler_checkpoint_to_hf(
args.nnscaler_checkpoint_path,
args.pytorch_dump_folder_path,
args.config_path,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,220 @@
# comfyui_vibevoice_chunked_wrapper.py
import math
import torch
from comfy.utils import ProgressBar
from .vibevoice_nodes import VibeVoiceTTSNode
# We assume the base node class from your snippet is in the same module/file.
# If it's in another module, import it instead:
# from your_module import VibeVoiceTTSNode
class VibeVoiceTTS_WrapperNode:
"""
Wraps VibeVoiceTTSNode, adds:
- Number of Speakers (1-4) that gates which speaker_*_voice inputs are used
- Chunking controls for multiline script ("Speaker N: ...")
- Iterates per chunk, concatenates outputs into one AUDIO dict
Returns: ("AUDIO",) — waveform [B, C, T], sample_rate per ComfyUI audio spec.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
# Pass-through of model/decoding params to the underlying node:
"model_name": (list(VibeVoiceTTSNode.INPUT_TYPES()["required"]["model_name"][0]), {
"tooltip": "Forwarded to VibeVoiceTTSNode"
}),
"text": ("STRING", {
"multiline": True,
"default": "Speaker 1: Hello there!\nSpeaker 2: And hello from me.",
"tooltip": "Multiline script: 'Speaker 1: ...' one line per utterance"
}),
"num_speakers": ("INT", {
"default": 2, "min": 1, "max": 4, "step": 1,
"tooltip": "How many speaker reference audios to use (14). Extra inputs are ignored."
}),
"chunk_lines": ("BOOLEAN", {
"default": False, "label_on": "Chunk", "label_off": "No chunking",
"tooltip": "When enabled, splits the script into groups of N lines and runs VibeVoice per chunk."
}),
"lines_per_chunk": ("INT", {
"default": 20, "min": 1, "max": 999, "step": 1,
"tooltip": "Only used when 'Chunk' is enabled."
}),
# Forwarded generation knobs:
"quantize_llm_4bit": ("BOOLEAN", {
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision"
}),
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {"default": "sdpa"}),
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05}),
"inference_steps": ("INT", {"default": 10, "min": 1, "max": 50}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True}),
"do_sample": ("BOOLEAN", {"default": True, "label_on": "Sampling", "label_off": "Greedy"}),
"temperature": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01}),
"top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}),
"top_k": ("INT", {"default": 0, "min": 0, "max": 500, "step": 1}),
},
"optional": {
# Provide up to 4 optional speaker audios; we enforce num_speakers in code.
"speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 1"}),
"speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 2"}),
"speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 3"}),
"speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 4"}),
},
# If you REALLY want these hidden until toggled via JS, you can also list them under "hidden"
# and add a tiny JS extension to flip them visible. Pure-Python dynamic show/hide isnt native. # see docs
}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "run"
CATEGORY = "audio/tts"
# --------- helpers ---------
@staticmethod
def _split_into_chunks(lines, n):
"""
Split list of lines into chunks of size n.
If the last chunk would be < 40% of n, merge it into the previous chunk.
"""
if n <= 0:
return [lines] if lines else []
chunks = [lines[i:i+n] for i in range(0, len(lines), n)]
if len(chunks) >= 2:
tail = chunks[-1]
if len(tail) < math.ceil(0.4 * n):
chunks[-2].extend(tail)
chunks.pop()
return chunks
@staticmethod
def _concat_audio_dicts(audio_dicts):
"""
Concatenate a list of ComfyUI AUDIO dicts along time dim T.
Each dict: {"waveform": tensor[B,C,T], "sample_rate": int}
Returns a single AUDIO dict of the same shape convention.
"""
if not audio_dicts:
# Return 1-sample silence if nothing to concat
return {"waveform": torch.zeros((1, 1, 1), dtype=torch.float32), "sample_rate": 24000}
srs = {ad["sample_rate"] for ad in audio_dicts if ad and "sample_rate" in ad}
if len(srs) != 1:
raise ValueError(f"Sample rates differ across chunks: {srs}")
sr = srs.pop()
waves = []
for ad in audio_dicts:
wf = ad["waveform"]
# Expect [B, C, T]
if wf.ndim == 1:
wf = wf.unsqueeze(0).unsqueeze(0) # -> [1,1,T]
elif wf.ndim == 2:
wf = wf.unsqueeze(0) # -> [1,C,T]
waves.append(wf)
# Concatenate on time axis T (-1). Assumes batch (B) and channels (C) match.
out = torch.cat(waves, dim=-1)
return {"waveform": out.cpu(), "sample_rate": sr}
@staticmethod
def _filter_speaker_inputs(kwargs, num_speakers):
"""
Pulls up to num_speakers optional AUDIO inputs from kwargs.
"""
voices = []
for i in range(1, num_speakers + 1):
voices.append(kwargs.get(f"speaker_{i}_voice"))
# Fill the rest with None to align with underlying signature but ignored there
while len(voices) < 4:
voices.append(None)
return {
"speaker_1_voice": voices[0],
"speaker_2_voice": voices[1],
"speaker_3_voice": voices[2],
"speaker_4_voice": voices[3],
}
# --------- main ---------
def run(
self,
model_name,
text,
num_speakers,
chunk_lines,
lines_per_chunk,
quantize_llm_4bit,
attention_mode,
cfg_scale,
inference_steps,
seed,
do_sample,
temperature,
top_p,
top_k,
**kwargs,
):
"""
Orchestrates chunking and calls VibeVoiceTTSNode.generate_audio per chunk.
Then concatenates to a single AUDIO dict.
"""
text = (text or "").strip()
if not text:
# return 1 second of silence at 24kHz, shape [1,1,24000]
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
# Prepare speaker refs according to chosen number of speakers
speaker_kwargs = self._filter_speaker_inputs(kwargs, max(1, min(4, int(num_speakers))))
# Prepare chunks (list of multiline strings)
if chunk_lines:
raw_lines = [ln for ln in text.splitlines() if ln.strip() != ""]
groups = self._split_into_chunks(raw_lines, lines_per_chunk)
chunk_texts = ["\n".join(g) for g in groups] if groups else [text]
else:
chunk_texts = [text]
# Progress bar over chunks
pbar = ProgressBar(total=len(chunk_texts))
# Call the underlying node per chunk
base = VibeVoiceTTSNode()
audio_parts = []
for idx, chunk in enumerate(chunk_texts, 1):
out_audio = base.generate_audio(
model_name=model_name,
text=chunk,
attention_mode=attention_mode,
cfg_scale=cfg_scale,
inference_steps=inference_steps,
seed=seed,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
quantize_llm_4bit=quantize_llm_4bit,
force_offload=False,
**speaker_kwargs,
)[0] # underlying returns (AUDIO,)
audio_parts.append(out_audio)
pbar.update(1)
# Concatenate into one AUDIO
merged = self._concat_audio_dicts(audio_parts)
return (merged,)
# Register
NODE_CLASS_MAPPINGS = {
"VibeVoiceTTS_Wrapper": VibeVoiceTTS_WrapperNode
# Keep the base node mapping from your original file:
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VibeVoiceTTS_Wrapper": "VibeVoice TTS (Chunked Wrapper)"
}

617
vibevoice_nodes.py Normal file
View File

@ -0,0 +1,617 @@
import os
import re
import torch
import numpy as np
import random
from huggingface_hub import hf_hub_download, snapshot_download
import logging
import gc
import folder_paths
import comfy.model_management as model_management
import comfy.model_patcher
from comfy.utils import ProgressBar
from comfy.model_management import throw_exception_if_processing_interrupted
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
from .vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast
try:
import librosa
except ImportError:
print("VibeVoice Node: `librosa` is not installed. Resampling of reference audio will not be available.")
librosa = None
logger = logging.getLogger(__name__)
LOADED_MODELS = {}
VIBEVOICE_PATCHER_CACHE = {}
MODEL_CONFIGS = {
"VibeVoice-1.5B": {
"repo_id": "microsoft/VibeVoice-1.5B",
"size_gb": 3.0,
"tokenizer_repo": "Qwen/Qwen2.5-1.5B"
},
"VibeVoice-Large": {
"repo_id": "microsoft/VibeVoice-Large",
"size_gb": 17.4,
"tokenizer_repo": "Qwen/Qwen2.5-7B"
}
}
ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"]
def cleanup_old_models(keep_cache_key=None):
"""Clean up old models, optionally keeping one specific model loaded"""
global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE
keys_to_remove = []
# Clear LOADED_MODELS
for key in list(LOADED_MODELS.keys()):
if key != keep_cache_key:
keys_to_remove.append(key)
del LOADED_MODELS[key]
# Clear VIBEVOICE_PATCHER_CACHE - but more carefully
for key in list(VIBEVOICE_PATCHER_CACHE.keys()):
if key != keep_cache_key:
# Set the model/processor to None but don't delete the patcher itself
# This lets ComfyUI's model management handle the patcher cleanup
try:
patcher = VIBEVOICE_PATCHER_CACHE[key]
if hasattr(patcher, 'model') and patcher.model:
patcher.model.model = None
patcher.model.processor = None
# Remove from our cache but let ComfyUI handle the rest
del VIBEVOICE_PATCHER_CACHE[key]
except Exception as e:
logger.warning(f"Error cleaning up patcher {key}: {e}")
if keys_to_remove:
logger.info(f"Cleaned up cached models: {keys_to_remove}")
gc.collect()
model_management.soft_empty_cache()
class VibeVoiceModelHandler(torch.nn.Module):
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
super().__init__()
self.model_pack_name = model_pack_name
self.attention_mode = attention_mode
self.use_llm_4bit = use_llm_4bit
self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
self.model = None
self.processor = None
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
def load_model(self, device, attention_mode="eager"):
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
self.model.to(device)
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
"""Custom ModelPatcher for managing VibeVoice models in ComfyUI."""
def __init__(self, model, attention_mode="eager", *args, **kwargs):
super().__init__(model, *args, **kwargs)
self.attention_mode = attention_mode
self.cache_key = model.cache_key
@property
def is_loaded(self):
"""Check if the model is currently loaded in memory."""
return hasattr(self, 'model') and self.model is not None and hasattr(self.model, 'model') and self.model.model is not None
def patch_model(self, device_to=None, *args, **kwargs):
target_device = self.load_device
if self.model.model is None:
logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...")
mode_names = {
"eager": "Eager (Most Compatible)",
"sdpa": "SDPA (Balanced Speed/Compatibility)",
"flash_attention_2": "Flash Attention 2 (Fastest)"
}
logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}")
self.model.load_model(target_device, self.attention_mode)
self.model.model.to(target_device)
return super().patch_model(device_to=target_device, *args, **kwargs)
def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs):
if unpatch_weights:
logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...")
self.model.model = None
self.model.processor = None
# Clear using the correct cache key
if self.cache_key in LOADED_MODELS:
del LOADED_MODELS[self.cache_key]
logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}")
# DON'T delete from VIBEVOICE_PATCHER_CACHE here - let ComfyUI handle it
# This prevents the IndexError in ComfyUI's model management
# Force garbage collection
gc.collect()
model_management.soft_empty_cache()
return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs)
class VibeVoiceLoader:
@staticmethod
def get_model_path(model_name: str):
if model_name not in MODEL_CONFIGS:
raise ValueError(f"Unknown VibeVoice model: {model_name}")
vibevoice_path = os.path.join(folder_paths.get_folder_paths("tts")[0], "VibeVoice")
model_path = os.path.join(vibevoice_path, model_name)
index_file = os.path.join(model_path, "model.safetensors.index.json")
if not os.path.exists(index_file):
print(f"Downloading VibeVoice model: {model_name}...")
repo_id = MODEL_CONFIGS[model_name]["repo_id"]
snapshot_download(repo_id=repo_id, local_dir=model_path)
return model_path
@staticmethod
def _check_attention_compatibility(attention_mode: str, torch_dtype, device_name: str = ""):
"""Check if the requested attention mode is compatible with current setup."""
# Check for SDPA availability (PyTorch 2.0+)
if attention_mode == "sdpa":
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
logger.warning("SDPA not available (requires PyTorch 2.0+), falling back to eager")
return "eager"
# Check for Flash Attention availability
elif attention_mode == "flash_attention_2":
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
logger.warning("Flash Attention not available, falling back to eager")
return "eager"
elif torch_dtype == torch.float32:
logger.warning("Flash Attention not recommended with float32, falling back to SDPA")
return "sdpa" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') else "eager"
# Just informational messages, no forced fallbacks
if device_name and torch.cuda.is_available():
if "RTX 50" in device_name or "Blackwell" in device_name:
if attention_mode == "flash_attention_2":
logger.info(f"Using Flash Attention on {device_name}")
elif attention_mode == "sdpa":
logger.info(f"Using SDPA on {device_name}")
return attention_mode
@staticmethod
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
# Validate attention mode
if attention_mode not in ATTENTION_MODES:
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
attention_mode = "eager"
if use_llm_4bit and attention_mode == "flash_attention_2":
attention_mode = "sdpa"
# Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}"
if cache_key in LOADED_MODELS:
logger.info(f"Using cached model with {attention_mode} attention")
return LOADED_MODELS[cache_key]
model_path = VibeVoiceLoader.get_model_path(model_name)
logger.info(f"Loading VibeVoice model components from: {model_path}")
tokenizer_repo = MODEL_CONFIGS[model_name].get("tokenizer_repo")
try:
tokenizer_file_path = hf_hub_download(repo_id=tokenizer_repo, filename="tokenizer.json")
except Exception as e:
raise RuntimeError(f"Could not download tokenizer.json for {tokenizer_repo}. Error: {e}")
vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path)
audio_processor = VibeVoiceTokenizerProcessor()
processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor)
torch_dtype = model_management.text_encoder_dtype(device)
device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else ""
# Check compatibility and potentially fall back to safer mode
final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
attention_mode, torch_dtype, device_name
)
# Build optional 4-bit config (LLM only)
quant_config = None
if use_llm_4bit:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
logger.info(f"Requested attention mode: {attention_mode}")
if final_attention_mode != attention_mode:
logger.info(f"Using attention mode: {final_attention_mode} (automatic fallback)")
# Update cache key to reflect actual mode used
cache_key = f"{model_name}_attn_{final_attention_mode}"
if cache_key in LOADED_MODELS:
return LOADED_MODELS[cache_key]
else:
logger.info(f"Using attention mode: {final_attention_mode}")
logger.info(f"Final attention implementation: {final_attention_mode}")
# Modify config for non-flash attention modes
if final_attention_mode in ["eager", "sdpa"]:
import json
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
# Remove flash attention settings
removed_keys = []
for key in ['_attn_implementation', 'attn_implementation', 'use_flash_attention_2']:
if key in config:
config.pop(key)
removed_keys.append(key)
if removed_keys:
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"Removed FlashAttention settings from config.json: {removed_keys}")
except Exception as e:
logger.warning(f"Could not modify config.json: {e}")
try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
model_path,
torch_dtype=torch.bfloat16 if quant_config else torch_dtype,
attn_implementation=final_attention_mode,
device_map="auto" if quant_config else device,
quantization_config=quant_config, # <- forwarded if supported
)
model.eval()
setattr(model, "_llm_4bit", bool(quant_config))
# Store with the actual attention mode used (not the requested one)
LOADED_MODELS[cache_key] = (model, processor)
logger.info(f"Successfully loaded model with {final_attention_mode} attention")
return model, processor
except Exception as e:
logger.error(f"Failed to load model with {final_attention_mode} attention: {e}")
# Progressive fallback: flash -> sdpa -> eager
if final_attention_mode == "flash_attention_2":
logger.info("Attempting fallback to SDPA...")
return VibeVoiceLoader.load_model(model_name, device, "sdpa")
elif final_attention_mode == "sdpa":
logger.info("Attempting fallback to eager...")
return VibeVoiceLoader.load_model(model_name, device, "eager")
else:
# If eager fails, something is seriously wrong
raise RuntimeError(f"Failed to load model even with eager attention: {e}")
def set_vibevoice_seed(seed: int):
"""Sets the seed for torch, numpy, and random, handling large seeds for numpy."""
if seed == 0:
seed = random.randint(1, 0xffffffffffffffff)
MAX_NUMPY_SEED = 2**32 - 1
numpy_seed = seed % MAX_NUMPY_SEED
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(numpy_seed)
random.seed(seed)
def parse_script_1_based(script: str) -> tuple[list[tuple[int, str]], list[int]]:
"""
Parses a 1-based speaker script into a list of (speaker_id, text) tuples
and a list of unique speaker IDs in the order of their first appearance.
Internally, it converts speaker IDs to 0-based for the model.
"""
parsed_lines = []
speaker_ids_in_script = [] # This will store the 1-based IDs from the script
for line in script.strip().split("\n"):
if not (line := line.strip()): continue
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
if match:
speaker_id = int(match.group(1))
if speaker_id < 1:
logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'")
continue
text = ' ' + match.group(2).strip()
# Internally, the model expects 0-based indexing for speakers
internal_speaker_id = speaker_id - 1
parsed_lines.append((internal_speaker_id, text))
if speaker_id not in speaker_ids_in_script:
speaker_ids_in_script.append(speaker_id)
else:
logger.warning(f"Could not parse line, skipping: '{line}'")
return parsed_lines, sorted(list(set(speaker_ids_in_script)))
def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarray:
"""
Converts a ComfyUI AUDIO dict to a mono NumPy array, resampling if necessary.
"""
if not audio_dict: return None
waveform_tensor = audio_dict.get('waveform')
if waveform_tensor is None or waveform_tensor.numel() == 0: return None
waveform = waveform_tensor[0].cpu().numpy()
original_sr = audio_dict['sample_rate']
if waveform.ndim > 1:
waveform = np.mean(waveform, axis=0)
# Check for invalid values
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
logger.error("Audio contains NaN or Inf values, replacing with zeros")
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
# Ensure audio is not completely silent or has extreme values
if np.all(waveform == 0):
logger.warning("Audio waveform is completely silent")
# Normalize extreme values
max_val = np.abs(waveform).max()
if max_val > 10.0:
logger.warning(f"Audio values are very large (max: {max_val}), normalizing")
waveform = waveform / max_val
if original_sr != target_sr:
if librosa is None:
raise ImportError("`librosa` package is required for audio resampling. Please install it with `pip install librosa`.")
logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.")
waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr)
# Final check after resampling
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
logger.error("Audio contains NaN or Inf after resampling, replacing with zeros")
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
return waveform.astype(np.float32)
def check_for_interrupt():
try:
throw_exception_if_processing_interrupted()
return False
except:
return True
class VibeVoiceTTSNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name": (list(MODEL_CONFIGS.keys()), {
"tooltip": "Select the VibeVoice model to use. Models will be downloaded automatically if not present."
}),
"text": ("STRING", {
"multiline": True,
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
}),
"quantize_llm_4bit": ("BOOLEAN", {
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision",
"tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32."
}),
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {
"default": "sdpa",
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
}),
"cfg_scale": ("FLOAT", {
"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05,
"tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3"
}),
"inference_steps": ("INT", {
"default": 10, "min": 1, "max": 50,
"tooltip": "Number of diffusion steps for audio generation. More steps can improve quality but take longer. Recommended: 10"
}),
"seed": ("INT", {
"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True,
"tooltip": "Seed for reproducibility. Set to 0 for a random seed on each run."
}),
"do_sample": ("BOOLEAN", {
"default": True, "label_on": "Enabled (Sampling)", "label_off": "Disabled (Greedy)",
"tooltip": "Enable to use sampling methods (like temperature and top_p) for more varied output. Disable for deterministic (greedy) decoding."
}),
"temperature": ("FLOAT", {
"default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01,
"tooltip": "Controls randomness. Higher values make the output more random and creative, while lower values make it more focused and deterministic. Active only if 'do_sample' is enabled."
}),
"top_p": ("FLOAT", {
"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01,
"tooltip": "Nucleus sampling (Top-P). The model samples from the smallest set of tokens whose cumulative probability exceeds this value. Active only if 'do_sample' is enabled."
}),
"top_k": ("INT", {
"default": 0, "min": 0, "max": 500, "step": 1,
"tooltip": "Top-K sampling. Restricts sampling to the K most likely next tokens. Set to 0 to disable. Active only if 'do_sample' is enabled."
}),
"force_offload": ("BOOLEAN", {
"default": False, "label_on": "Force Offload", "label_off": "Keep in VRAM",
"tooltip": "Force model to be offloaded from VRAM after generation. Useful to free up memory between generations but may slow down subsequent runs."
}),
},
"optional": {
"speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 1' in the script."}),
"speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 2' in the script."}),
"speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 3' in the script."}),
"speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 4' in the script."}),
}
}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "generate_audio"
CATEGORY = "audio/tts"
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs):
if not text.strip():
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
# Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}"
# Clean up old models when switching to a different model
if cache_key not in VIBEVOICE_PATCHER_CACHE:
# Only keep models that are currently being requested
cleanup_old_models(keep_cache_key=cache_key)
model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
patcher = VibeVoicePatcher(
model_handler,
attention_mode=attention_mode,
load_device=model_management.get_torch_device(),
offload_device=model_management.unet_offload_device(),
size=model_handler.size
)
VIBEVOICE_PATCHER_CACHE[cache_key] = patcher
patcher = VIBEVOICE_PATCHER_CACHE[cache_key]
model_management.load_model_gpu(patcher)
model = patcher.model.model
processor = patcher.model.processor
if model is None or processor is None:
raise RuntimeError("VibeVoice model and processor could not be loaded. Check logs for errors.")
parsed_lines_0_based, speaker_ids_1_based = parse_script_1_based(text)
if not parsed_lines_0_based:
raise ValueError("Script is empty or invalid. Use 'Speaker 1:', 'Speaker 2:', etc. format.")
full_script = "\n".join([f"Speaker {spk}: {txt}" for spk, txt in parsed_lines_0_based])
speaker_inputs = {i: kwargs.get(f"speaker_{i}_voice") for i in range(1, 5)}
voice_samples_np = [preprocess_comfy_audio(speaker_inputs[sid]) for sid in speaker_ids_1_based]
if any(v is None for v in voice_samples_np):
missing_ids = [sid for sid, v in zip(speaker_ids_1_based, voice_samples_np) if v is None]
raise ValueError(f"Script requires voices for Speakers {missing_ids}, but they were not provided.")
set_vibevoice_seed(seed)
try:
inputs = processor(
text=[full_script], voice_samples=[voice_samples_np], padding=True,
return_tensors="pt", return_attention_mask=True
)
# Validate inputs before moving to GPU
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)):
logger.error(f"Input tensor '{key}' contains NaN or Inf values")
raise ValueError(f"Invalid values in input tensor: {key}")
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
model.set_ddpm_inference_steps(num_steps=inference_steps)
generation_config = {'do_sample': do_sample}
if do_sample:
generation_config['temperature'] = temperature
generation_config['top_p'] = top_p
if top_k > 0:
generation_config['top_k'] = top_k
# Hardware-specific optimizations - only for eager mode
if attention_mode == "eager":
# Apply RTX 5090 / Blackwell compatibility fixes only for eager
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.cuda.empty_cache()
# Apply additional tensor fixes for eager mode
model = model.float()
processed_inputs = {}
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
# Keep integer/boolean tensors as-is (token IDs, attention masks, etc.)
if v.dtype in [torch.int, torch.long, torch.int32, torch.int64, torch.bool, torch.uint8]:
processed_inputs[k] = v
# Keep tensors with "mask" in their name as boolean
elif "mask" in k.lower():
processed_inputs[k] = v.bool() if v.dtype != torch.bool else v
else:
# Convert float/bfloat16 tensors to float32
processed_inputs[k] = v.float()
else:
processed_inputs[k] = v
inputs = processed_inputs
with torch.no_grad():
# Create progress bar for inference steps
pbar = ProgressBar(inference_steps)
def progress_callback(step, total_steps):
pbar.update(1)
# Check for interruption from ComfyUI
if model_management.interrupt_current_processing:
raise comfy.model_management.InterruptProcessingException()
# Custom generation loop with interruption support
try:
outputs = model.generate(
**inputs, max_new_tokens=None, cfg_scale=cfg_scale,
tokenizer=processor.tokenizer, generation_config=generation_config,
verbose=False, stop_check_fn=check_for_interrupt
)
# Note: The model.generate method doesn't support progress callbacks in the current VibeVoice implementation
# But we check for interruption at the start and end of generation
pbar.update(inference_steps - pbar.current)
except RuntimeError as e:
error_msg = str(e).lower()
if "assertion" in error_msg or "cuda" in error_msg:
logger.error(f"CUDA assertion failed with {attention_mode} attention: {e}")
logger.error("This might be due to invalid input data, GPU memory issues, or incompatible attention mode.")
logger.error("Try restarting ComfyUI, using different audio files, or switching to 'eager' attention mode.")
raise e
except comfy.model_management.InterruptProcessingException:
logger.info("VibeVoice generation interrupted by user")
raise
finally:
pbar.update_absolute(inference_steps)
except comfy.model_management.InterruptProcessingException:
logger.info("VibeVoice TTS generation was cancelled")
# Return silent audio on cancellation
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
except Exception as e:
logger.error(f"Error during VibeVoice generation with {attention_mode} attention: {e}")
if "interrupt" in str(e).lower() or "cancel" in str(e).lower():
logger.info("Generation was interrupted")
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
raise
output_waveform = outputs.speech_outputs[0]
if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0)
if output_waveform.ndim == 2: output_waveform = output_waveform.unsqueeze(0)
# Force offload model if requested
if force_offload:
logger.info(f"Force offloading VibeVoice model '{model_name}' from VRAM...")
# Force offload by unpatching the model and freeing memory
if patcher.is_loaded:
patcher.unpatch_model(unpatch_weights=True)
# Force unload all models to free memory
model_management.unload_all_models()
gc.collect()
model_management.soft_empty_cache()
logger.info("Model force offload completed")
return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},)
NODE_CLASS_MAPPINGS = {"VibeVoiceTTS": VibeVoiceTTSNode}
NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"}