First commit
This commit is contained in:
220
vibevoice_node_chunked_wrapper.py
Normal file
220
vibevoice_node_chunked_wrapper.py
Normal file
@ -0,0 +1,220 @@
|
||||
# comfyui_vibevoice_chunked_wrapper.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
from .vibevoice_nodes import VibeVoiceTTSNode
|
||||
|
||||
# We assume the base node class from your snippet is in the same module/file.
|
||||
# If it's in another module, import it instead:
|
||||
# from your_module import VibeVoiceTTSNode
|
||||
|
||||
class VibeVoiceTTS_WrapperNode:
|
||||
"""
|
||||
Wraps VibeVoiceTTSNode, adds:
|
||||
- Number of Speakers (1-4) that gates which speaker_*_voice inputs are used
|
||||
- Chunking controls for multiline script ("Speaker N: ...")
|
||||
- Iterates per chunk, concatenates outputs into one AUDIO dict
|
||||
|
||||
Returns: ("AUDIO",) — waveform [B, C, T], sample_rate per ComfyUI audio spec.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
# Pass-through of model/decoding params to the underlying node:
|
||||
"model_name": (list(VibeVoiceTTSNode.INPUT_TYPES()["required"]["model_name"][0]), {
|
||||
"tooltip": "Forwarded to VibeVoiceTTSNode"
|
||||
}),
|
||||
"text": ("STRING", {
|
||||
"multiline": True,
|
||||
"default": "Speaker 1: Hello there!\nSpeaker 2: And hello from me.",
|
||||
"tooltip": "Multiline script: 'Speaker 1: ...' one line per utterance"
|
||||
}),
|
||||
"num_speakers": ("INT", {
|
||||
"default": 2, "min": 1, "max": 4, "step": 1,
|
||||
"tooltip": "How many speaker reference audios to use (1–4). Extra inputs are ignored."
|
||||
}),
|
||||
"chunk_lines": ("BOOLEAN", {
|
||||
"default": False, "label_on": "Chunk", "label_off": "No chunking",
|
||||
"tooltip": "When enabled, splits the script into groups of N lines and runs VibeVoice per chunk."
|
||||
}),
|
||||
"lines_per_chunk": ("INT", {
|
||||
"default": 20, "min": 1, "max": 999, "step": 1,
|
||||
"tooltip": "Only used when 'Chunk' is enabled."
|
||||
}),
|
||||
|
||||
# Forwarded generation knobs:
|
||||
"quantize_llm_4bit": ("BOOLEAN", {
|
||||
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision"
|
||||
}),
|
||||
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {"default": "sdpa"}),
|
||||
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05}),
|
||||
"inference_steps": ("INT", {"default": 10, "min": 1, "max": 50}),
|
||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True}),
|
||||
"do_sample": ("BOOLEAN", {"default": True, "label_on": "Sampling", "label_off": "Greedy"}),
|
||||
"temperature": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01}),
|
||||
"top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"top_k": ("INT", {"default": 0, "min": 0, "max": 500, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
# Provide up to 4 optional speaker audios; we enforce num_speakers in code.
|
||||
"speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 1"}),
|
||||
"speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 2"}),
|
||||
"speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 3"}),
|
||||
"speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 4"}),
|
||||
},
|
||||
# If you REALLY want these hidden until toggled via JS, you can also list them under "hidden"
|
||||
# and add a tiny JS extension to flip them visible. Pure-Python dynamic show/hide isn’t native. # see docs
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "run"
|
||||
CATEGORY = "audio/tts"
|
||||
|
||||
# --------- helpers ---------
|
||||
@staticmethod
|
||||
def _split_into_chunks(lines, n):
|
||||
"""
|
||||
Split list of lines into chunks of size n.
|
||||
If the last chunk would be < 40% of n, merge it into the previous chunk.
|
||||
"""
|
||||
if n <= 0:
|
||||
return [lines] if lines else []
|
||||
|
||||
chunks = [lines[i:i+n] for i in range(0, len(lines), n)]
|
||||
if len(chunks) >= 2:
|
||||
tail = chunks[-1]
|
||||
if len(tail) < math.ceil(0.4 * n):
|
||||
chunks[-2].extend(tail)
|
||||
chunks.pop()
|
||||
return chunks
|
||||
|
||||
@staticmethod
|
||||
def _concat_audio_dicts(audio_dicts):
|
||||
"""
|
||||
Concatenate a list of ComfyUI AUDIO dicts along time dim T.
|
||||
Each dict: {"waveform": tensor[B,C,T], "sample_rate": int}
|
||||
Returns a single AUDIO dict of the same shape convention.
|
||||
"""
|
||||
if not audio_dicts:
|
||||
# Return 1-sample silence if nothing to concat
|
||||
return {"waveform": torch.zeros((1, 1, 1), dtype=torch.float32), "sample_rate": 24000}
|
||||
|
||||
srs = {ad["sample_rate"] for ad in audio_dicts if ad and "sample_rate" in ad}
|
||||
if len(srs) != 1:
|
||||
raise ValueError(f"Sample rates differ across chunks: {srs}")
|
||||
sr = srs.pop()
|
||||
|
||||
waves = []
|
||||
for ad in audio_dicts:
|
||||
wf = ad["waveform"]
|
||||
# Expect [B, C, T]
|
||||
if wf.ndim == 1:
|
||||
wf = wf.unsqueeze(0).unsqueeze(0) # -> [1,1,T]
|
||||
elif wf.ndim == 2:
|
||||
wf = wf.unsqueeze(0) # -> [1,C,T]
|
||||
waves.append(wf)
|
||||
|
||||
# Concatenate on time axis T (-1). Assumes batch (B) and channels (C) match.
|
||||
out = torch.cat(waves, dim=-1)
|
||||
return {"waveform": out.cpu(), "sample_rate": sr}
|
||||
|
||||
@staticmethod
|
||||
def _filter_speaker_inputs(kwargs, num_speakers):
|
||||
"""
|
||||
Pulls up to num_speakers optional AUDIO inputs from kwargs.
|
||||
"""
|
||||
voices = []
|
||||
for i in range(1, num_speakers + 1):
|
||||
voices.append(kwargs.get(f"speaker_{i}_voice"))
|
||||
# Fill the rest with None to align with underlying signature but ignored there
|
||||
while len(voices) < 4:
|
||||
voices.append(None)
|
||||
return {
|
||||
"speaker_1_voice": voices[0],
|
||||
"speaker_2_voice": voices[1],
|
||||
"speaker_3_voice": voices[2],
|
||||
"speaker_4_voice": voices[3],
|
||||
}
|
||||
|
||||
# --------- main ---------
|
||||
def run(
|
||||
self,
|
||||
model_name,
|
||||
text,
|
||||
num_speakers,
|
||||
chunk_lines,
|
||||
lines_per_chunk,
|
||||
quantize_llm_4bit,
|
||||
attention_mode,
|
||||
cfg_scale,
|
||||
inference_steps,
|
||||
seed,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Orchestrates chunking and calls VibeVoiceTTSNode.generate_audio per chunk.
|
||||
Then concatenates to a single AUDIO dict.
|
||||
"""
|
||||
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
# return 1 second of silence at 24kHz, shape [1,1,24000]
|
||||
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
||||
|
||||
# Prepare speaker refs according to chosen number of speakers
|
||||
speaker_kwargs = self._filter_speaker_inputs(kwargs, max(1, min(4, int(num_speakers))))
|
||||
|
||||
# Prepare chunks (list of multiline strings)
|
||||
if chunk_lines:
|
||||
raw_lines = [ln for ln in text.splitlines() if ln.strip() != ""]
|
||||
groups = self._split_into_chunks(raw_lines, lines_per_chunk)
|
||||
chunk_texts = ["\n".join(g) for g in groups] if groups else [text]
|
||||
else:
|
||||
chunk_texts = [text]
|
||||
|
||||
# Progress bar over chunks
|
||||
pbar = ProgressBar(total=len(chunk_texts))
|
||||
|
||||
# Call the underlying node per chunk
|
||||
base = VibeVoiceTTSNode()
|
||||
audio_parts = []
|
||||
for idx, chunk in enumerate(chunk_texts, 1):
|
||||
out_audio = base.generate_audio(
|
||||
model_name=model_name,
|
||||
text=chunk,
|
||||
attention_mode=attention_mode,
|
||||
cfg_scale=cfg_scale,
|
||||
inference_steps=inference_steps,
|
||||
seed=seed,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
quantize_llm_4bit=quantize_llm_4bit,
|
||||
force_offload=False,
|
||||
**speaker_kwargs,
|
||||
)[0] # underlying returns (AUDIO,)
|
||||
audio_parts.append(out_audio)
|
||||
pbar.update(1)
|
||||
|
||||
# Concatenate into one AUDIO
|
||||
merged = self._concat_audio_dicts(audio_parts)
|
||||
return (merged,)
|
||||
|
||||
|
||||
# Register
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VibeVoiceTTS_Wrapper": VibeVoiceTTS_WrapperNode
|
||||
# Keep the base node mapping from your original file:
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VibeVoiceTTS_Wrapper": "VibeVoice TTS (Chunked Wrapper)"
|
||||
}
|
||||
Reference in New Issue
Block a user