Files
VibeVoice-Modifications/vibevoice_node_chunked_wrapper.py
2025-09-03 19:24:26 +10:00

221 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# comfyui_vibevoice_chunked_wrapper.py
import math
import torch
from comfy.utils import ProgressBar
from .vibevoice_nodes import VibeVoiceTTSNode
# We assume the base node class from your snippet is in the same module/file.
# If it's in another module, import it instead:
# from your_module import VibeVoiceTTSNode
class VibeVoiceTTS_WrapperNode:
"""
Wraps VibeVoiceTTSNode, adds:
- Number of Speakers (1-4) that gates which speaker_*_voice inputs are used
- Chunking controls for multiline script ("Speaker N: ...")
- Iterates per chunk, concatenates outputs into one AUDIO dict
Returns: ("AUDIO",) — waveform [B, C, T], sample_rate per ComfyUI audio spec.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
# Pass-through of model/decoding params to the underlying node:
"model_name": (list(VibeVoiceTTSNode.INPUT_TYPES()["required"]["model_name"][0]), {
"tooltip": "Forwarded to VibeVoiceTTSNode"
}),
"text": ("STRING", {
"multiline": True,
"default": "Speaker 1: Hello there!\nSpeaker 2: And hello from me.",
"tooltip": "Multiline script: 'Speaker 1: ...' one line per utterance"
}),
"num_speakers": ("INT", {
"default": 2, "min": 1, "max": 4, "step": 1,
"tooltip": "How many speaker reference audios to use (14). Extra inputs are ignored."
}),
"chunk_lines": ("BOOLEAN", {
"default": False, "label_on": "Chunk", "label_off": "No chunking",
"tooltip": "When enabled, splits the script into groups of N lines and runs VibeVoice per chunk."
}),
"lines_per_chunk": ("INT", {
"default": 20, "min": 1, "max": 999, "step": 1,
"tooltip": "Only used when 'Chunk' is enabled."
}),
# Forwarded generation knobs:
"quantize_llm_4bit": ("BOOLEAN", {
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision"
}),
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {"default": "sdpa"}),
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05}),
"inference_steps": ("INT", {"default": 10, "min": 1, "max": 50}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True}),
"do_sample": ("BOOLEAN", {"default": True, "label_on": "Sampling", "label_off": "Greedy"}),
"temperature": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01}),
"top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}),
"top_k": ("INT", {"default": 0, "min": 0, "max": 500, "step": 1}),
},
"optional": {
# Provide up to 4 optional speaker audios; we enforce num_speakers in code.
"speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 1"}),
"speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 2"}),
"speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 3"}),
"speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 4"}),
},
# If you REALLY want these hidden until toggled via JS, you can also list them under "hidden"
# and add a tiny JS extension to flip them visible. Pure-Python dynamic show/hide isnt native. # see docs
}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "run"
CATEGORY = "audio/tts"
# --------- helpers ---------
@staticmethod
def _split_into_chunks(lines, n):
"""
Split list of lines into chunks of size n.
If the last chunk would be < 40% of n, merge it into the previous chunk.
"""
if n <= 0:
return [lines] if lines else []
chunks = [lines[i:i+n] for i in range(0, len(lines), n)]
if len(chunks) >= 2:
tail = chunks[-1]
if len(tail) < math.ceil(0.4 * n):
chunks[-2].extend(tail)
chunks.pop()
return chunks
@staticmethod
def _concat_audio_dicts(audio_dicts):
"""
Concatenate a list of ComfyUI AUDIO dicts along time dim T.
Each dict: {"waveform": tensor[B,C,T], "sample_rate": int}
Returns a single AUDIO dict of the same shape convention.
"""
if not audio_dicts:
# Return 1-sample silence if nothing to concat
return {"waveform": torch.zeros((1, 1, 1), dtype=torch.float32), "sample_rate": 24000}
srs = {ad["sample_rate"] for ad in audio_dicts if ad and "sample_rate" in ad}
if len(srs) != 1:
raise ValueError(f"Sample rates differ across chunks: {srs}")
sr = srs.pop()
waves = []
for ad in audio_dicts:
wf = ad["waveform"]
# Expect [B, C, T]
if wf.ndim == 1:
wf = wf.unsqueeze(0).unsqueeze(0) # -> [1,1,T]
elif wf.ndim == 2:
wf = wf.unsqueeze(0) # -> [1,C,T]
waves.append(wf)
# Concatenate on time axis T (-1). Assumes batch (B) and channels (C) match.
out = torch.cat(waves, dim=-1)
return {"waveform": out.cpu(), "sample_rate": sr}
@staticmethod
def _filter_speaker_inputs(kwargs, num_speakers):
"""
Pulls up to num_speakers optional AUDIO inputs from kwargs.
"""
voices = []
for i in range(1, num_speakers + 1):
voices.append(kwargs.get(f"speaker_{i}_voice"))
# Fill the rest with None to align with underlying signature but ignored there
while len(voices) < 4:
voices.append(None)
return {
"speaker_1_voice": voices[0],
"speaker_2_voice": voices[1],
"speaker_3_voice": voices[2],
"speaker_4_voice": voices[3],
}
# --------- main ---------
def run(
self,
model_name,
text,
num_speakers,
chunk_lines,
lines_per_chunk,
quantize_llm_4bit,
attention_mode,
cfg_scale,
inference_steps,
seed,
do_sample,
temperature,
top_p,
top_k,
**kwargs,
):
"""
Orchestrates chunking and calls VibeVoiceTTSNode.generate_audio per chunk.
Then concatenates to a single AUDIO dict.
"""
text = (text or "").strip()
if not text:
# return 1 second of silence at 24kHz, shape [1,1,24000]
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
# Prepare speaker refs according to chosen number of speakers
speaker_kwargs = self._filter_speaker_inputs(kwargs, max(1, min(4, int(num_speakers))))
# Prepare chunks (list of multiline strings)
if chunk_lines:
raw_lines = [ln for ln in text.splitlines() if ln.strip() != ""]
groups = self._split_into_chunks(raw_lines, lines_per_chunk)
chunk_texts = ["\n".join(g) for g in groups] if groups else [text]
else:
chunk_texts = [text]
# Progress bar over chunks
pbar = ProgressBar(total=len(chunk_texts))
# Call the underlying node per chunk
base = VibeVoiceTTSNode()
audio_parts = []
for idx, chunk in enumerate(chunk_texts, 1):
out_audio = base.generate_audio(
model_name=model_name,
text=chunk,
attention_mode=attention_mode,
cfg_scale=cfg_scale,
inference_steps=inference_steps,
seed=seed,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
quantize_llm_4bit=quantize_llm_4bit,
force_offload=False,
**speaker_kwargs,
)[0] # underlying returns (AUDIO,)
audio_parts.append(out_audio)
pbar.update(1)
# Concatenate into one AUDIO
merged = self._concat_audio_dicts(audio_parts)
return (merged,)
# Register
NODE_CLASS_MAPPINGS = {
"VibeVoiceTTS_Wrapper": VibeVoiceTTS_WrapperNode
# Keep the base node mapping from your original file:
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VibeVoiceTTS_Wrapper": "VibeVoice TTS (Chunked Wrapper)"
}