# comfyui_vibevoice_chunked_wrapper.py import math import torch from comfy.utils import ProgressBar from .vibevoice_nodes import VibeVoiceTTSNode # We assume the base node class from your snippet is in the same module/file. # If it's in another module, import it instead: # from your_module import VibeVoiceTTSNode class VibeVoiceTTS_WrapperNode: """ Wraps VibeVoiceTTSNode, adds: - Number of Speakers (1-4) that gates which speaker_*_voice inputs are used - Chunking controls for multiline script ("Speaker N: ...") - Iterates per chunk, concatenates outputs into one AUDIO dict Returns: ("AUDIO",) — waveform [B, C, T], sample_rate per ComfyUI audio spec. """ @classmethod def INPUT_TYPES(cls): return { "required": { # Pass-through of model/decoding params to the underlying node: "model_name": (list(VibeVoiceTTSNode.INPUT_TYPES()["required"]["model_name"][0]), { "tooltip": "Forwarded to VibeVoiceTTSNode" }), "text": ("STRING", { "multiline": True, "default": "Speaker 1: Hello there!\nSpeaker 2: And hello from me.", "tooltip": "Multiline script: 'Speaker 1: ...' one line per utterance" }), "num_speakers": ("INT", { "default": 2, "min": 1, "max": 4, "step": 1, "tooltip": "How many speaker reference audios to use (1–4). Extra inputs are ignored." }), "chunk_lines": ("BOOLEAN", { "default": False, "label_on": "Chunk", "label_off": "No chunking", "tooltip": "When enabled, splits the script into groups of N lines and runs VibeVoice per chunk." }), "lines_per_chunk": ("INT", { "default": 20, "min": 1, "max": 999, "step": 1, "tooltip": "Only used when 'Chunk' is enabled." }), # Forwarded generation knobs: "quantize_llm_4bit": ("BOOLEAN", { "default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision" }), "attention_mode": (["eager", "sdpa", "flash_attention_2"], {"default": "sdpa"}), "cfg_scale": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05}), "inference_steps": ("INT", {"default": 10, "min": 1, "max": 50}), "seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True}), "do_sample": ("BOOLEAN", {"default": True, "label_on": "Sampling", "label_off": "Greedy"}), "temperature": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01}), "top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}), "top_k": ("INT", {"default": 0, "min": 0, "max": 500, "step": 1}), }, "optional": { # Provide up to 4 optional speaker audios; we enforce num_speakers in code. "speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 1"}), "speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 2"}), "speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 3"}), "speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for Speaker 4"}), }, # If you REALLY want these hidden until toggled via JS, you can also list them under "hidden" # and add a tiny JS extension to flip them visible. Pure-Python dynamic show/hide isn’t native. # see docs } RETURN_TYPES = ("AUDIO",) FUNCTION = "run" CATEGORY = "audio/tts" # --------- helpers --------- @staticmethod def _split_into_chunks(lines, n): """ Split list of lines into chunks of size n. If the last chunk would be < 40% of n, merge it into the previous chunk. """ if n <= 0: return [lines] if lines else [] chunks = [lines[i:i+n] for i in range(0, len(lines), n)] if len(chunks) >= 2: tail = chunks[-1] if len(tail) < math.ceil(0.4 * n): chunks[-2].extend(tail) chunks.pop() return chunks @staticmethod def _concat_audio_dicts(audio_dicts): """ Concatenate a list of ComfyUI AUDIO dicts along time dim T. Each dict: {"waveform": tensor[B,C,T], "sample_rate": int} Returns a single AUDIO dict of the same shape convention. """ if not audio_dicts: # Return 1-sample silence if nothing to concat return {"waveform": torch.zeros((1, 1, 1), dtype=torch.float32), "sample_rate": 24000} srs = {ad["sample_rate"] for ad in audio_dicts if ad and "sample_rate" in ad} if len(srs) != 1: raise ValueError(f"Sample rates differ across chunks: {srs}") sr = srs.pop() waves = [] for ad in audio_dicts: wf = ad["waveform"] # Expect [B, C, T] if wf.ndim == 1: wf = wf.unsqueeze(0).unsqueeze(0) # -> [1,1,T] elif wf.ndim == 2: wf = wf.unsqueeze(0) # -> [1,C,T] waves.append(wf) # Concatenate on time axis T (-1). Assumes batch (B) and channels (C) match. out = torch.cat(waves, dim=-1) return {"waveform": out.cpu(), "sample_rate": sr} @staticmethod def _filter_speaker_inputs(kwargs, num_speakers): """ Pulls up to num_speakers optional AUDIO inputs from kwargs. """ voices = [] for i in range(1, num_speakers + 1): voices.append(kwargs.get(f"speaker_{i}_voice")) # Fill the rest with None to align with underlying signature but ignored there while len(voices) < 4: voices.append(None) return { "speaker_1_voice": voices[0], "speaker_2_voice": voices[1], "speaker_3_voice": voices[2], "speaker_4_voice": voices[3], } # --------- main --------- def run( self, model_name, text, num_speakers, chunk_lines, lines_per_chunk, quantize_llm_4bit, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs, ): """ Orchestrates chunking and calls VibeVoiceTTSNode.generate_audio per chunk. Then concatenates to a single AUDIO dict. """ text = (text or "").strip() if not text: # return 1 second of silence at 24kHz, shape [1,1,24000] return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) # Prepare speaker refs according to chosen number of speakers speaker_kwargs = self._filter_speaker_inputs(kwargs, max(1, min(4, int(num_speakers)))) # Prepare chunks (list of multiline strings) if chunk_lines: raw_lines = [ln for ln in text.splitlines() if ln.strip() != ""] groups = self._split_into_chunks(raw_lines, lines_per_chunk) chunk_texts = ["\n".join(g) for g in groups] if groups else [text] else: chunk_texts = [text] # Progress bar over chunks pbar = ProgressBar(total=len(chunk_texts)) # Call the underlying node per chunk base = VibeVoiceTTSNode() audio_parts = [] for idx, chunk in enumerate(chunk_texts, 1): out_audio = base.generate_audio( model_name=model_name, text=chunk, attention_mode=attention_mode, cfg_scale=cfg_scale, inference_steps=inference_steps, seed=seed, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, quantize_llm_4bit=quantize_llm_4bit, force_offload=False, **speaker_kwargs, )[0] # underlying returns (AUDIO,) audio_parts.append(out_audio) pbar.update(1) # Concatenate into one AUDIO merged = self._concat_audio_dicts(audio_parts) return (merged,) # Register NODE_CLASS_MAPPINGS = { "VibeVoiceTTS_Wrapper": VibeVoiceTTS_WrapperNode # Keep the base node mapping from your original file: } NODE_DISPLAY_NAME_MAPPINGS = { "VibeVoiceTTS_Wrapper": "VibeVoice TTS (Chunked Wrapper)" }