221 lines
8.6 KiB
Python
221 lines
8.6 KiB
Python
# comfyui_vibevoice_chunked_wrapper.py
|
||
|
||
import math
|
||
import torch
|
||
from comfy.utils import ProgressBar
|
||
|
||
from .vibevoice_nodes import VibeVoiceTTSNode
|
||
|
||
# We assume the base node class from your snippet is in the same module/file.
|
||
# If it's in another module, import it instead:
|
||
# from your_module import VibeVoiceTTSNode
|
||
|
||
class VibeVoiceTTS_WrapperNode:
|
||
"""
|
||
Wraps VibeVoiceTTSNode, adds:
|
||
- Number of Speakers (1-4) that gates which speaker_*_voice inputs are used
|
||
- Chunking controls for multiline script ("Speaker N: ...")
|
||
- Iterates per chunk, concatenates outputs into one AUDIO dict
|
||
|
||
Returns: ("AUDIO",) — waveform [B, C, T], sample_rate per ComfyUI audio spec.
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
# Pass-through of model/decoding params to the underlying node:
|
||
"model_name": (list(VibeVoiceTTSNode.INPUT_TYPES()["required"]["model_name"][0]), {
|
||
"tooltip": "Forwarded to VibeVoiceTTSNode"
|
||
}),
|
||
"text": ("STRING", {
|
||
"multiline": True,
|
||
"default": "Speaker 1: Hello there!\nSpeaker 2: And hello from me.",
|
||
"tooltip": "Multiline script: 'Speaker 1: ...' one line per utterance"
|
||
}),
|
||
"num_speakers": ("INT", {
|
||
"default": 2, "min": 1, "max": 4, "step": 1,
|
||
"tooltip": "How many speaker reference audios to use (1–4). Extra inputs are ignored."
|
||
}),
|
||
"chunk_lines": ("BOOLEAN", {
|
||
"default": False, "label_on": "Chunk", "label_off": "No chunking",
|
||
"tooltip": "When enabled, splits the script into groups of N lines and runs VibeVoice per chunk."
|
||
}),
|
||
"lines_per_chunk": ("INT", {
|
||
"default": 20, "min": 1, "max": 999, "step": 1,
|
||
"tooltip": "Only used when 'Chunk' is enabled."
|
||
}),
|
||
|
||
# Forwarded generation knobs:
|
||
"quantize_llm_4bit": ("BOOLEAN", {
|
||
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision"
|
||
}),
|
||
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {"default": "sdpa"}),
|
||
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05}),
|
||
"inference_steps": ("INT", {"default": 10, "min": 1, "max": 50}),
|
||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True}),
|
||
"do_sample": ("BOOLEAN", {"default": True, "label_on": "Sampling", "label_off": "Greedy"}),
|
||
"temperature": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01}),
|
||
"top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||
"top_k": ("INT", {"default": 0, "min": 0, "max": 500, "step": 1}),
|
||
},
|
||
"optional": {
|
||
# Provide up to 4 optional speaker audios; we enforce num_speakers in code.
|
||
"speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 1"}),
|
||
"speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 2"}),
|
||
"speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 3"}),
|
||
"speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 4"}),
|
||
},
|
||
# If you REALLY want these hidden until toggled via JS, you can also list them under "hidden"
|
||
# and add a tiny JS extension to flip them visible. Pure-Python dynamic show/hide isn’t native. # see docs
|
||
}
|
||
|
||
RETURN_TYPES = ("AUDIO",)
|
||
FUNCTION = "run"
|
||
CATEGORY = "audio/tts"
|
||
|
||
# --------- helpers ---------
|
||
@staticmethod
|
||
def _split_into_chunks(lines, n):
|
||
"""
|
||
Split list of lines into chunks of size n.
|
||
If the last chunk would be < 40% of n, merge it into the previous chunk.
|
||
"""
|
||
if n <= 0:
|
||
return [lines] if lines else []
|
||
|
||
chunks = [lines[i:i+n] for i in range(0, len(lines), n)]
|
||
if len(chunks) >= 2:
|
||
tail = chunks[-1]
|
||
if len(tail) < math.ceil(0.4 * n):
|
||
chunks[-2].extend(tail)
|
||
chunks.pop()
|
||
return chunks
|
||
|
||
@staticmethod
|
||
def _concat_audio_dicts(audio_dicts):
|
||
"""
|
||
Concatenate a list of ComfyUI AUDIO dicts along time dim T.
|
||
Each dict: {"waveform": tensor[B,C,T], "sample_rate": int}
|
||
Returns a single AUDIO dict of the same shape convention.
|
||
"""
|
||
if not audio_dicts:
|
||
# Return 1-sample silence if nothing to concat
|
||
return {"waveform": torch.zeros((1, 1, 1), dtype=torch.float32), "sample_rate": 24000}
|
||
|
||
srs = {ad["sample_rate"] for ad in audio_dicts if ad and "sample_rate" in ad}
|
||
if len(srs) != 1:
|
||
raise ValueError(f"Sample rates differ across chunks: {srs}")
|
||
sr = srs.pop()
|
||
|
||
waves = []
|
||
for ad in audio_dicts:
|
||
wf = ad["waveform"]
|
||
# Expect [B, C, T]
|
||
if wf.ndim == 1:
|
||
wf = wf.unsqueeze(0).unsqueeze(0) # -> [1,1,T]
|
||
elif wf.ndim == 2:
|
||
wf = wf.unsqueeze(0) # -> [1,C,T]
|
||
waves.append(wf)
|
||
|
||
# Concatenate on time axis T (-1). Assumes batch (B) and channels (C) match.
|
||
out = torch.cat(waves, dim=-1)
|
||
return {"waveform": out.cpu(), "sample_rate": sr}
|
||
|
||
@staticmethod
|
||
def _filter_speaker_inputs(kwargs, num_speakers):
|
||
"""
|
||
Pulls up to num_speakers optional AUDIO inputs from kwargs.
|
||
"""
|
||
voices = []
|
||
for i in range(1, num_speakers + 1):
|
||
voices.append(kwargs.get(f"speaker_{i}_voice"))
|
||
# Fill the rest with None to align with underlying signature but ignored there
|
||
while len(voices) < 4:
|
||
voices.append(None)
|
||
return {
|
||
"speaker_1_voice": voices[0],
|
||
"speaker_2_voice": voices[1],
|
||
"speaker_3_voice": voices[2],
|
||
"speaker_4_voice": voices[3],
|
||
}
|
||
|
||
# --------- main ---------
|
||
def run(
|
||
self,
|
||
model_name,
|
||
text,
|
||
num_speakers,
|
||
chunk_lines,
|
||
lines_per_chunk,
|
||
quantize_llm_4bit,
|
||
attention_mode,
|
||
cfg_scale,
|
||
inference_steps,
|
||
seed,
|
||
do_sample,
|
||
temperature,
|
||
top_p,
|
||
top_k,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Orchestrates chunking and calls VibeVoiceTTSNode.generate_audio per chunk.
|
||
Then concatenates to a single AUDIO dict.
|
||
"""
|
||
|
||
text = (text or "").strip()
|
||
if not text:
|
||
# return 1 second of silence at 24kHz, shape [1,1,24000]
|
||
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
||
|
||
# Prepare speaker refs according to chosen number of speakers
|
||
speaker_kwargs = self._filter_speaker_inputs(kwargs, max(1, min(4, int(num_speakers))))
|
||
|
||
# Prepare chunks (list of multiline strings)
|
||
if chunk_lines:
|
||
raw_lines = [ln for ln in text.splitlines() if ln.strip() != ""]
|
||
groups = self._split_into_chunks(raw_lines, lines_per_chunk)
|
||
chunk_texts = ["\n".join(g) for g in groups] if groups else [text]
|
||
else:
|
||
chunk_texts = [text]
|
||
|
||
# Progress bar over chunks
|
||
pbar = ProgressBar(total=len(chunk_texts))
|
||
|
||
# Call the underlying node per chunk
|
||
base = VibeVoiceTTSNode()
|
||
audio_parts = []
|
||
for idx, chunk in enumerate(chunk_texts, 1):
|
||
out_audio = base.generate_audio(
|
||
model_name=model_name,
|
||
text=chunk,
|
||
attention_mode=attention_mode,
|
||
cfg_scale=cfg_scale,
|
||
inference_steps=inference_steps,
|
||
seed=seed,
|
||
do_sample=do_sample,
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
top_k=top_k,
|
||
quantize_llm_4bit=quantize_llm_4bit,
|
||
force_offload=False,
|
||
**speaker_kwargs,
|
||
)[0] # underlying returns (AUDIO,)
|
||
audio_parts.append(out_audio)
|
||
pbar.update(1)
|
||
|
||
# Concatenate into one AUDIO
|
||
merged = self._concat_audio_dicts(audio_parts)
|
||
return (merged,)
|
||
|
||
|
||
# Register
|
||
NODE_CLASS_MAPPINGS = {
|
||
"VibeVoiceTTS_Wrapper": VibeVoiceTTS_WrapperNode
|
||
# Keep the base node mapping from your original file:
|
||
}
|
||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
"VibeVoiceTTS_Wrapper": "VibeVoice TTS (Chunked Wrapper)"
|
||
}
|