First commit

This commit is contained in:
2025-09-03 19:24:26 +10:00
parent 2bf7123ae2
commit 1e1402795f
30 changed files with 7582 additions and 33 deletions

View File

@ -0,0 +1,220 @@
# comfyui_vibevoice_chunked_wrapper.py
import math
import torch
from comfy.utils import ProgressBar
from .vibevoice_nodes import VibeVoiceTTSNode
# We assume the base node class from your snippet is in the same module/file.
# If it's in another module, import it instead:
# from your_module import VibeVoiceTTSNode
class VibeVoiceTTS_WrapperNode:
"""
Wraps VibeVoiceTTSNode, adds:
- Number of Speakers (1-4) that gates which speaker_*_voice inputs are used
- Chunking controls for multiline script ("Speaker N: ...")
- Iterates per chunk, concatenates outputs into one AUDIO dict
Returns: ("AUDIO",) — waveform [B, C, T], sample_rate per ComfyUI audio spec.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
# Pass-through of model/decoding params to the underlying node:
"model_name": (list(VibeVoiceTTSNode.INPUT_TYPES()["required"]["model_name"][0]), {
"tooltip": "Forwarded to VibeVoiceTTSNode"
}),
"text": ("STRING", {
"multiline": True,
"default": "Speaker 1: Hello there!\nSpeaker 2: And hello from me.",
"tooltip": "Multiline script: 'Speaker 1: ...' one line per utterance"
}),
"num_speakers": ("INT", {
"default": 2, "min": 1, "max": 4, "step": 1,
"tooltip": "How many speaker reference audios to use (14). Extra inputs are ignored."
}),
"chunk_lines": ("BOOLEAN", {
"default": False, "label_on": "Chunk", "label_off": "No chunking",
"tooltip": "When enabled, splits the script into groups of N lines and runs VibeVoice per chunk."
}),
"lines_per_chunk": ("INT", {
"default": 20, "min": 1, "max": 999, "step": 1,
"tooltip": "Only used when 'Chunk' is enabled."
}),
# Forwarded generation knobs:
"quantize_llm_4bit": ("BOOLEAN", {
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision"
}),
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {"default": "sdpa"}),
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05}),
"inference_steps": ("INT", {"default": 10, "min": 1, "max": 50}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True}),
"do_sample": ("BOOLEAN", {"default": True, "label_on": "Sampling", "label_off": "Greedy"}),
"temperature": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01}),
"top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}),
"top_k": ("INT", {"default": 0, "min": 0, "max": 500, "step": 1}),
},
"optional": {
# Provide up to 4 optional speaker audios; we enforce num_speakers in code.
"speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 1"}),
"speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 2"}),
"speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 3"}),
"speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 4"}),
},
# If you REALLY want these hidden until toggled via JS, you can also list them under "hidden"
# and add a tiny JS extension to flip them visible. Pure-Python dynamic show/hide isnt native. # see docs
}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "run"
CATEGORY = "audio/tts"
# --------- helpers ---------
@staticmethod
def _split_into_chunks(lines, n):
"""
Split list of lines into chunks of size n.
If the last chunk would be < 40% of n, merge it into the previous chunk.
"""
if n <= 0:
return [lines] if lines else []
chunks = [lines[i:i+n] for i in range(0, len(lines), n)]
if len(chunks) >= 2:
tail = chunks[-1]
if len(tail) < math.ceil(0.4 * n):
chunks[-2].extend(tail)
chunks.pop()
return chunks
@staticmethod
def _concat_audio_dicts(audio_dicts):
"""
Concatenate a list of ComfyUI AUDIO dicts along time dim T.
Each dict: {"waveform": tensor[B,C,T], "sample_rate": int}
Returns a single AUDIO dict of the same shape convention.
"""
if not audio_dicts:
# Return 1-sample silence if nothing to concat
return {"waveform": torch.zeros((1, 1, 1), dtype=torch.float32), "sample_rate": 24000}
srs = {ad["sample_rate"] for ad in audio_dicts if ad and "sample_rate" in ad}
if len(srs) != 1:
raise ValueError(f"Sample rates differ across chunks: {srs}")
sr = srs.pop()
waves = []
for ad in audio_dicts:
wf = ad["waveform"]
# Expect [B, C, T]
if wf.ndim == 1:
wf = wf.unsqueeze(0).unsqueeze(0) # -> [1,1,T]
elif wf.ndim == 2:
wf = wf.unsqueeze(0) # -> [1,C,T]
waves.append(wf)
# Concatenate on time axis T (-1). Assumes batch (B) and channels (C) match.
out = torch.cat(waves, dim=-1)
return {"waveform": out.cpu(), "sample_rate": sr}
@staticmethod
def _filter_speaker_inputs(kwargs, num_speakers):
"""
Pulls up to num_speakers optional AUDIO inputs from kwargs.
"""
voices = []
for i in range(1, num_speakers + 1):
voices.append(kwargs.get(f"speaker_{i}_voice"))
# Fill the rest with None to align with underlying signature but ignored there
while len(voices) < 4:
voices.append(None)
return {
"speaker_1_voice": voices[0],
"speaker_2_voice": voices[1],
"speaker_3_voice": voices[2],
"speaker_4_voice": voices[3],
}
# --------- main ---------
def run(
self,
model_name,
text,
num_speakers,
chunk_lines,
lines_per_chunk,
quantize_llm_4bit,
attention_mode,
cfg_scale,
inference_steps,
seed,
do_sample,
temperature,
top_p,
top_k,
**kwargs,
):
"""
Orchestrates chunking and calls VibeVoiceTTSNode.generate_audio per chunk.
Then concatenates to a single AUDIO dict.
"""
text = (text or "").strip()
if not text:
# return 1 second of silence at 24kHz, shape [1,1,24000]
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
# Prepare speaker refs according to chosen number of speakers
speaker_kwargs = self._filter_speaker_inputs(kwargs, max(1, min(4, int(num_speakers))))
# Prepare chunks (list of multiline strings)
if chunk_lines:
raw_lines = [ln for ln in text.splitlines() if ln.strip() != ""]
groups = self._split_into_chunks(raw_lines, lines_per_chunk)
chunk_texts = ["\n".join(g) for g in groups] if groups else [text]
else:
chunk_texts = [text]
# Progress bar over chunks
pbar = ProgressBar(total=len(chunk_texts))
# Call the underlying node per chunk
base = VibeVoiceTTSNode()
audio_parts = []
for idx, chunk in enumerate(chunk_texts, 1):
out_audio = base.generate_audio(
model_name=model_name,
text=chunk,
attention_mode=attention_mode,
cfg_scale=cfg_scale,
inference_steps=inference_steps,
seed=seed,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
quantize_llm_4bit=quantize_llm_4bit,
force_offload=False,
**speaker_kwargs,
)[0] # underlying returns (AUDIO,)
audio_parts.append(out_audio)
pbar.update(1)
# Concatenate into one AUDIO
merged = self._concat_audio_dicts(audio_parts)
return (merged,)
# Register
NODE_CLASS_MAPPINGS = {
"VibeVoiceTTS_Wrapper": VibeVoiceTTS_WrapperNode
# Keep the base node mapping from your original file:
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VibeVoiceTTS_Wrapper": "VibeVoice TTS (Chunked Wrapper)"
}