vibingvoice commited on
Commit
cd0b70a
·
verified ·
1 Parent(s): 5db5dd1

Upload 34 files

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Fabio Sarracino - enemyx.net
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by Fabio Sarracino
2
+ __version__ = "1.3.0"
3
+ __author__ = "Fabio Sarracino"
4
+ __title__ = "VibeVoice ComfyUI"
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+ import subprocess
10
+
11
+ # Setup logging
12
+ logger = logging.getLogger("VibeVoice")
13
+ logger.propagate = False
14
+
15
+ if not logger.handlers:
16
+ handler = logging.StreamHandler()
17
+ formatter = logging.Formatter('[VibeVoice] %(message)s')
18
+ handler.setFormatter(formatter)
19
+ logger.addHandler(handler)
20
+ logger.setLevel(logging.INFO)
21
+
22
+ def apply_timm_compatibility_patches():
23
+ """Apply compatibility patches for timm package conflicts"""
24
+ try:
25
+ import timm.data
26
+
27
+ # Patch missing functions that cause import errors
28
+ patches = {
29
+ 'ImageNetInfo': lambda: type('ImageNetInfo', (), {'__init__': lambda self: None})(),
30
+ 'infer_imagenet_subset': lambda class_to_idx: 'imagenet',
31
+ 'get_imagenet_subset_labels': lambda *args, **kwargs: [],
32
+ 'get_imagenet_subset_info': lambda *args, **kwargs: {},
33
+ 'resolve_data_config': lambda *args, **kwargs: {}
34
+ }
35
+
36
+ for attr_name, patch_func in patches.items():
37
+ if not hasattr(timm.data, attr_name):
38
+ if attr_name == 'ImageNetInfo':
39
+ setattr(timm.data, attr_name, type('ImageNetInfo', (), {'__init__': lambda self: None}))
40
+ else:
41
+ setattr(timm.data, attr_name, patch_func)
42
+
43
+ return True
44
+ except Exception as e:
45
+ return False
46
+
47
+ def check_embedded_vibevoice():
48
+ """Check if embedded VibeVoice is available"""
49
+ vvembed_path = os.path.join(os.path.dirname(__file__), 'vvembed')
50
+ if not os.path.exists(vvembed_path):
51
+ logger.error(f"Embedded VibeVoice not found at {vvembed_path}")
52
+ return False
53
+
54
+ # Add vvembed to path if not already there
55
+ if vvembed_path not in sys.path:
56
+ sys.path.insert(0, vvembed_path)
57
+
58
+ logger.info("Using embedded VibeVoice (MIT licensed)")
59
+ return True
60
+
61
+ def ensure_dependencies():
62
+ """Ensure required dependencies are installed"""
63
+ try:
64
+ import transformers
65
+ from packaging import version
66
+ if version.parse(transformers.__version__) < version.parse("4.44.0"):
67
+ logger.warning("Transformers version < 4.44.0, some features may not work correctly")
68
+ except ImportError:
69
+ logger.warning("Transformers not installed. Please install: pip install transformers>=4.44.0")
70
+ return False
71
+
72
+ # Apply timm patches if needed
73
+ apply_timm_compatibility_patches()
74
+
75
+ return True
76
+
77
+ # Initialize node mappings
78
+ NODE_CLASS_MAPPINGS = {}
79
+ NODE_DISPLAY_NAME_MAPPINGS = {}
80
+
81
+ # Register text loading node (always available)
82
+ try:
83
+ from .nodes.load_text_node import LoadTextFromFileNode
84
+ NODE_CLASS_MAPPINGS["LoadTextFromFileNode"] = LoadTextFromFileNode
85
+ NODE_DISPLAY_NAME_MAPPINGS["LoadTextFromFileNode"] = "VibeVoice Load Text From File"
86
+ except Exception as e:
87
+ logger.error(f"Failed to register LoadTextFromFile node: {e}")
88
+
89
+ # Register VibeVoice nodes (using embedded VibeVoice)
90
+ if check_embedded_vibevoice() and ensure_dependencies():
91
+ try:
92
+ from .nodes.single_speaker_node import VibeVoiceSingleSpeakerNode
93
+ from .nodes.multi_speaker_node import VibeVoiceMultipleSpeakersNode
94
+ from .nodes.free_memory_node import VibeVoiceFreeMemoryNode
95
+
96
+ # Single speaker node
97
+ NODE_CLASS_MAPPINGS["VibeVoiceSingleSpeakerNode"] = VibeVoiceSingleSpeakerNode
98
+ NODE_DISPLAY_NAME_MAPPINGS["VibeVoiceSingleSpeakerNode"] = "VibeVoice Single Speaker"
99
+
100
+ # Multi speaker node
101
+ NODE_CLASS_MAPPINGS["VibeVoiceMultipleSpeakersNode"] = VibeVoiceMultipleSpeakersNode
102
+ NODE_DISPLAY_NAME_MAPPINGS["VibeVoiceMultipleSpeakersNode"] = "VibeVoice Multiple Speakers"
103
+
104
+ # Free memory node
105
+ NODE_CLASS_MAPPINGS["VibeVoiceFreeMemoryNode"] = VibeVoiceFreeMemoryNode
106
+ NODE_DISPLAY_NAME_MAPPINGS["VibeVoiceFreeMemoryNode"] = "VibeVoice Free Memory"
107
+
108
+ logger.info("VibeVoice nodes registered successfully")
109
+
110
+ except Exception as e:
111
+ logger.error(f"Failed to register VibeVoice nodes: {e}")
112
+ logger.info("Please ensure transformers>=4.44.0 is installed")
113
+ else:
114
+ logger.warning("VibeVoice nodes unavailable - check embedded module and dependencies")
115
+
116
+ __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', '__version__']
examples/Multiple-Speaker.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"id":"e5ca15c5-18b5-4d37-8852-795692a14b29","revision":0,"last_node_id":28,"last_link_id":41,"nodes":[{"id":20,"type":"Note","pos":[-76.50189208984375,752.0235595703125],"size":[415,88],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["Use Load Text From File if you want to use a .txt file instead of text-area. You can load .txt files from ComfyUI/input, ComfyUI/output or ComfyUI/temp directories."],"color":"#432","bgcolor":"#653"},{"id":16,"type":"PreviewAudio","pos":[914.1022338867188,189.5508270263672],"size":[270,88],"flags":{},"order":6,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":41},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"PreviewAudio"},"widgets_values":[],"color":"#323","bgcolor":"#535"},{"id":15,"type":"LoadAudio","pos":[-12.263749122619629,190.64144897460938],"size":[270,136],"flags":{},"order":1,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"COMBO","widget":{"name":"audio"},"link":null},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null},{"localized_name":"upload","name":"upload","type":"AUDIOUPLOAD","widget":{"name":"upload"},"link":null}],"outputs":[{"localized_name":"AUDIO","name":"AUDIO","type":"AUDIO","links":[39]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"LoadAudio"},"widgets_values":["Voice1.mp3",null,null],"color":"#2a363b","bgcolor":"#3f5159"},{"id":17,"type":"LoadAudio","pos":[-11.774602890014648,403.2247009277344],"size":[270,136],"flags":{},"order":2,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"COMBO","widget":{"name":"audio"},"link":null},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null},{"localized_name":"upload","name":"upload","type":"AUDIOUPLOAD","widget":{"name":"upload"},"link":null}],"outputs":[{"localized_name":"AUDIO","name":"AUDIO","type":"AUDIO","links":[40]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"LoadAudio"},"widgets_values":["Voice2.mp3",null,null],"color":"#2a363b","bgcolor":"#3f5159"},{"id":19,"type":"LoadTextFromFileNode","pos":[-8.260560989379883,651.4061279296875],"size":[270,58],"flags":{},"order":3,"mode":4,"inputs":[{"localized_name":"file","name":"file","type":"COMBO","widget":{"name":"file"},"link":null}],"outputs":[{"localized_name":"text","name":"text","type":"STRING","links":null}],"properties":{"Node name for S&R":"LoadTextFromFileNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":["No text files found in any directory"],"color":"#323","bgcolor":"#535"},{"id":28,"type":"VibeVoiceMultipleSpeakersNode","pos":[392.798095703125,188.42286682128906],"size":[400,388],"flags":{},"order":5,"mode":0,"inputs":[{"localized_name":"speaker1_voice","name":"speaker1_voice","shape":7,"type":"AUDIO","link":39},{"localized_name":"speaker2_voice","name":"speaker2_voice","shape":7,"type":"AUDIO","link":40},{"localized_name":"speaker3_voice","name":"speaker3_voice","shape":7,"type":"AUDIO","link":null},{"localized_name":"speaker4_voice","name":"speaker4_voice","shape":7,"type":"AUDIO","link":null},{"localized_name":"text","name":"text","type":"STRING","widget":{"name":"text"},"link":null},{"localized_name":"model","name":"model","type":"COMBO","widget":{"name":"model"},"link":null},{"localized_name":"attention_type","name":"attention_type","type":"COMBO","widget":{"name":"attention_type"},"link":null},{"localized_name":"free_memory_after_generate","name":"free_memory_after_generate","type":"BOOLEAN","widget":{"name":"free_memory_after_generate"},"link":null},{"localized_name":"diffusion_steps","name":"diffusion_steps","type":"INT","widget":{"name":"diffusion_steps"},"link":null},{"localized_name":"seed","name":"seed","type":"INT","widget":{"name":"seed"},"link":null},{"localized_name":"cfg_scale","name":"cfg_scale","type":"FLOAT","widget":{"name":"cfg_scale"},"link":null},{"localized_name":"use_sampling","name":"use_sampling","type":"BOOLEAN","widget":{"name":"use_sampling"},"link":null},{"localized_name":"temperature","name":"temperature","shape":7,"type":"FLOAT","widget":{"name":"temperature"},"link":null},{"localized_name":"top_p","name":"top_p","shape":7,"type":"FLOAT","widget":{"name":"top_p"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[41]}],"properties":{"Node name for S&R":"VibeVoiceMultipleSpeakersNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":["[1]: Hello, this is the first speaker.\n[2]: Hi there, I'm the second speaker.\n[1]: Nice to meet you!\n[2]: Nice to meet you too!","VibeVoice-Large","auto",true,20,42,"fixed",1.3,false,0.95,0.95],"color":"#223","bgcolor":"#335"},{"id":21,"type":"Note","pos":[387.95281982421875,624.7701416015625],"size":[415,88],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["The first time you use a model, it will be downloaded to ComfyUI/models/vibevoice/. This can take several minutes!\nThe VibeVoice-1.5B model (about 5GB)\nThe VibeVoice-Large model (about 17GB)\nThe VibeVoice-Large-Quant-4Bit model (about 7GB)"],"color":"#432","bgcolor":"#653"}],"links":[[39,15,0,28,0,"AUDIO"],[40,17,0,28,1,"AUDIO"],[41,28,0,16,0,"AUDIO"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.1,"offset":[200.3891510142744,-69.52646130728805]}},"version":0.4}
examples/Pause-Tag.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"id":"b70cf6f7-8531-4faa-9843-9c963a4ba577","revision":0,"last_node_id":38,"last_link_id":49,"nodes":[{"id":21,"type":"Note","pos":[-128.1415557861328,534.7645263671875],"size":[415,88],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["Use Load Text From File if you want to use a .txt file instead of text-area. You can load .txt files from ComfyUI/input, ComfyUI/output or ComfyUI/temp directories."],"color":"#432","bgcolor":"#653"},{"id":28,"type":"LoadTextFromFileNode","pos":[-65.6552963256836,428.2049865722656],"size":[289.5152282714844,58],"flags":{},"order":1,"mode":4,"inputs":[{"localized_name":"file","name":"file","type":"COMBO","widget":{"name":"file"},"link":null}],"outputs":[{"localized_name":"text","name":"text","type":"STRING","links":null}],"properties":{"Node name for S&R":"LoadTextFromFileNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":["No text files found in any directory"],"color":"#323","bgcolor":"#535"},{"id":15,"type":"LoadAudio","pos":[15.256911277770996,126.44892883300781],"size":[270,136],"flags":{},"order":2,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"COMBO","widget":{"name":"audio"},"link":null},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null},{"localized_name":"upload","name":"upload","type":"AUDIOUPLOAD","widget":{"name":"upload"},"link":null}],"outputs":[{"localized_name":"AUDIO","name":"AUDIO","type":"AUDIO","links":[48]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"LoadAudio"},"widgets_values":["Voice.mp3",null,null],"color":"#2a363b","bgcolor":"#3f5159"},{"id":16,"type":"PreviewAudio","pos":[892.3655395507812,127.41075897216797],"size":[270,88],"flags":{},"order":6,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":49},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"PreviewAudio"},"widgets_values":[],"color":"#323","bgcolor":"#535"},{"id":22,"type":"Note","pos":[365.11663818359375,535.5830078125],"size":[415,88],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["The first time you use a model, it will be downloaded to ComfyUI/models/vibevoice/. This can take several minutes!\nThe VibeVoice-1.5B model (about 5GB)\nThe VibeVoice-Large model (about 17GB)\nThe VibeVoice-Large-Quant-4Bit model (about 7GB)"],"color":"#432","bgcolor":"#653"},{"id":38,"type":"Note","pos":[818.8140869140625,270.3061218261719],"size":[415,88],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["[pause]: add 1 second of silence.\n[pause:{number}] add {number}ms of pause\nWARNING: the pause tag forces the text to be split into chunks. This may worsen the model’s ability to understand the context. The model’s context is represented ONLY by its own chunk."],"color":"#432","bgcolor":"#653"},{"id":37,"type":"VibeVoiceSingleSpeakerNode","pos":[376.42669677734375,126.94989013671875],"size":[400,352],"flags":{},"order":5,"mode":0,"inputs":[{"localized_name":"voice_to_clone","name":"voice_to_clone","shape":7,"type":"AUDIO","link":48},{"localized_name":"text","name":"text","type":"STRING","widget":{"name":"text"},"link":null},{"localized_name":"model","name":"model","type":"COMBO","widget":{"name":"model"},"link":null},{"localized_name":"attention_type","name":"attention_type","type":"COMBO","widget":{"name":"attention_type"},"link":null},{"localized_name":"free_memory_after_generate","name":"free_memory_after_generate","type":"BOOLEAN","widget":{"name":"free_memory_after_generate"},"link":null},{"localized_name":"diffusion_steps","name":"diffusion_steps","type":"INT","widget":{"name":"diffusion_steps"},"link":null},{"localized_name":"seed","name":"seed","type":"INT","widget":{"name":"seed"},"link":null},{"localized_name":"cfg_scale","name":"cfg_scale","type":"FLOAT","widget":{"name":"cfg_scale"},"link":null},{"localized_name":"use_sampling","name":"use_sampling","type":"BOOLEAN","widget":{"name":"use_sampling"},"link":null},{"localized_name":"temperature","name":"temperature","shape":7,"type":"FLOAT","widget":{"name":"temperature"},"link":null},{"localized_name":"top_p","name":"top_p","shape":7,"type":"FLOAT","widget":{"name":"top_p"},"link":null},{"localized_name":"max_words_per_chunk","name":"max_words_per_chunk","shape":7,"type":"INT","widget":{"name":"max_words_per_chunk"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[49]}],"properties":{"Node name for S&R":"VibeVoiceSingleSpeakerNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":["Hello, this is a test of the VibeVoice text-to-speech system. [pause] Do you like my voice? [pause:500] What's your name?","VibeVoice-1.5B","auto",true,20,42,"fixed",1.3,false,0.95,0.95,250],"color":"#223","bgcolor":"#335"}],"links":[[48,15,0,37,0,"AUDIO"],[49,37,0,16,0,"AUDIO"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.1000000000000005,"offset":[127.54923408733805,16.966619865757746]}},"version":0.4}
examples/Single-Speaker.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"id":"c6ef8963-032c-45f6-954f-b5f6b354343b","revision":0,"last_node_id":37,"last_link_id":49,"nodes":[{"id":21,"type":"Note","pos":[-128.1415557861328,534.7645263671875],"size":[415,88],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["Use Load Text From File if you want to use a .txt file instead of text-area. You can load .txt files from ComfyUI/input, ComfyUI/output or ComfyUI/temp directories."],"color":"#432","bgcolor":"#653"},{"id":28,"type":"LoadTextFromFileNode","pos":[-65.6552963256836,428.2049865722656],"size":[289.5152282714844,58],"flags":{},"order":1,"mode":4,"inputs":[{"localized_name":"file","name":"file","type":"COMBO","widget":{"name":"file"},"link":null}],"outputs":[{"localized_name":"text","name":"text","type":"STRING","links":null}],"properties":{"Node name for S&R":"LoadTextFromFileNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":["No text files found in any directory"],"color":"#323","bgcolor":"#535"},{"id":37,"type":"VibeVoiceSingleSpeakerNode","pos":[376.42669677734375,126.94989013671875],"size":[400,352],"flags":{},"order":4,"mode":0,"inputs":[{"localized_name":"voice_to_clone","name":"voice_to_clone","shape":7,"type":"AUDIO","link":48},{"localized_name":"text","name":"text","type":"STRING","widget":{"name":"text"},"link":null},{"localized_name":"model","name":"model","type":"COMBO","widget":{"name":"model"},"link":null},{"localized_name":"attention_type","name":"attention_type","type":"COMBO","widget":{"name":"attention_type"},"link":null},{"localized_name":"free_memory_after_generate","name":"free_memory_after_generate","type":"BOOLEAN","widget":{"name":"free_memory_after_generate"},"link":null},{"localized_name":"diffusion_steps","name":"diffusion_steps","type":"INT","widget":{"name":"diffusion_steps"},"link":null},{"localized_name":"seed","name":"seed","type":"INT","widget":{"name":"seed"},"link":null},{"localized_name":"cfg_scale","name":"cfg_scale","type":"FLOAT","widget":{"name":"cfg_scale"},"link":null},{"localized_name":"use_sampling","name":"use_sampling","type":"BOOLEAN","widget":{"name":"use_sampling"},"link":null},{"localized_name":"temperature","name":"temperature","shape":7,"type":"FLOAT","widget":{"name":"temperature"},"link":null},{"localized_name":"top_p","name":"top_p","shape":7,"type":"FLOAT","widget":{"name":"top_p"},"link":null},{"localized_name":"max_words_per_chunk","name":"max_words_per_chunk","shape":7,"type":"INT","widget":{"name":"max_words_per_chunk"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[49]}],"properties":{"Node name for S&R":"VibeVoiceSingleSpeakerNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":["Hello, this is a test of the VibeVoice text-to-speech system.","VibeVoice-1.5B","auto",true,20,42,"fixed",1.3,false,0.95,0.95,250],"color":"#223","bgcolor":"#335"},{"id":16,"type":"PreviewAudio","pos":[892.3655395507812,127.41075897216797],"size":[270,88],"flags":{},"order":5,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":49},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"PreviewAudio"},"widgets_values":[],"color":"#323","bgcolor":"#535"},{"id":15,"type":"LoadAudio","pos":[15.256911277770996,126.44892883300781],"size":[270,136],"flags":{},"order":3,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"COMBO","widget":{"name":"audio"},"link":null},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null},{"localized_name":"upload","name":"upload","type":"AUDIOUPLOAD","widget":{"name":"upload"},"link":null}],"outputs":[{"localized_name":"AUDIO","name":"AUDIO","type":"AUDIO","links":[48]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"LoadAudio"},"widgets_values":["Voice.mp3",null,null],"color":"#2a363b","bgcolor":"#3f5159"},{"id":22,"type":"Note","pos":[365.11663818359375,535.5830078125],"size":[415,88],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["The first time you use a model, it will be downloaded to ComfyUI/models/vibevoice/. This can take several minutes!\nThe VibeVoice-1.5B model (about 5GB)\nThe VibeVoice-Large model (about 17GB)\nThe VibeVoice-Large-Quant-4Bit model (about 7GB)"],"color":"#432","bgcolor":"#653"}],"links":[[48,15,0,37,0,"AUDIO"],[49,37,0,16,0,"AUDIO"]],"groups":[],"config":{},"extra":{"ds":{"scale":1.1000000000000005,"offset":[249.36741590551986,40.60298350212133]}},"version":0.4}
examples/VibeVoice-Unload-Memory.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"id":"fc471b7e-ccef-427f-be3f-29dec93a90ea","revision":0,"last_node_id":37,"last_link_id":47,"nodes":[{"id":34,"type":"VibeVoiceFreeMemoryNode","pos":[913.2552490234375,126.25599670410156],"size":[189.03964233398438,26],"flags":{},"order":6,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":47}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[42]}],"properties":{"Node name for S&R":"VibeVoiceFreeMemoryNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":[],"color":"#322","bgcolor":"#533"},{"id":35,"type":"Note","pos":[809.6192016601562,208.98324584960938],"size":[415,88],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["The VibeVoice Free Memory node releases memory as soon as it receives the audio input (acting as a passthrough for the audio itself). In this specific use case, however, it’s redundant, since it would be enough to enable the “free_memory_after_generate” parameter of the previous node. The ideal use case is, for example, when you have a loop generating multiple audio clips, and only after the final generation you pass the last audio and free the memory."],"color":"#432","bgcolor":"#653"},{"id":21,"type":"Note","pos":[-128.1415557861328,534.7645263671875],"size":[415,88],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["Use Load Text From File if you want to use a .txt file instead of text-area. You can load .txt files from ComfyUI/input, ComfyUI/output or ComfyUI/temp directories."],"color":"#432","bgcolor":"#653"},{"id":16,"type":"PreviewAudio","pos":[1271.0958251953125,126.20075988769531],"size":[270,88],"flags":{},"order":7,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","link":42},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"PreviewAudio"},"widgets_values":[],"color":"#323","bgcolor":"#535"},{"id":15,"type":"LoadAudio","pos":[15.256911277770996,126.44892883300781],"size":[270,136],"flags":{},"order":2,"mode":0,"inputs":[{"localized_name":"audio","name":"audio","type":"COMBO","widget":{"name":"audio"},"link":null},{"localized_name":"audioUI","name":"audioUI","type":"AUDIO_UI","widget":{"name":"audioUI"},"link":null},{"localized_name":"upload","name":"upload","type":"AUDIOUPLOAD","widget":{"name":"upload"},"link":null}],"outputs":[{"localized_name":"AUDIO","name":"AUDIO","type":"AUDIO","links":[46]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.49","Node name for S&R":"LoadAudio"},"widgets_values":["Voice.mp3",null,null],"color":"#2a363b","bgcolor":"#3f5159"},{"id":37,"type":"VibeVoiceSingleSpeakerNode","pos":[353.6889343261719,125.9052505493164],"size":[400,328],"flags":{},"order":5,"mode":0,"inputs":[{"localized_name":"voice_to_clone","name":"voice_to_clone","shape":7,"type":"AUDIO","link":46},{"localized_name":"text","name":"text","type":"STRING","widget":{"name":"text"},"link":null},{"localized_name":"model","name":"model","type":"COMBO","widget":{"name":"model"},"link":null},{"localized_name":"attention_type","name":"attention_type","type":"COMBO","widget":{"name":"attention_type"},"link":null},{"localized_name":"free_memory_after_generate","name":"free_memory_after_generate","type":"BOOLEAN","widget":{"name":"free_memory_after_generate"},"link":null},{"localized_name":"diffusion_steps","name":"diffusion_steps","type":"INT","widget":{"name":"diffusion_steps"},"link":null},{"localized_name":"seed","name":"seed","type":"INT","widget":{"name":"seed"},"link":null},{"localized_name":"cfg_scale","name":"cfg_scale","type":"FLOAT","widget":{"name":"cfg_scale"},"link":null},{"localized_name":"use_sampling","name":"use_sampling","type":"BOOLEAN","widget":{"name":"use_sampling"},"link":null},{"localized_name":"temperature","name":"temperature","shape":7,"type":"FLOAT","widget":{"name":"temperature"},"link":null},{"localized_name":"top_p","name":"top_p","shape":7,"type":"FLOAT","widget":{"name":"top_p"},"link":null}],"outputs":[{"localized_name":"audio","name":"audio","type":"AUDIO","links":[47]}],"properties":{"Node name for S&R":"VibeVoiceSingleSpeakerNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":["Hello, this is a test of the VibeVoice text-to-speech system.","VibeVoice-1.5B","auto",false,20,42,"fixed",1.3,false,0.95,0.95],"color":"#223","bgcolor":"#335"},{"id":28,"type":"LoadTextFromFileNode","pos":[-65.6552963256836,428.2049865722656],"size":[289.5152282714844,58],"flags":{},"order":3,"mode":4,"inputs":[{"localized_name":"file","name":"file","type":"COMBO","widget":{"name":"file"},"link":null}],"outputs":[{"localized_name":"text","name":"text","type":"STRING","links":null}],"properties":{"Node name for S&R":"LoadTextFromFileNode","cnr_id":"VibeVoice-ComfyUI","ver":"5a24489a7b0bf0c406d291dd51e82a085d338d44"},"widgets_values":["No text files found in any directory"],"color":"#323","bgcolor":"#535"},{"id":22,"type":"Note","pos":[349.11663818359375,500.03680419921875],"size":[415,88],"flags":{},"order":4,"mode":0,"inputs":[],"outputs":[],"properties":{},"widgets_values":["The first time you use a model, it will be downloaded to ComfyUI/models/vibevoice/. This can take several minutes!\nThe VibeVoice-1.5B model (about 5GB)\nThe VibeVoice-Large model (about 17GB)\nThe VibeVoice-Large-Quant-4Bit model (about 7GB)"],"color":"#432","bgcolor":"#653"}],"links":[[42,34,0,16,0,"AUDIO"],[46,15,0,37,0,"AUDIO"],[47,37,0,34,0,"AUDIO"]],"groups":[],"config":{},"extra":{"ds":{"scale":1,"offset":[186.3110536351836,148.09475114094386]}},"version":0.4}
node_list.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "VibeVoice Load Text From File": "Load .txt from ComfyUI input/output/temp",
3
+ "VibeVoice Single Speaker": "Single-speaker TTS with optional voice cloning",
4
+ "VibeVoice Multiple Speakers": "Multi-speaker TTS ([1]..[4]) with optional clones",
5
+ "VibeVoice Free Memory": "Frees loaded VibeVoice models; passthrough audio"
6
+ }
nodes/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by Fabio Sarracino
2
+ # Nodes module for VibeVoiceWrapper
3
+ """
4
+ This module contains all the ComfyUI nodes for VibeVoice integration.
5
+ """
6
+
7
+ from .load_text_node import LoadTextFromFileNode
8
+ from .single_speaker_node import VibeVoiceSingleSpeakerNode
9
+ from .multi_speaker_node import VibeVoiceMultipleSpeakersNode
10
+ from .free_memory_node import VibeVoiceFreeMemoryNode
11
+
12
+ __all__ = [
13
+ 'LoadTextFromFileNode',
14
+ 'VibeVoiceSingleSpeakerNode',
15
+ 'VibeVoiceMultipleSpeakersNode',
16
+ 'VibeVoiceFreeMemoryNode'
17
+ ]
nodes/base_vibevoice.py ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by Fabio Sarracino
2
+ # Base class for VibeVoice nodes with common functionality
3
+
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ import torch
8
+ import numpy as np
9
+ import re
10
+ import gc
11
+ from typing import List, Optional, Tuple, Any
12
+
13
+ # Setup logging
14
+ logger = logging.getLogger("VibeVoice")
15
+
16
+ # Import for interruption support
17
+ try:
18
+ import execution
19
+ INTERRUPTION_SUPPORT = True
20
+ except ImportError:
21
+ INTERRUPTION_SUPPORT = False
22
+ logger.warning("Interruption support not available")
23
+
24
+ # Check for SageAttention availability
25
+ try:
26
+ from sageattention import sageattn
27
+ SAGE_AVAILABLE = True
28
+ logger.info("SageAttention available for acceleration")
29
+ except ImportError:
30
+ SAGE_AVAILABLE = False
31
+ logger.debug("SageAttention not available - install with: pip install sageattention")
32
+
33
+ def get_optimal_device():
34
+ """Get the best available device (cuda, mps, or cpu)"""
35
+ if torch.cuda.is_available():
36
+ return "cuda"
37
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
38
+ return "mps"
39
+ else:
40
+ return "cpu"
41
+
42
+ def get_device_map():
43
+ """Get device map for model loading"""
44
+ device = get_optimal_device()
45
+ # Note: device_map "auto" might work better for MPS in some cases
46
+ return device if device != "mps" else "mps"
47
+
48
+ class BaseVibeVoiceNode:
49
+ """Base class for VibeVoice nodes containing common functionality"""
50
+
51
+ def __init__(self):
52
+ self.model = None
53
+ self.processor = None
54
+ self.current_model_path = None
55
+ self.current_attention_type = None
56
+
57
+ def free_memory(self):
58
+ """Free model and processor from memory"""
59
+ try:
60
+ if self.model is not None:
61
+ del self.model
62
+ self.model = None
63
+
64
+ if self.processor is not None:
65
+ del self.processor
66
+ self.processor = None
67
+
68
+ self.current_model_path = None
69
+
70
+ # Force garbage collection and clear CUDA cache if available
71
+ import gc
72
+ gc.collect()
73
+
74
+ if torch.cuda.is_available():
75
+ torch.cuda.empty_cache()
76
+ torch.cuda.synchronize()
77
+
78
+ logger.info("Model and processor memory freed successfully")
79
+
80
+ except Exception as e:
81
+ logger.error(f"Error freeing memory: {e}")
82
+
83
+ def _check_dependencies(self):
84
+ """Check if VibeVoice is available and import it with fallback installation"""
85
+ try:
86
+ import sys
87
+ import os
88
+
89
+ # Add vvembed to path
90
+ current_dir = os.path.dirname(os.path.abspath(__file__))
91
+ parent_dir = os.path.dirname(current_dir)
92
+ vvembed_path = os.path.join(parent_dir, 'vvembed')
93
+
94
+ if vvembed_path not in sys.path:
95
+ sys.path.insert(0, vvembed_path)
96
+
97
+ # Import from embedded version
98
+ from modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
99
+
100
+ logger.info(f"Using embedded VibeVoice from {vvembed_path}")
101
+ return None, VibeVoiceForConditionalGenerationInference
102
+
103
+ except ImportError as e:
104
+ logger.error(f"Embedded VibeVoice import failed: {e}")
105
+
106
+ # Try fallback to installed version if available
107
+ try:
108
+ import vibevoice
109
+ from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
110
+ logger.warning("Falling back to system-installed VibeVoice")
111
+ return vibevoice, VibeVoiceForConditionalGenerationInference
112
+ except ImportError:
113
+ pass
114
+
115
+ raise Exception(
116
+ "VibeVoice embedded module import failed. Please ensure the vvembed folder exists "
117
+ "and transformers>=4.51.3 is installed."
118
+ )
119
+
120
+ def _apply_sage_attention(self):
121
+ """Apply SageAttention to the loaded model by monkey-patching attention layers"""
122
+ try:
123
+ from sageattention import sageattn
124
+ import torch.nn.functional as F
125
+
126
+ # Counter for patched layers
127
+ patched_count = 0
128
+
129
+ def patch_attention_forward(module):
130
+ """Recursively patch attention layers to use SageAttention"""
131
+ nonlocal patched_count
132
+
133
+ # Check if this module has scaled_dot_product_attention
134
+ if hasattr(module, 'forward'):
135
+ original_forward = module.forward
136
+
137
+ # Create wrapper that replaces F.scaled_dot_product_attention with sageattn
138
+ def sage_forward(*args, **kwargs):
139
+ # Temporarily replace F.scaled_dot_product_attention
140
+ original_sdpa = F.scaled_dot_product_attention
141
+
142
+ def sage_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, **kwargs):
143
+ """Wrapper that converts sdpa calls to sageattn"""
144
+ # Log any unexpected parameters for debugging
145
+ if kwargs:
146
+ unexpected_params = list(kwargs.keys())
147
+ logger.debug(f"SageAttention: Ignoring unsupported parameters: {unexpected_params}")
148
+
149
+ try:
150
+ # SageAttention expects tensors in specific format
151
+ # Transformers typically use (batch, heads, seq_len, head_dim)
152
+
153
+ # Check tensor dimensions to determine layout
154
+ if query.dim() == 4:
155
+ # 4D tensor: (batch, heads, seq, dim)
156
+ batch_size = query.shape[0]
157
+ num_heads = query.shape[1]
158
+ seq_len_q = query.shape[2]
159
+ seq_len_k = key.shape[2]
160
+ head_dim = query.shape[3]
161
+
162
+ # Reshape to (batch*heads, seq, dim) for HND layout
163
+ query_reshaped = query.reshape(batch_size * num_heads, seq_len_q, head_dim)
164
+ key_reshaped = key.reshape(batch_size * num_heads, seq_len_k, head_dim)
165
+ value_reshaped = value.reshape(batch_size * num_heads, seq_len_k, head_dim)
166
+
167
+ # Call sageattn with HND layout
168
+ output = sageattn(
169
+ query_reshaped, key_reshaped, value_reshaped,
170
+ is_causal=is_causal,
171
+ tensor_layout="HND" # Heads*batch, seqN, Dim
172
+ )
173
+
174
+ # Output should be (batch*heads, seq_len_q, head_dim)
175
+ # Reshape back to (batch, heads, seq, dim)
176
+ if output.dim() == 3:
177
+ output = output.reshape(batch_size, num_heads, seq_len_q, head_dim)
178
+
179
+ return output
180
+ else:
181
+ # For 3D tensors, assume they're already in HND format
182
+ output = sageattn(
183
+ query, key, value,
184
+ is_causal=is_causal,
185
+ tensor_layout="HND"
186
+ )
187
+ return output
188
+
189
+ except Exception as e:
190
+ # If SageAttention fails, fall back to original implementation
191
+ logger.debug(f"SageAttention failed, using original: {e}")
192
+ # Call with proper arguments - scale is a keyword argument in PyTorch 2.0+
193
+ # Pass through any additional kwargs that the original sdpa might support
194
+ if scale is not None:
195
+ return original_sdpa(query, key, value, attn_mask=attn_mask,
196
+ dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
197
+ else:
198
+ return original_sdpa(query, key, value, attn_mask=attn_mask,
199
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs)
200
+
201
+ # Replace the function
202
+ F.scaled_dot_product_attention = sage_sdpa
203
+
204
+ try:
205
+ # Call original forward with patched attention
206
+ result = original_forward(*args, **kwargs)
207
+ finally:
208
+ # Restore original function
209
+ F.scaled_dot_product_attention = original_sdpa
210
+
211
+ return result
212
+
213
+ # Check if this module likely uses attention
214
+ # Look for common attention module names
215
+ module_name = module.__class__.__name__.lower()
216
+ if any(name in module_name for name in ['attention', 'attn', 'multihead']):
217
+ module.forward = sage_forward
218
+ patched_count += 1
219
+
220
+ # Recursively patch child modules
221
+ for child in module.children():
222
+ patch_attention_forward(child)
223
+
224
+ # Apply patching to the entire model
225
+ patch_attention_forward(self.model)
226
+
227
+ logger.info(f"Patched {patched_count} attention layers with SageAttention")
228
+
229
+ if patched_count == 0:
230
+ logger.warning("No attention layers found to patch - SageAttention may not be applied")
231
+
232
+ except Exception as e:
233
+ logger.error(f"Failed to apply SageAttention: {e}")
234
+ logger.warning("Continuing with standard attention implementation")
235
+
236
+ def load_model(self, model_name: str, model_path: str, attention_type: str = "auto"):
237
+ """Load VibeVoice model with specified attention implementation
238
+
239
+ Args:
240
+ model_name: The display name of the model (e.g., "VibeVoice-Large-Quant-4Bit")
241
+ model_path: The HuggingFace model path
242
+ attention_type: The attention implementation to use
243
+ """
244
+ # Check if we need to reload model due to attention type change
245
+ current_attention = getattr(self, 'current_attention_type', None)
246
+ if (self.model is None or
247
+ getattr(self, 'current_model_path', None) != model_path or
248
+ current_attention != attention_type):
249
+
250
+ # Free existing model before loading new one (important for attention type changes)
251
+ if self.model is not None and (current_attention != attention_type or getattr(self, 'current_model_path', None) != model_path):
252
+ logger.info(f"Freeing existing model before loading with new settings (attention: {current_attention} -> {attention_type})")
253
+ self.free_memory()
254
+
255
+ try:
256
+ vibevoice, VibeVoiceInferenceModel = self._check_dependencies()
257
+
258
+ # Set ComfyUI models directory
259
+ import folder_paths
260
+ models_dir = folder_paths.get_folder_paths("checkpoints")[0]
261
+ comfyui_models_dir = os.path.join(os.path.dirname(models_dir), "vibevoice")
262
+ os.makedirs(comfyui_models_dir, exist_ok=True)
263
+
264
+ # Force HuggingFace to use ComfyUI directory
265
+ original_hf_home = os.environ.get('HF_HOME')
266
+ original_hf_cache = os.environ.get('HUGGINGFACE_HUB_CACHE')
267
+
268
+ os.environ['HF_HOME'] = comfyui_models_dir
269
+ os.environ['HUGGINGFACE_HUB_CACHE'] = comfyui_models_dir
270
+
271
+ # Import time for timing
272
+ import time
273
+ start_time = time.time()
274
+
275
+ # Suppress verbose logs
276
+ import transformers
277
+ import warnings
278
+ transformers.logging.set_verbosity_error()
279
+ warnings.filterwarnings("ignore", category=UserWarning)
280
+
281
+ # Check if model exists locally
282
+ model_dir = os.path.join(comfyui_models_dir, f"models--{model_path.replace('/', '--')}")
283
+ model_exists_in_comfyui = os.path.exists(model_dir)
284
+
285
+ # Check if this is a quantized model based on the model name
286
+ is_quantized_4bit = "Quant-4Bit" in model_name
287
+ is_quantized_8bit = "Quant-8Bit" in model_name # Future support
288
+
289
+ # Prepare attention implementation kwargs
290
+ model_kwargs = {
291
+ "cache_dir": comfyui_models_dir,
292
+ "trust_remote_code": True,
293
+ "torch_dtype": torch.bfloat16,
294
+ "device_map": get_device_map(),
295
+ }
296
+
297
+ # Handle 4-bit quantized model loading
298
+ if is_quantized_4bit:
299
+ # Check if CUDA is available (required for 4-bit quantization)
300
+ if not torch.cuda.is_available():
301
+ raise Exception("4-bit quantized models require a CUDA GPU. Please use standard models on CPU/MPS.")
302
+
303
+ # Try to import bitsandbytes
304
+ try:
305
+ from transformers import BitsAndBytesConfig
306
+ logger.info("Loading 4-bit quantized model with bitsandbytes...")
307
+
308
+ # Configure 4-bit quantization
309
+ bnb_config = BitsAndBytesConfig(
310
+ load_in_4bit=True,
311
+ bnb_4bit_compute_dtype=torch.bfloat16,
312
+ bnb_4bit_use_double_quant=True,
313
+ bnb_4bit_quant_type='nf4'
314
+ )
315
+ model_kwargs["quantization_config"] = bnb_config
316
+ model_kwargs["device_map"] = "cuda" # Force CUDA for 4-bit
317
+ model_kwargs["subfolder"] = "4bit" # Point to 4bit subfolder
318
+
319
+ except ImportError:
320
+ raise Exception(
321
+ "4-bit quantized models require 'bitsandbytes' library.\n"
322
+ "Please install it with: pip install bitsandbytes\n"
323
+ "Or use the standard VibeVoice models instead."
324
+ )
325
+
326
+ # Set attention implementation based on user selection
327
+ use_sage_attention = False
328
+ if attention_type == "sage":
329
+ # SageAttention requires special handling - can't be set via attn_implementation
330
+ if not SAGE_AVAILABLE:
331
+ logger.warning("SageAttention not installed, falling back to sdpa")
332
+ logger.warning("Install with: pip install sageattention")
333
+ model_kwargs["attn_implementation"] = "sdpa"
334
+ elif not torch.cuda.is_available():
335
+ logger.warning("SageAttention requires CUDA GPU, falling back to sdpa")
336
+ model_kwargs["attn_implementation"] = "sdpa"
337
+ else:
338
+ # Don't set attn_implementation for sage, will apply after loading
339
+ use_sage_attention = True
340
+ logger.info("Will apply SageAttention after model loading")
341
+ elif attention_type != "auto":
342
+ model_kwargs["attn_implementation"] = attention_type
343
+ logger.info(f"Using {attention_type} attention implementation")
344
+ else:
345
+ # Auto mode - let transformers decide the best implementation
346
+ logger.info("Using auto attention implementation selection")
347
+
348
+ # Try to load locally first
349
+ try:
350
+ if model_exists_in_comfyui:
351
+ model_kwargs["local_files_only"] = True
352
+ logger.info(f"Loading model from local cache: {model_path}")
353
+ if is_quantized_4bit:
354
+ logger.info(f"Using 4-bit quantization with subfolder: {model_kwargs.get('subfolder', 'None')}")
355
+ self.model = VibeVoiceInferenceModel.from_pretrained(
356
+ model_path,
357
+ **model_kwargs
358
+ )
359
+ else:
360
+ raise FileNotFoundError("Model not found locally")
361
+ except (FileNotFoundError, OSError) as e:
362
+ logger.info(f"Downloading {model_path}...")
363
+ if is_quantized_4bit:
364
+ logger.info(f"Downloading 4-bit quantized model with subfolder: {model_kwargs.get('subfolder', 'None')}")
365
+
366
+ model_kwargs["local_files_only"] = False
367
+ self.model = VibeVoiceInferenceModel.from_pretrained(
368
+ model_path,
369
+ **model_kwargs
370
+ )
371
+ elapsed = time.time() - start_time
372
+ else:
373
+ elapsed = time.time() - start_time
374
+
375
+ # Verify model was loaded
376
+ if self.model is None:
377
+ raise Exception("Model failed to load - model is None after loading")
378
+
379
+ # Load processor with proper error handling
380
+ from processor.vibevoice_processor import VibeVoiceProcessor
381
+
382
+ # Prepare processor kwargs
383
+ processor_kwargs = {
384
+ "trust_remote_code": True,
385
+ "cache_dir": comfyui_models_dir
386
+ }
387
+
388
+ # Add subfolder for quantized models
389
+ if is_quantized_4bit:
390
+ processor_kwargs["subfolder"] = "4bit"
391
+
392
+ try:
393
+ # First try with local files if model was loaded locally
394
+ if model_exists_in_comfyui:
395
+ processor_kwargs["local_files_only"] = True
396
+ self.processor = VibeVoiceProcessor.from_pretrained(
397
+ model_path,
398
+ **processor_kwargs
399
+ )
400
+ else:
401
+ # Download from HuggingFace
402
+ self.processor = VibeVoiceProcessor.from_pretrained(
403
+ model_path,
404
+ **processor_kwargs
405
+ )
406
+ except Exception as proc_error:
407
+ logger.warning(f"Failed to load processor from {model_path}: {proc_error}")
408
+
409
+ # Check if error is about missing Qwen tokenizer
410
+ if "Qwen" in str(proc_error) and "tokenizer" in str(proc_error).lower():
411
+ logger.info("Downloading required Qwen tokenizer files...")
412
+ # The processor needs the Qwen tokenizer, ensure it's available
413
+ try:
414
+ from transformers import AutoTokenizer
415
+ # Pre-download the Qwen tokenizer that VibeVoice depends on
416
+ _ = AutoTokenizer.from_pretrained(
417
+ "Qwen/Qwen2.5-1.5B",
418
+ trust_remote_code=True,
419
+ cache_dir=comfyui_models_dir
420
+ )
421
+ logger.info("Qwen tokenizer downloaded, retrying processor load...")
422
+ except Exception as tokenizer_error:
423
+ logger.warning(f"Failed to download Qwen tokenizer: {tokenizer_error}")
424
+
425
+ logger.info("Attempting to load processor with fallback method...")
426
+
427
+ # Fallback: try loading without local_files_only constraint
428
+ try:
429
+ self.processor = VibeVoiceProcessor.from_pretrained(
430
+ model_path,
431
+ local_files_only=False,
432
+ trust_remote_code=True,
433
+ cache_dir=comfyui_models_dir
434
+ )
435
+ except Exception as fallback_error:
436
+ logger.error(f"Processor loading failed completely: {fallback_error}")
437
+ raise Exception(
438
+ f"Failed to load VibeVoice processor. Error: {fallback_error}\n"
439
+ f"This might be due to missing tokenizer files. Try:\n"
440
+ f"1. Ensure you have internet connection for first-time download\n"
441
+ f"2. Clear the ComfyUI/models/vibevoice folder and retry\n"
442
+ f"3. Install transformers: pip install transformers>=4.51.3\n"
443
+ f"4. Manually download Qwen tokenizer: from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('Qwen/Qwen2.5-1.5B')"
444
+ )
445
+
446
+ # Restore environment variables
447
+ if original_hf_home is not None:
448
+ os.environ['HF_HOME'] = original_hf_home
449
+ elif 'HF_HOME' in os.environ:
450
+ del os.environ['HF_HOME']
451
+
452
+ if original_hf_cache is not None:
453
+ os.environ['HUGGINGFACE_HUB_CACHE'] = original_hf_cache
454
+ elif 'HUGGINGFACE_HUB_CACHE' in os.environ:
455
+ del os.environ['HUGGINGFACE_HUB_CACHE']
456
+
457
+ # Move to appropriate device (skip for quantized models as they use device_map)
458
+ if not is_quantized_4bit and not is_quantized_8bit:
459
+ device = get_optimal_device()
460
+ if device == "cuda":
461
+ self.model = self.model.cuda()
462
+ elif device == "mps":
463
+ self.model = self.model.to("mps")
464
+ else:
465
+ logger.info("Quantized model already mapped to device via device_map")
466
+
467
+ # Apply SageAttention if requested and available
468
+ if use_sage_attention and SAGE_AVAILABLE:
469
+ self._apply_sage_attention()
470
+ logger.info("SageAttention successfully applied to model")
471
+
472
+ self.current_model_path = model_path
473
+ self.current_attention_type = attention_type
474
+
475
+ except Exception as e:
476
+ logger.error(f"Failed to load VibeVoice model: {str(e)}")
477
+ raise Exception(f"Model loading failed: {str(e)}")
478
+
479
+ def _create_synthetic_voice_sample(self, speaker_idx: int) -> np.ndarray:
480
+ """Create synthetic voice sample for a specific speaker"""
481
+ sample_rate = 24000
482
+ duration = 1.0
483
+ samples = int(sample_rate * duration)
484
+
485
+ t = np.linspace(0, duration, samples, False)
486
+
487
+ # Create realistic voice-like characteristics for each speaker
488
+ # Use different base frequencies for different speaker types
489
+ base_frequencies = [120, 180, 140, 200] # Mix of male/female-like frequencies
490
+ base_freq = base_frequencies[speaker_idx % len(base_frequencies)]
491
+
492
+ # Create vowel-like formants (like "ah" sound) - unique per speaker
493
+ formant1 = 800 + speaker_idx * 100 # First formant
494
+ formant2 = 1200 + speaker_idx * 150 # Second formant
495
+
496
+ # Generate more voice-like waveform
497
+ voice_sample = (
498
+ # Fundamental with harmonics (voice-like)
499
+ 0.6 * np.sin(2 * np.pi * base_freq * t) +
500
+ 0.25 * np.sin(2 * np.pi * base_freq * 2 * t) +
501
+ 0.15 * np.sin(2 * np.pi * base_freq * 3 * t) +
502
+
503
+ # Formant resonances (vowel-like characteristics)
504
+ 0.1 * np.sin(2 * np.pi * formant1 * t) * np.exp(-t * 2) +
505
+ 0.05 * np.sin(2 * np.pi * formant2 * t) * np.exp(-t * 3) +
506
+
507
+ # Natural breath noise (reduced)
508
+ 0.02 * np.random.normal(0, 1, len(t))
509
+ )
510
+
511
+ # Add natural envelope (like human speech pattern)
512
+ # Quick attack, slower decay with slight vibrato (unique per speaker)
513
+ vibrato_freq = 4 + speaker_idx * 0.3 # Slightly different vibrato per speaker
514
+ envelope = (np.exp(-t * 0.3) * (1 + 0.1 * np.sin(2 * np.pi * vibrato_freq * t)))
515
+ voice_sample *= envelope * 0.08 # Lower volume
516
+
517
+ return voice_sample.astype(np.float32)
518
+
519
+ def _prepare_audio_from_comfyui(self, voice_audio, target_sample_rate: int = 24000) -> Optional[np.ndarray]:
520
+ """Prepare audio from ComfyUI format to numpy array"""
521
+ if voice_audio is None:
522
+ return None
523
+
524
+ # Extract waveform from ComfyUI audio format
525
+ if isinstance(voice_audio, dict) and "waveform" in voice_audio:
526
+ waveform = voice_audio["waveform"]
527
+ input_sample_rate = voice_audio.get("sample_rate", target_sample_rate)
528
+
529
+ # Convert to numpy (handling BFloat16 tensors)
530
+ if isinstance(waveform, torch.Tensor):
531
+ # Convert to float32 first as numpy doesn't support BFloat16
532
+ audio_np = waveform.cpu().float().numpy()
533
+ else:
534
+ audio_np = np.array(waveform)
535
+
536
+ # Handle different audio shapes
537
+ if audio_np.ndim == 3: # (batch, channels, samples)
538
+ audio_np = audio_np[0, 0, :] # Take first batch, first channel
539
+ elif audio_np.ndim == 2: # (channels, samples)
540
+ audio_np = audio_np[0, :] # Take first channel
541
+ # If 1D, leave as is
542
+
543
+ # Resample if needed
544
+ if input_sample_rate != target_sample_rate:
545
+ target_length = int(len(audio_np) * target_sample_rate / input_sample_rate)
546
+ audio_np = np.interp(np.linspace(0, len(audio_np), target_length),
547
+ np.arange(len(audio_np)), audio_np)
548
+
549
+ # Ensure audio is in correct range [-1, 1]
550
+ audio_max = np.abs(audio_np).max()
551
+ if audio_max > 0:
552
+ audio_np = audio_np / max(audio_max, 1.0) # Normalize
553
+
554
+ return audio_np.astype(np.float32)
555
+
556
+ return None
557
+
558
+ def _get_model_mapping(self) -> dict:
559
+ """Get model name mappings"""
560
+ return {
561
+ "VibeVoice-1.5B": "microsoft/VibeVoice-1.5B",
562
+ "VibeVoice-Large": "aoi-ot/VibeVoice-Large",
563
+ "VibeVoice-Large-Quant-4Bit": "DevParker/VibeVoice7b-low-vram"
564
+ }
565
+
566
+ def _split_text_into_chunks(self, text: str, max_words: int = 250) -> List[str]:
567
+ """Split long text into manageable chunks at sentence boundaries
568
+
569
+ Args:
570
+ text: The text to split
571
+ max_words: Maximum words per chunk (default 250 for safety)
572
+
573
+ Returns:
574
+ List of text chunks
575
+ """
576
+ import re
577
+
578
+ # Split into sentences (handling common abbreviations)
579
+ # This regex tries to split on sentence endings while avoiding common abbreviations
580
+ sentence_pattern = r'(?<=[.!?])\s+(?=[A-Z])'
581
+ sentences = re.split(sentence_pattern, text)
582
+
583
+ # If regex split didn't work well, fall back to simple split
584
+ if len(sentences) == 1 and len(text.split()) > max_words:
585
+ # Fall back to splitting on any period followed by space
586
+ sentences = text.replace('. ', '.|').split('|')
587
+ sentences = [s.strip() for s in sentences if s.strip()]
588
+
589
+ chunks = []
590
+ current_chunk = []
591
+ current_word_count = 0
592
+
593
+ for sentence in sentences:
594
+ sentence = sentence.strip()
595
+ if not sentence:
596
+ continue
597
+
598
+ sentence_words = sentence.split()
599
+ sentence_word_count = len(sentence_words)
600
+
601
+ # If single sentence is too long, split it further
602
+ if sentence_word_count > max_words:
603
+ # Split long sentence at commas or semicolons
604
+ sub_parts = re.split(r'[,;]', sentence)
605
+ for part in sub_parts:
606
+ part = part.strip()
607
+ if not part:
608
+ continue
609
+ part_words = part.split()
610
+ part_word_count = len(part_words)
611
+
612
+ if current_word_count + part_word_count > max_words and current_chunk:
613
+ # Save current chunk
614
+ chunks.append(' '.join(current_chunk))
615
+ current_chunk = [part]
616
+ current_word_count = part_word_count
617
+ else:
618
+ current_chunk.append(part)
619
+ current_word_count += part_word_count
620
+ else:
621
+ # Check if adding this sentence would exceed the limit
622
+ if current_word_count + sentence_word_count > max_words and current_chunk:
623
+ # Save current chunk and start new one
624
+ chunks.append(' '.join(current_chunk))
625
+ current_chunk = [sentence]
626
+ current_word_count = sentence_word_count
627
+ else:
628
+ # Add sentence to current chunk
629
+ current_chunk.append(sentence)
630
+ current_word_count += sentence_word_count
631
+
632
+ # Add remaining chunk
633
+ if current_chunk:
634
+ chunks.append(' '.join(current_chunk))
635
+
636
+ # If no chunks were created, return the original text
637
+ if not chunks:
638
+ chunks = [text]
639
+
640
+ logger.info(f"Split text into {len(chunks)} chunks (max {max_words} words each)")
641
+ for i, chunk in enumerate(chunks):
642
+ word_count = len(chunk.split())
643
+ logger.debug(f"Chunk {i+1}: {word_count} words")
644
+
645
+ return chunks
646
+
647
+ def _parse_pause_keywords(self, text: str) -> List[Tuple[str, Any]]:
648
+ """Parse [pause] and [pause:ms] keywords from text
649
+
650
+ Args:
651
+ text: Text potentially containing pause keywords
652
+
653
+ Returns:
654
+ List of tuples: ('text', str) or ('pause', duration_ms)
655
+ """
656
+ segments = []
657
+ # Pattern matches [pause] or [pause:1500] where 1500 is milliseconds
658
+ pattern = r'\[pause(?::(\d+))?\]'
659
+
660
+ last_end = 0
661
+ for match in re.finditer(pattern, text):
662
+ # Add text segment before pause (if any)
663
+ if match.start() > last_end:
664
+ text_segment = text[last_end:match.start()].strip()
665
+ if text_segment: # Only add non-empty text segments
666
+ segments.append(('text', text_segment))
667
+
668
+ # Add pause segment with duration (default 1000ms = 1 second)
669
+ duration_ms = int(match.group(1)) if match.group(1) else 1000
670
+ segments.append(('pause', duration_ms))
671
+ last_end = match.end()
672
+
673
+ # Add remaining text after last pause (if any)
674
+ if last_end < len(text):
675
+ remaining_text = text[last_end:].strip()
676
+ if remaining_text:
677
+ segments.append(('text', remaining_text))
678
+
679
+ # If no pauses found, return original text as single segment
680
+ if not segments:
681
+ segments.append(('text', text))
682
+
683
+ logger.debug(f"Parsed text into {len(segments)} segments (including pauses)")
684
+ return segments
685
+
686
+ def _generate_silence(self, duration_ms: int, sample_rate: int = 24000) -> dict:
687
+ """Generate silence audio tensor for specified duration
688
+
689
+ Args:
690
+ duration_ms: Duration of silence in milliseconds
691
+ sample_rate: Sample rate (default 24000 Hz for VibeVoice)
692
+
693
+ Returns:
694
+ Audio dict with silence waveform
695
+ """
696
+ # Calculate number of samples for the duration
697
+ num_samples = int(sample_rate * duration_ms / 1000.0)
698
+
699
+ # Create silence tensor with shape (1, 1, num_samples) to match audio format
700
+ silence_waveform = torch.zeros(1, 1, num_samples, dtype=torch.float32)
701
+
702
+ logger.info(f"Generated {duration_ms}ms silence ({num_samples} samples)")
703
+
704
+ return {
705
+ "waveform": silence_waveform,
706
+ "sample_rate": sample_rate
707
+ }
708
+
709
+ def _format_text_for_vibevoice(self, text: str, speakers: list) -> str:
710
+ """Format text with speaker information for VibeVoice using correct format"""
711
+ # Remove any newlines from the text to prevent parsing issues
712
+ # The processor splits by newline and expects each line to have "Speaker N:" format
713
+ text = text.replace('\n', ' ').replace('\r', ' ')
714
+ # Clean up multiple spaces
715
+ text = ' '.join(text.split())
716
+
717
+ # VibeVoice expects format: "Speaker 1: text" not "Name: text"
718
+ if len(speakers) == 1:
719
+ return f"Speaker 1: {text}"
720
+ else:
721
+ # Check if text already has proper Speaker N: format
722
+ if re.match(r'^\s*Speaker\s+\d+\s*:', text, re.IGNORECASE):
723
+ return text
724
+ # If text has name format, convert to Speaker N format
725
+ elif any(f"{speaker}:" in text for speaker in speakers):
726
+ formatted_text = text
727
+ for i, speaker in enumerate(speakers):
728
+ formatted_text = formatted_text.replace(f"{speaker}:", f"Speaker {i+1}:")
729
+ return formatted_text
730
+ else:
731
+ # Plain text, assign to first speaker
732
+ return f"Speaker 1: {text}"
733
+
734
+ def _generate_with_vibevoice(self, formatted_text: str, voice_samples: List[np.ndarray],
735
+ cfg_scale: float, seed: int, diffusion_steps: int, use_sampling: bool,
736
+ temperature: float = 0.95, top_p: float = 0.95) -> dict:
737
+ """Generate audio using VibeVoice model"""
738
+ try:
739
+ # Ensure model and processor are loaded
740
+ if self.model is None or self.processor is None:
741
+ raise Exception("Model or processor not loaded")
742
+
743
+ # Set seeds for reproducibility
744
+ torch.manual_seed(seed)
745
+ if torch.cuda.is_available():
746
+ torch.cuda.manual_seed(seed)
747
+ torch.cuda.manual_seed_all(seed) # For multi-GPU
748
+
749
+ # Also set numpy seed for any numpy operations
750
+ np.random.seed(seed)
751
+
752
+ # Set diffusion steps
753
+ self.model.set_ddpm_inference_steps(diffusion_steps)
754
+ logger.info(f"Starting audio generation with {diffusion_steps} diffusion steps...")
755
+
756
+ # Check for interruption before starting generation
757
+ if INTERRUPTION_SUPPORT:
758
+ try:
759
+ import comfy.model_management as mm
760
+
761
+ # Check if we're being interrupted right now
762
+ # The interrupt flag is reset by ComfyUI before each node execution
763
+ # So we only check model_management's throw_exception_if_processing_interrupted
764
+ # which is the proper way to check for interruption
765
+ mm.throw_exception_if_processing_interrupted()
766
+
767
+ except ImportError:
768
+ # If comfy.model_management is not available, skip this check
769
+ pass
770
+
771
+ # Prepare inputs using processor
772
+ inputs = self.processor(
773
+ [formatted_text], # Wrap text in list
774
+ voice_samples=[voice_samples], # Provide voice samples for reference
775
+ return_tensors="pt",
776
+ return_attention_mask=True
777
+ )
778
+
779
+ # Move to device
780
+ device = next(self.model.parameters()).device
781
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
782
+
783
+ # Estimate tokens for user information (not used as limit)
784
+ text_length = len(formatted_text.split())
785
+ estimated_tokens = int(text_length * 2.5) # More accurate estimate for display
786
+
787
+ # Log generation start with explanation
788
+ logger.info(f"Generating audio with {diffusion_steps} diffusion steps...")
789
+ logger.info(f"Note: Progress bar shows max possible tokens, not actual needed (~{estimated_tokens} estimated)")
790
+ logger.info("The generation will stop automatically when audio is complete")
791
+
792
+ # Create stop check function for interruption support
793
+ stop_check_fn = None
794
+ if INTERRUPTION_SUPPORT:
795
+ def check_comfyui_interrupt():
796
+ """Check if ComfyUI has requested interruption"""
797
+ try:
798
+ if hasattr(execution, 'PromptExecutor') and hasattr(execution.PromptExecutor, 'interrupted'):
799
+ interrupted = execution.PromptExecutor.interrupted
800
+ if interrupted:
801
+ logger.info("Generation interrupted by user via stop_check_fn")
802
+ return interrupted
803
+ except:
804
+ pass
805
+ return False
806
+
807
+ stop_check_fn = check_comfyui_interrupt
808
+
809
+ # Generate with official parameters
810
+ with torch.no_grad():
811
+ if use_sampling:
812
+ # Use sampling mode (less stable but more varied)
813
+ output = self.model.generate(
814
+ **inputs,
815
+ tokenizer=self.processor.tokenizer,
816
+ cfg_scale=cfg_scale,
817
+ max_new_tokens=None,
818
+ do_sample=True,
819
+ temperature=temperature,
820
+ top_p=top_p,
821
+ stop_check_fn=stop_check_fn,
822
+ )
823
+ else:
824
+ # Use deterministic mode like official examples
825
+ output = self.model.generate(
826
+ **inputs,
827
+ tokenizer=self.processor.tokenizer,
828
+ cfg_scale=cfg_scale,
829
+ max_new_tokens=None,
830
+ do_sample=False, # More deterministic generation
831
+ stop_check_fn=stop_check_fn,
832
+ )
833
+
834
+ # Check if we got actual audio output
835
+ if hasattr(output, 'speech_outputs') and output.speech_outputs:
836
+ speech_tensors = output.speech_outputs
837
+
838
+ if isinstance(speech_tensors, list) and len(speech_tensors) > 0:
839
+ audio_tensor = torch.cat(speech_tensors, dim=-1)
840
+ else:
841
+ audio_tensor = speech_tensors
842
+
843
+ # Ensure proper format (1, 1, samples)
844
+ if audio_tensor.dim() == 1:
845
+ audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
846
+ elif audio_tensor.dim() == 2:
847
+ audio_tensor = audio_tensor.unsqueeze(0)
848
+
849
+ # Convert to float32 for compatibility with downstream nodes (Save Audio, etc.)
850
+ # Many audio processing nodes don't support BFloat16
851
+ return {
852
+ "waveform": audio_tensor.cpu().float(),
853
+ "sample_rate": 24000
854
+ }
855
+
856
+ elif hasattr(output, 'sequences'):
857
+ logger.error("VibeVoice returned only text tokens, no audio generated")
858
+ raise Exception("VibeVoice failed to generate audio - only text tokens returned")
859
+
860
+ else:
861
+ logger.error(f"Unexpected output format from VibeVoice: {type(output)}")
862
+ raise Exception(f"VibeVoice returned unexpected output format: {type(output)}")
863
+
864
+ except Exception as e:
865
+ # Re-raise interruption exceptions without wrapping
866
+ import comfy.model_management as mm
867
+ if isinstance(e, mm.InterruptProcessingException):
868
+ raise # Let the interruption propagate
869
+
870
+ # For real errors, log and re-raise with context
871
+ logger.error(f"VibeVoice generation failed: {e}")
872
+ raise Exception(f"VibeVoice generation failed: {str(e)}")
nodes/free_memory_node.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by Fabio Sarracino
2
+ # Node to free VibeVoice model memory
3
+
4
+ import logging
5
+ import torch
6
+ import gc
7
+ from typing import Any
8
+
9
+ # Setup logging
10
+ logger = logging.getLogger("VibeVoice")
11
+
12
+ class VibeVoiceFreeMemoryNode:
13
+ """Node to explicitly free VibeVoice model memory"""
14
+
15
+ # Class variables to store node instances
16
+ _single_speaker_instances = []
17
+ _multi_speaker_instances = []
18
+
19
+ @classmethod
20
+ def INPUT_TYPES(cls):
21
+ return {
22
+ "required": {
23
+ "audio": ("AUDIO", {"tooltip": "Audio input that triggers memory cleanup and gets passed through"}),
24
+ }
25
+ }
26
+
27
+ RETURN_TYPES = ("AUDIO",)
28
+ RETURN_NAMES = ("audio",)
29
+ FUNCTION = "free_vibevoice_memory"
30
+ CATEGORY = "VibeVoiceWrapper"
31
+ DESCRIPTION = "Free all loaded VibeVoice models from memory when audio passes through"
32
+
33
+ @classmethod
34
+ def register_single_speaker(cls, node_instance):
35
+ """Register a single speaker node instance"""
36
+ if node_instance not in cls._single_speaker_instances:
37
+ cls._single_speaker_instances.append(node_instance)
38
+
39
+ @classmethod
40
+ def register_multi_speaker(cls, node_instance):
41
+ """Register a multi speaker node instance"""
42
+ if node_instance not in cls._multi_speaker_instances:
43
+ cls._multi_speaker_instances.append(node_instance)
44
+
45
+ def free_vibevoice_memory(self, audio):
46
+ """Free memory from all VibeVoice nodes and pass through the audio"""
47
+
48
+ try:
49
+ freed_count = 0
50
+
51
+ # Try to access and free memory from globally cached instances
52
+ # ComfyUI might cache node instances
53
+ try:
54
+ import sys
55
+ from .base_vibevoice import BaseVibeVoiceNode
56
+
57
+ # Search in all modules for BaseVibeVoiceNode instances
58
+ for module_name, module in sys.modules.items():
59
+ if module and 'vibevoice' in module_name.lower():
60
+ for attr_name in dir(module):
61
+ if not attr_name.startswith('_'):
62
+ try:
63
+ attr = getattr(module, attr_name)
64
+ if isinstance(attr, type) and issubclass(attr, BaseVibeVoiceNode):
65
+ # Check if the class has any cached instances
66
+ for instance_attr in dir(attr):
67
+ instance = getattr(attr, instance_attr)
68
+ if isinstance(instance, BaseVibeVoiceNode) and hasattr(instance, 'free_memory'):
69
+ instance.free_memory()
70
+ freed_count += 1
71
+ except:
72
+ pass
73
+ except:
74
+ pass
75
+
76
+ # Free from registered single speaker instances
77
+ for node in self._single_speaker_instances:
78
+ if hasattr(node, 'free_memory'):
79
+ node.free_memory()
80
+ freed_count += 1
81
+
82
+ # Free from registered multi speaker instances
83
+ for node in self._multi_speaker_instances:
84
+ if hasattr(node, 'free_memory'):
85
+ node.free_memory()
86
+ freed_count += 1
87
+
88
+ # Force garbage collection
89
+ gc.collect()
90
+
91
+ # Clear CUDA cache if available
92
+ if torch.cuda.is_available():
93
+ torch.cuda.empty_cache()
94
+ torch.cuda.synchronize()
95
+ logger.info(f"Freed VibeVoice memory from {freed_count} nodes and cleared CUDA cache")
96
+ else:
97
+ logger.info(f"Freed VibeVoice memory from {freed_count} nodes")
98
+
99
+ # Pass through the audio unchanged
100
+ return (audio,)
101
+
102
+ except Exception as e:
103
+ logger.error(f"Error freeing VibeVoice memory: {str(e)}")
104
+ # Still pass through audio even if error occurs
105
+ return (audio,)
106
+
107
+ @classmethod
108
+ def IS_CHANGED(cls, **kwargs):
109
+ """Always execute this node"""
110
+ return float("nan") # Forces re-execution every time
nodes/load_text_node.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by Fabio Sarracino
2
+
3
+ import os
4
+ import logging
5
+ import hashlib
6
+ import folder_paths
7
+
8
+ # Setup logging
9
+ logger = logging.getLogger("VibeVoice")
10
+
11
+ class LoadTextFromFileNode:
12
+ @classmethod
13
+ def INPUT_TYPES(cls):
14
+ # Get all text files from all directories
15
+ all_files = []
16
+
17
+ # Add files from each directory with prefix
18
+ for dir_name in ["input", "output", "temp"]:
19
+ files = cls.get_files_for_directory(dir_name)
20
+ for f in files:
21
+ if f != "No text files found":
22
+ all_files.append(f"{dir_name}/{f}")
23
+
24
+ if not all_files:
25
+ all_files = ["No text files found in any directory"]
26
+
27
+ return {
28
+ "required": {
29
+ "file": (sorted(all_files), {
30
+ "tooltip": "Select a text file to load (format: directory/filename)"
31
+ }),
32
+ }
33
+ }
34
+
35
+ @classmethod
36
+ def get_files_for_directory(cls, source_dir):
37
+ """Get list of text files for the selected directory"""
38
+ # Get the appropriate directory path
39
+ if source_dir == "input":
40
+ dir_path = folder_paths.get_input_directory()
41
+ elif source_dir == "output":
42
+ dir_path = folder_paths.get_output_directory()
43
+ elif source_dir == "temp":
44
+ dir_path = folder_paths.get_temp_directory()
45
+ else:
46
+ return []
47
+
48
+ files = []
49
+ try:
50
+ for f in os.listdir(dir_path):
51
+ if os.path.isfile(os.path.join(dir_path, f)):
52
+ # Check for text file extensions
53
+ if f.lower().endswith(('.txt')):
54
+ files.append(f)
55
+ except Exception as e:
56
+ logger.warning(f"Error listing files in {source_dir}: {e}")
57
+
58
+ return files
59
+
60
+ RETURN_TYPES = ("STRING",)
61
+ RETURN_NAMES = ("text",)
62
+ FUNCTION = "load_text"
63
+ CATEGORY = "VibeVoiceWrapper"
64
+ DESCRIPTION = "Load text content from a .txt file"
65
+
66
+ def load_text(self, file: str):
67
+ """Load text content from file"""
68
+
69
+ try:
70
+ # Check if no file selected
71
+ if not file or file == "No text files found in any directory":
72
+ raise Exception("Please select a valid text file.")
73
+
74
+ # Parse directory and filename from the combined string
75
+ if "/" not in file:
76
+ raise Exception(f"Invalid file format: {file}")
77
+
78
+ source_dir, filename = file.split("/", 1)
79
+
80
+ # Get the appropriate directory path
81
+ if source_dir == "input":
82
+ dir_path = folder_paths.get_input_directory()
83
+ elif source_dir == "output":
84
+ dir_path = folder_paths.get_output_directory()
85
+ elif source_dir == "temp":
86
+ dir_path = folder_paths.get_temp_directory()
87
+ else:
88
+ raise Exception(f"Invalid source directory: {source_dir}")
89
+
90
+ # Build full file path
91
+ file_path = os.path.join(dir_path, filename)
92
+
93
+ if not os.path.exists(file_path):
94
+ raise Exception(f"File not found: {file_path}")
95
+
96
+ # Read file with UTF-8 encoding (most common)
97
+ with open(file_path, 'r', encoding='utf-8') as f:
98
+ text_content = f.read()
99
+
100
+ if not text_content.strip():
101
+ raise Exception("File is empty or contains only whitespace")
102
+
103
+ return (text_content,)
104
+
105
+ except UnicodeDecodeError as e:
106
+ raise Exception(f"Encoding error reading file: {str(e)}. File may not be UTF-8 encoded.")
107
+ except Exception as e:
108
+ logger.error(f"Failed to load text file: {str(e)}")
109
+ raise Exception(f"Error loading text file: {str(e)}")
110
+
111
+ @classmethod
112
+ def IS_CHANGED(cls, file):
113
+ """Cache key for ComfyUI"""
114
+ if not file or file == "No text files found in any directory":
115
+ return "no_file"
116
+
117
+ # Parse directory and filename
118
+ if "/" not in file:
119
+ return f"{file}_invalid"
120
+
121
+ source_dir, filename = file.split("/", 1)
122
+
123
+ # Get the appropriate directory path
124
+ if source_dir == "input":
125
+ dir_path = folder_paths.get_input_directory()
126
+ elif source_dir == "output":
127
+ dir_path = folder_paths.get_output_directory()
128
+ elif source_dir == "temp":
129
+ dir_path = folder_paths.get_temp_directory()
130
+ else:
131
+ return f"{file}_invalid_dir"
132
+
133
+ file_path = os.path.join(dir_path, filename)
134
+
135
+ if not os.path.exists(file_path):
136
+ return f"{file}_not_found"
137
+
138
+ # Use file hash for cache invalidation
139
+ try:
140
+ m = hashlib.sha256()
141
+ with open(file_path, 'rb') as f:
142
+ m.update(f.read())
143
+ return m.digest().hex()
144
+ except:
145
+ return f"{file}_error"
146
+
147
+ @classmethod
148
+ def VALIDATE_INPUTS(cls, file, **kwargs):
149
+ """Validate that the file exists"""
150
+ if not file or file == "No text files found in any directory":
151
+ return "No valid text file selected"
152
+
153
+ # Parse directory and filename
154
+ if "/" not in file:
155
+ return f"Invalid file format: {file}"
156
+
157
+ source_dir, filename = file.split("/", 1)
158
+
159
+ # Get the appropriate directory path
160
+ if source_dir == "input":
161
+ dir_path = folder_paths.get_input_directory()
162
+ elif source_dir == "output":
163
+ dir_path = folder_paths.get_output_directory()
164
+ elif source_dir == "temp":
165
+ dir_path = folder_paths.get_temp_directory()
166
+ else:
167
+ return f"Invalid source directory: {source_dir}"
168
+
169
+ file_path = os.path.join(dir_path, filename)
170
+ if not os.path.exists(file_path):
171
+ return f"File not found: {filename} in {source_dir}"
172
+
173
+ return True
nodes/multi_speaker_node.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by Fabio Sarracino
2
+
3
+ import logging
4
+ import os
5
+ import re
6
+ import tempfile
7
+ import torch
8
+ import numpy as np
9
+ from typing import List, Optional
10
+
11
+ from .base_vibevoice import BaseVibeVoiceNode
12
+
13
+ # Setup logging
14
+ logger = logging.getLogger("VibeVoice")
15
+
16
+ class VibeVoiceMultipleSpeakersNode(BaseVibeVoiceNode):
17
+ def __init__(self):
18
+ super().__init__()
19
+ # Register this instance for memory management
20
+ try:
21
+ from .free_memory_node import VibeVoiceFreeMemoryNode
22
+ VibeVoiceFreeMemoryNode.register_multi_speaker(self)
23
+ except:
24
+ pass
25
+
26
+ @classmethod
27
+ def INPUT_TYPES(cls):
28
+ return {
29
+ "required": {
30
+ "text": ("STRING", {
31
+ "multiline": True,
32
+ "default": "[1]: Hello, this is the first speaker.\n[2]: Hi there, I'm the second speaker.\n[1]: Nice to meet you!\n[2]: Nice to meet you too!",
33
+ "tooltip": "Text with speaker labels. Use '[N]:' format where N is 1-4. Gets disabled when connected to another node.",
34
+ "forceInput": False,
35
+ "dynamicPrompts": True
36
+ }),
37
+ "model": (["VibeVoice-1.5B", "VibeVoice-Large", "VibeVoice-Large-Quant-4Bit"], {
38
+ "default": "VibeVoice-Large", # Large recommended for multi-speaker
39
+ "tooltip": "Model to use. Large is recommended for multi-speaker generation, Quant-4Bit uses less VRAM (CUDA only)"
40
+ }),
41
+ "attention_type": (["auto", "eager", "sdpa", "flash_attention_2", "sage"], {
42
+ "default": "auto",
43
+ "tooltip": "Attention implementation. Auto selects the best available, eager is standard, sdpa is optimized PyTorch, flash_attention_2 requires compatible GPU, sage uses quantized attention for speedup (CUDA only)"
44
+ }),
45
+ "free_memory_after_generate": ("BOOLEAN", {"default": True, "tooltip": "Free model from memory after generation to save VRAM/RAM. Disable to keep model loaded for faster subsequent generations"}),
46
+ "diffusion_steps": ("INT", {"default": 20, "min": 5, "max": 100, "step": 1, "tooltip": "Number of denoising steps. More steps = better quality but slower. Default: 20"}),
47
+ "seed": ("INT", {"default": 42, "min": 0, "max": 2**32-1, "tooltip": "Random seed for generation. Default 42 is used in official examples"}),
48
+ "cfg_scale": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.5, "step": 0.05, "tooltip": "Classifier-free guidance scale (official default: 1.3)"}),
49
+ "use_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable sampling mode. When False (default), uses deterministic generation like official examples"}),
50
+ },
51
+ "optional": {
52
+ "speaker1_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 1. If not provided, synthetic voice will be used."}),
53
+ "speaker2_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 2. If not provided, synthetic voice will be used."}),
54
+ "speaker3_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 3. If not provided, synthetic voice will be used."}),
55
+ "speaker4_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 4. If not provided, synthetic voice will be used."}),
56
+ "temperature": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
57
+ "top_p": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 1.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
58
+ }
59
+ }
60
+
61
+ RETURN_TYPES = ("AUDIO",)
62
+ RETURN_NAMES = ("audio",)
63
+ FUNCTION = "generate_speech"
64
+ CATEGORY = "VibeVoiceWrapper"
65
+ DESCRIPTION = "Generate multi-speaker conversations with up to 4 distinct voices using Microsoft VibeVoice"
66
+
67
+ def _prepare_voice_sample(self, voice_audio, speaker_idx: int) -> Optional[np.ndarray]:
68
+ """Prepare a single voice sample from input audio"""
69
+ return self._prepare_audio_from_comfyui(voice_audio)
70
+
71
+ def generate_speech(self, text: str = "", model: str = "VibeVoice-7B-Preview",
72
+ attention_type: str = "auto", free_memory_after_generate: bool = True,
73
+ diffusion_steps: int = 20, seed: int = 42, cfg_scale: float = 1.3,
74
+ use_sampling: bool = False, speaker1_voice=None, speaker2_voice=None,
75
+ speaker3_voice=None, speaker4_voice=None,
76
+ temperature: float = 0.95, top_p: float = 0.95):
77
+ """Generate multi-speaker speech from text using VibeVoice"""
78
+
79
+ try:
80
+ # Check text input
81
+ if not text or not text.strip():
82
+ raise Exception("No text provided. Please enter text with speaker labels (e.g., '[1]: Hello' or '[2]: Hi')")
83
+
84
+ # First detect how many speakers are in the text
85
+ bracket_pattern = r'\[(\d+)\]\s*:'
86
+ speakers_numbers = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)])))
87
+
88
+ # Limit to 1-4 speakers
89
+ if not speakers_numbers:
90
+ num_speakers = 1 # Default to 1 if no speaker format found
91
+ else:
92
+ num_speakers = min(max(speakers_numbers), 4) # Max speaker number, capped at 4
93
+ if max(speakers_numbers) > 4:
94
+ print(f"[VibeVoice] Warning: Found {max(speakers_numbers)} speakers, limiting to 4")
95
+
96
+ # Direct conversion from [N]: to Speaker (N-1): for VibeVoice processor
97
+ # This avoids multiple conversion steps
98
+ converted_text = text
99
+
100
+ # Find all [N]: patterns in the text
101
+ speakers_in_text = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)])))
102
+
103
+ if not speakers_in_text:
104
+ # No [N]: format found, try Speaker N: format
105
+ speaker_pattern = r'Speaker\s+(\d+)\s*:'
106
+ speakers_in_text = sorted(list(set([int(m) for m in re.findall(speaker_pattern, text)])))
107
+
108
+ if speakers_in_text:
109
+ # Text already in Speaker N format, convert to 0-based
110
+ for speaker_num in sorted(speakers_in_text, reverse=True):
111
+ pattern = f'Speaker\\s+{speaker_num}\\s*:'
112
+ replacement = f'Speaker {speaker_num - 1}:'
113
+ converted_text = re.sub(pattern, replacement, converted_text)
114
+ else:
115
+ # No speaker format found
116
+ speakers_in_text = [1]
117
+
118
+ # Parse pause keywords even for single speaker
119
+ pause_segments = self._parse_pause_keywords(text)
120
+
121
+ # Store speaker segments for pause processing
122
+ speaker_segments_with_pauses = []
123
+ segments = []
124
+
125
+ for seg_type, seg_content in pause_segments:
126
+ if seg_type == 'pause':
127
+ speaker_segments_with_pauses.append(('pause', seg_content, None))
128
+ else:
129
+ # Clean up newlines
130
+ text_clean = seg_content.replace('\n', ' ').replace('\r', ' ')
131
+ text_clean = ' '.join(text_clean.split())
132
+
133
+ if text_clean:
134
+ speaker_segments_with_pauses.append(('text', text_clean, 1))
135
+ segments.append(f"Speaker 0: {text_clean}")
136
+
137
+ # Join all segments for fallback
138
+ converted_text = '\n'.join(segments) if segments else f"Speaker 0: {text}"
139
+ else:
140
+ # Convert [N]: directly to Speaker (N-1): and handle multi-line text
141
+ # Split text to preserve speaker segments while cleaning up newlines within each segment
142
+ segments = []
143
+
144
+ # Find all speaker markers with their positions
145
+ speaker_matches = list(re.finditer(f'\\[({"|".join(map(str, speakers_in_text))})\\]\\s*:', converted_text))
146
+
147
+ # Store speaker segments for pause processing
148
+ speaker_segments_with_pauses = []
149
+
150
+ for i, match in enumerate(speaker_matches):
151
+ speaker_num = int(match.group(1))
152
+ start = match.end()
153
+
154
+ # Find where this speaker's text ends (at next speaker or end of text)
155
+ if i + 1 < len(speaker_matches):
156
+ end = speaker_matches[i + 1].start()
157
+ else:
158
+ end = len(converted_text)
159
+
160
+ # Extract the speaker's text (keep pause keywords for now)
161
+ speaker_text = converted_text[start:end].strip()
162
+
163
+ # Parse pause keywords within this speaker's text
164
+ pause_segments = self._parse_pause_keywords(speaker_text)
165
+
166
+ # Process each segment (text or pause) for this speaker
167
+ for seg_type, seg_content in pause_segments:
168
+ if seg_type == 'pause':
169
+ # Add pause segment
170
+ speaker_segments_with_pauses.append(('pause', seg_content, None))
171
+ else:
172
+ # Clean up the text segment
173
+ text_clean = seg_content.replace('\n', ' ').replace('\r', ' ')
174
+ text_clean = ' '.join(text_clean.split())
175
+
176
+ if text_clean: # Only add non-empty text
177
+ # Add text segment with speaker info
178
+ speaker_segments_with_pauses.append(('text', text_clean, speaker_num))
179
+ # Also build the traditional segments for fallback
180
+ segments.append(f'Speaker {speaker_num - 1}: {text_clean}')
181
+
182
+ # Join all segments with newlines (required for multi-speaker format) - for fallback
183
+ converted_text = '\n'.join(segments) if segments else ""
184
+
185
+ # Build speaker names list - these are just for logging, not used by processor
186
+ # The processor uses the speaker labels in the text itself
187
+ speakers = [f"Speaker {i}" for i in range(len(speakers_in_text))]
188
+
189
+ # Get model mapping and load model with attention type
190
+ model_mapping = self._get_model_mapping()
191
+ model_path = model_mapping.get(model, model)
192
+ self.load_model(model, model_path, attention_type)
193
+
194
+ voice_inputs = [speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice]
195
+
196
+ # Prepare voice samples in order of appearance
197
+ voice_samples = []
198
+ for i, speaker_num in enumerate(speakers_in_text):
199
+ idx = speaker_num - 1 # Convert to 0-based for voice array
200
+
201
+ # Try to use provided voice sample
202
+ if idx < len(voice_inputs) and voice_inputs[idx] is not None:
203
+ voice_sample = self._prepare_voice_sample(voice_inputs[idx], idx)
204
+ if voice_sample is None:
205
+ # Use the actual speaker index for consistent synthetic voice
206
+ voice_sample = self._create_synthetic_voice_sample(idx)
207
+ else:
208
+ # Use the actual speaker index for consistent synthetic voice
209
+ voice_sample = self._create_synthetic_voice_sample(idx)
210
+
211
+ voice_samples.append(voice_sample)
212
+
213
+ # Ensure voice_samples count matches detected speakers
214
+ if len(voice_samples) != len(speakers_in_text):
215
+ logger.error(f"Mismatch: {len(speakers_in_text)} speakers but {len(voice_samples)} voice samples!")
216
+ raise Exception(f"Voice sample count mismatch: expected {len(speakers_in_text)}, got {len(voice_samples)}")
217
+
218
+ # Check if we have pause segments to process
219
+ if 'speaker_segments_with_pauses' in locals() and speaker_segments_with_pauses:
220
+ # Process segments with pauses
221
+ all_audio_segments = []
222
+ sample_rate = 24000 # VibeVoice uses 24kHz
223
+
224
+ # Group consecutive text segments from same speaker for efficiency
225
+ grouped_segments = []
226
+ current_group = []
227
+ current_speaker = None
228
+
229
+ for seg_type, seg_content, speaker_num in speaker_segments_with_pauses:
230
+ if seg_type == 'pause':
231
+ # Save current group if any
232
+ if current_group:
233
+ grouped_segments.append(('text_group', current_group, current_speaker))
234
+ current_group = []
235
+ current_speaker = None
236
+ # Add pause
237
+ grouped_segments.append(('pause', seg_content, None))
238
+ else:
239
+ # Text segment
240
+ if speaker_num == current_speaker:
241
+ # Same speaker, add to current group
242
+ current_group.append(seg_content)
243
+ else:
244
+ # Different speaker, save current group and start new one
245
+ if current_group:
246
+ grouped_segments.append(('text_group', current_group, current_speaker))
247
+ current_group = [seg_content]
248
+ current_speaker = speaker_num
249
+
250
+ # Save last group if any
251
+ if current_group:
252
+ grouped_segments.append(('text_group', current_group, current_speaker))
253
+
254
+ # Process grouped segments
255
+ for seg_type, seg_content, speaker_num in grouped_segments:
256
+ if seg_type == 'pause':
257
+ # Generate silence
258
+ duration_ms = seg_content
259
+ logger.info(f"Adding {duration_ms}ms pause")
260
+ silence_audio = self._generate_silence(duration_ms, sample_rate)
261
+ all_audio_segments.append(silence_audio)
262
+ else:
263
+ # Process text group for a speaker
264
+ combined_text = ' '.join(seg_content)
265
+ formatted_text = f"Speaker {speaker_num - 1}: {combined_text}"
266
+
267
+ # Get voice sample for this speaker
268
+ speaker_idx = speakers_in_text.index(speaker_num)
269
+ speaker_voice_samples = [voice_samples[speaker_idx]]
270
+
271
+ logger.info(f"Generating audio for Speaker {speaker_num}: {len(combined_text.split())} words")
272
+
273
+ # Generate audio for this speaker's text
274
+ segment_audio = self._generate_with_vibevoice(
275
+ formatted_text, speaker_voice_samples, cfg_scale, seed,
276
+ diffusion_steps, use_sampling, temperature, top_p
277
+ )
278
+
279
+ all_audio_segments.append(segment_audio)
280
+
281
+ # Concatenate all audio segments
282
+ if all_audio_segments:
283
+ logger.info(f"Concatenating {len(all_audio_segments)} audio segments (including pauses)...")
284
+
285
+ # Extract waveforms
286
+ waveforms = []
287
+ for audio_segment in all_audio_segments:
288
+ if isinstance(audio_segment, dict) and "waveform" in audio_segment:
289
+ waveforms.append(audio_segment["waveform"])
290
+
291
+ if waveforms:
292
+ # Filter out None values if any
293
+ valid_waveforms = [w for w in waveforms if w is not None]
294
+
295
+ if valid_waveforms:
296
+ # Concatenate along time dimension
297
+ combined_waveform = torch.cat(valid_waveforms, dim=-1)
298
+
299
+ audio_dict = {
300
+ "waveform": combined_waveform,
301
+ "sample_rate": sample_rate
302
+ }
303
+ logger.info(f"Successfully generated multi-speaker audio with pauses")
304
+ else:
305
+ raise Exception("No valid audio waveforms generated")
306
+ else:
307
+ raise Exception("Failed to extract waveforms from audio segments")
308
+ else:
309
+ raise Exception("No audio segments generated")
310
+ else:
311
+ # Fallback to original method without pause support
312
+ logger.info("Processing without pause support (no pause keywords found)")
313
+ audio_dict = self._generate_with_vibevoice(
314
+ converted_text, voice_samples, cfg_scale, seed, diffusion_steps,
315
+ use_sampling, temperature, top_p
316
+ )
317
+
318
+ # Free memory if requested
319
+ if free_memory_after_generate:
320
+ self.free_memory()
321
+
322
+ return (audio_dict,)
323
+
324
+ except Exception as e:
325
+ # Check if this is an interruption by the user
326
+ import comfy.model_management as mm
327
+ if isinstance(e, mm.InterruptProcessingException):
328
+ # User interrupted - just log it and re-raise to stop the workflow
329
+ logger.info("Generation interrupted by user")
330
+ raise # Propagate the interruption to stop the workflow
331
+ else:
332
+ # Real error - show it
333
+ logger.error(f"Multi-speaker speech generation failed: {str(e)}")
334
+ raise Exception(f"Error generating multi-speaker speech: {str(e)}")
335
+
336
+ @classmethod
337
+ def IS_CHANGED(cls, text="", model="VibeVoice-7B-Preview",
338
+ speaker1_voice=None, speaker2_voice=None,
339
+ speaker3_voice=None, speaker4_voice=None, **kwargs):
340
+ """Cache key for ComfyUI"""
341
+ voices_hash = hash(str([speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice]))
342
+ return f"{hash(text)}_{model}_{voices_hash}_{kwargs.get('cfg_scale', 1.3)}_{kwargs.get('seed', 0)}"
nodes/single_speaker_node.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by Fabio Sarracino
2
+
3
+ import logging
4
+ import os
5
+ import tempfile
6
+ import torch
7
+ import numpy as np
8
+ import re
9
+ from typing import List, Optional
10
+
11
+ from .base_vibevoice import BaseVibeVoiceNode
12
+
13
+ # Setup logging
14
+ logger = logging.getLogger("VibeVoice")
15
+
16
+ class VibeVoiceSingleSpeakerNode(BaseVibeVoiceNode):
17
+ def __init__(self):
18
+ super().__init__()
19
+ # Register this instance for memory management
20
+ try:
21
+ from .free_memory_node import VibeVoiceFreeMemoryNode
22
+ VibeVoiceFreeMemoryNode.register_single_speaker(self)
23
+ except:
24
+ pass
25
+
26
+ @classmethod
27
+ def INPUT_TYPES(cls):
28
+ return {
29
+ "required": {
30
+ "text": ("STRING", {
31
+ "multiline": True,
32
+ "default": "Hello, this is a test of the VibeVoice text-to-speech system.",
33
+ "tooltip": "Text to convert to speech. Gets disabled when connected to another node.",
34
+ "forceInput": False,
35
+ "dynamicPrompts": True
36
+ }),
37
+ "model": (["VibeVoice-1.5B", "VibeVoice-Large", "VibeVoice-Large-Quant-4Bit"], {
38
+ "default": "VibeVoice-1.5B",
39
+ "tooltip": "Model to use. 1.5B is faster, Large has better quality, Quant-4Bit uses less VRAM (CUDA only)"
40
+ }),
41
+ "attention_type": (["auto", "eager", "sdpa", "flash_attention_2", "sage"], {
42
+ "default": "auto",
43
+ "tooltip": "Attention implementation. Auto selects the best available, eager is standard, sdpa is optimized PyTorch, flash_attention_2 requires compatible GPU, sage uses quantized attention for speedup (CUDA only)"
44
+ }),
45
+ "free_memory_after_generate": ("BOOLEAN", {"default": True, "tooltip": "Free model from memory after generation to save VRAM/RAM. Disable to keep model loaded for faster subsequent generations"}),
46
+ "diffusion_steps": ("INT", {"default": 20, "min": 5, "max": 100, "step": 1, "tooltip": "Number of denoising steps. More steps = better quality but slower. Default: 20"}),
47
+ "seed": ("INT", {"default": 42, "min": 0, "max": 2**32-1, "tooltip": "Random seed for generation. Default 42 is used in official examples"}),
48
+ "cfg_scale": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.5, "step": 0.05, "tooltip": "Classifier-free guidance scale (official default: 1.3)"}),
49
+ "use_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable sampling mode. When False (default), uses deterministic generation like official examples"}),
50
+ },
51
+ "optional": {
52
+ "voice_to_clone": ("AUDIO", {"tooltip": "Optional: Reference voice to clone. If not provided, synthetic voice will be used."}),
53
+ "temperature": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
54
+ "top_p": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 1.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
55
+ "max_words_per_chunk": ("INT", {"default": 250, "min": 100, "max": 500, "step": 50, "tooltip": "Maximum words per chunk for long texts. Lower values prevent speed issues but create more chunks."}),
56
+ }
57
+ }
58
+
59
+ RETURN_TYPES = ("AUDIO",)
60
+ RETURN_NAMES = ("audio",)
61
+ FUNCTION = "generate_speech"
62
+ CATEGORY = "VibeVoiceWrapper"
63
+ DESCRIPTION = "Generate speech from text using Microsoft VibeVoice with optional voice cloning"
64
+
65
+ def _prepare_voice_samples(self, speakers: list, voice_to_clone) -> List[np.ndarray]:
66
+ """Prepare voice samples from input audio or create synthetic ones"""
67
+
68
+ if voice_to_clone is not None:
69
+ # Use the base class method to prepare audio
70
+ audio_np = self._prepare_audio_from_comfyui(voice_to_clone)
71
+ if audio_np is not None:
72
+ return [audio_np]
73
+
74
+ # Create synthetic voice samples for speakers
75
+ voice_samples = []
76
+ for i, speaker in enumerate(speakers):
77
+ voice_sample = self._create_synthetic_voice_sample(i)
78
+ voice_samples.append(voice_sample)
79
+
80
+ return voice_samples
81
+
82
+ def generate_speech(self, text: str = "", model: str = "VibeVoice-1.5B",
83
+ attention_type: str = "auto", free_memory_after_generate: bool = True,
84
+ diffusion_steps: int = 20, seed: int = 42, cfg_scale: float = 1.3,
85
+ use_sampling: bool = False, voice_to_clone=None,
86
+ temperature: float = 0.95, top_p: float = 0.95,
87
+ max_words_per_chunk: int = 250):
88
+ """Generate speech from text using VibeVoice"""
89
+
90
+ try:
91
+ # Use text directly (it now serves as both manual input and connection input)
92
+ if text and text.strip():
93
+ final_text = text
94
+ else:
95
+ raise Exception("No text provided. Please enter text or connect from LoadTextFromFile node.")
96
+
97
+ # Get model mapping and load model with attention type
98
+ model_mapping = self._get_model_mapping()
99
+ model_path = model_mapping.get(model, model)
100
+ self.load_model(model, model_path, attention_type)
101
+
102
+ # For single speaker, we just use ["Speaker 1"]
103
+ speakers = ["Speaker 1"]
104
+
105
+ # Parse pause keywords from text
106
+ segments = self._parse_pause_keywords(final_text)
107
+
108
+ # Process segments
109
+ all_audio_segments = []
110
+ voice_samples = None # Will be created on first text segment
111
+ sample_rate = 24000 # VibeVoice uses 24kHz
112
+
113
+ for seg_idx, (seg_type, seg_content) in enumerate(segments):
114
+ if seg_type == 'pause':
115
+ # Generate silence for pause
116
+ duration_ms = seg_content
117
+ logger.info(f"Adding {duration_ms}ms pause")
118
+ silence_audio = self._generate_silence(duration_ms, sample_rate)
119
+ all_audio_segments.append(silence_audio)
120
+
121
+ elif seg_type == 'text':
122
+ # Process text segment (with chunking if needed)
123
+ word_count = len(seg_content.split())
124
+
125
+ if word_count > max_words_per_chunk:
126
+ # Split long text into chunks
127
+ logger.info(f"Text segment {seg_idx+1} has {word_count} words, splitting into chunks...")
128
+ text_chunks = self._split_text_into_chunks(seg_content, max_words_per_chunk)
129
+
130
+ for chunk_idx, chunk in enumerate(text_chunks):
131
+ logger.info(f"Processing chunk {chunk_idx+1}/{len(text_chunks)} of segment {seg_idx+1}...")
132
+
133
+ # Format chunk for VibeVoice
134
+ formatted_text = self._format_text_for_vibevoice(chunk, speakers)
135
+
136
+ # Create voice samples on first text segment
137
+ if voice_samples is None:
138
+ voice_samples = self._prepare_voice_samples(speakers, voice_to_clone)
139
+
140
+ # Generate audio for this chunk
141
+ chunk_audio = self._generate_with_vibevoice(
142
+ formatted_text, voice_samples, cfg_scale,
143
+ seed, # Use same seed for voice consistency
144
+ diffusion_steps, use_sampling, temperature, top_p
145
+ )
146
+
147
+ all_audio_segments.append(chunk_audio)
148
+ else:
149
+ # Process as single chunk
150
+ logger.info(f"Processing text segment {seg_idx+1} ({word_count} words)")
151
+
152
+ # Format text for VibeVoice
153
+ formatted_text = self._format_text_for_vibevoice(seg_content, speakers)
154
+
155
+ # Create voice samples on first text segment
156
+ if voice_samples is None:
157
+ voice_samples = self._prepare_voice_samples(speakers, voice_to_clone)
158
+
159
+ # Generate audio
160
+ segment_audio = self._generate_with_vibevoice(
161
+ formatted_text, voice_samples, cfg_scale, seed, diffusion_steps,
162
+ use_sampling, temperature, top_p
163
+ )
164
+
165
+ all_audio_segments.append(segment_audio)
166
+
167
+ # Concatenate all audio segments (including pauses)
168
+ if all_audio_segments:
169
+ logger.info(f"Concatenating {len(all_audio_segments)} audio segments (including pauses)...")
170
+
171
+ # Extract waveforms from all segments
172
+ waveforms = []
173
+ for audio_segment in all_audio_segments:
174
+ if isinstance(audio_segment, dict) and "waveform" in audio_segment:
175
+ waveforms.append(audio_segment["waveform"])
176
+
177
+ if waveforms:
178
+ # Filter out None values if any
179
+ valid_waveforms = [w for w in waveforms if w is not None]
180
+
181
+ if valid_waveforms:
182
+ # Concatenate along the time dimension (last dimension)
183
+ combined_waveform = torch.cat(valid_waveforms, dim=-1)
184
+
185
+ # Create final audio dict
186
+ audio_dict = {
187
+ "waveform": combined_waveform,
188
+ "sample_rate": sample_rate
189
+ }
190
+ logger.info(f"Successfully generated audio with {len(segments)} segments")
191
+ else:
192
+ raise Exception("No valid audio waveforms generated")
193
+ else:
194
+ raise Exception("Failed to extract waveforms from audio segments")
195
+ else:
196
+ raise Exception("No audio segments generated")
197
+
198
+ # Free memory if requested
199
+ if free_memory_after_generate:
200
+ self.free_memory()
201
+
202
+ return (audio_dict,)
203
+
204
+ except Exception as e:
205
+ # Check if this is an interruption by the user
206
+ import comfy.model_management as mm
207
+ if isinstance(e, mm.InterruptProcessingException):
208
+ # User interrupted - just log it and re-raise to stop the workflow
209
+ logger.info("Generation interrupted by user")
210
+ raise # Propagate the interruption to stop the workflow
211
+ else:
212
+ # Real error - show it
213
+ logger.error(f"Single speaker speech generation failed: {str(e)}")
214
+ raise Exception(f"Error generating speech: {str(e)}")
215
+
216
+ @classmethod
217
+ def IS_CHANGED(cls, text="", model="VibeVoice-1.5B", voice_to_clone=None, **kwargs):
218
+ """Cache key for ComfyUI"""
219
+ voice_hash = hash(str(voice_to_clone)) if voice_to_clone else 0
220
+ return f"{hash(text)}_{model}_{voice_hash}_{kwargs.get('cfg_scale', 1.3)}_{kwargs.get('seed', 0)}"
pyproject.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "VibeVoice-ComfyUI"
3
+ version = "1.3.0"
4
+ description = "ComfyUI wrapper for Microsoft VibeVoice TTS model. Supports single speaker, multi-speaker, and text file loading"
5
+ license = {file = "LICENSE"}
6
+ authors = [{name = "Fabio Sarracino"}]
7
+ dependencies = ["accelerate==1.6.0", "torch>=2.0.0", "torchaudio>=2.0.0", "numpy>=1.20.0", "transformers>=4.51.3", "librosa>=0.9.0", "soundfile>=0.12.0", "av>=14.3.0", "peft>=0.17.0", "huggingface_hub>=0.25.1", "diffusers", "tqdm", "scipy", "ml-collections", "absl-py", "aiortc", "bitsandbytes"]
8
+
9
+ [project.urls]
10
+ Repository = "https://github.com/Enemyx-net/VibeVoice-ComfyUI"
11
+ "Bug Tracker" = "https://github.com/Enemyx-net/VibeVoice-ComfyUI/issues"
12
+
13
+ [tool.comfy]
14
+ PublisherId = "enemyx"
15
+ DisplayName = "VibeVoice ComfyUI"
16
+ Icon = ""
17
+ includes = []
18
+
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.6.0
2
+ transformers>=4.51.3
3
+ diffusers
4
+ tqdm
5
+ scipy
6
+ ml-collections
7
+ torch>=2.0.0
8
+ torchaudio>=2.0.0
9
+ numpy>=1.20.0
10
+ librosa>=0.9.0
11
+ soundfile>=0.12.0
12
+ av>=14.3.0
13
+ peft>=0.17.0
14
+ huggingface_hub>=0.25.1
15
+ absl-py
16
+ aiortc
17
+ bitsandbytes
vvembed/LICENSE ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ---
24
+
25
+ This is the original VibeVoice code from Microsoft, embedded here as the
26
+ repository has been removed from GitHub. The code is used under the MIT license.
vvembed/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Embedded VibeVoice
2
+
3
+ This folder contains the embedded VibeVoice code from Microsoft.
4
+
5
+ ## Why Embedded?
6
+
7
+ The original VibeVoice repository (https://github.com/microsoft/VibeVoice) has been removed from GitHub. Since VibeVoice is licensed under MIT, we have embedded the code here to ensure continued functionality of the ComfyUI wrapper.
8
+
9
+ ## License
10
+
11
+ The code in this folder is licensed under the MIT License (see LICENSE file). Original copyright belongs to Microsoft Corporation.
12
+
13
+ ## Modifications
14
+
15
+ The only modifications made to the original code are:
16
+ - Changed absolute imports from `vibevoice` to relative imports
17
+ - No functional changes to the core logic
18
+
19
+ ## Note
20
+
21
+ This is a preservation copy to ensure the continued availability of VibeVoice for the ComfyUI community.
vvembed/__init__.py ADDED
File without changes
vvembed/modular/__init__.py ADDED
File without changes
vvembed/modular/configuration_vibevoice.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code by Microsoft
2
+ # updated by Fabio Sarracino - Enemyx-net
3
+
4
+ """ VibeVoice_AcousticTokenizer model configuration"""
5
+
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.utils import logging
10
+
11
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ # to be improved...
16
+
17
+
18
+ class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
19
+ model_type = "vibevoice_acoustic_tokenizer"
20
+
21
+ def __init__(
22
+ self,
23
+ channels: int = 1,
24
+ corpus_normalize: float = 0.0,
25
+ causal: bool = True,
26
+ vae_dim: int = 64,
27
+ fix_std: float = 0.5,
28
+ std_dist_type: str = 'gaussian',
29
+ # common
30
+ mixer_layer: str = 'depthwise_conv',
31
+ conv_norm: str = 'none',
32
+ pad_mode: str = 'constant',
33
+ disable_last_norm: bool = True,
34
+ layernorm: str = 'RMSNorm',
35
+ layernorm_eps: float = 1e-5,
36
+ layernorm_elementwise_affine: bool = True,
37
+ conv_bias: bool = True,
38
+ layer_scale_init_value: float = 1e-6,
39
+ weight_init_value: float = 1e-2,
40
+ # encoder specific
41
+ encoder_n_filters: int = 32,
42
+ encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
43
+ encoder_depths: str = "3-3-3-3-3-3-8",
44
+ # decoder specific
45
+ decoder_n_filters: int = 32,
46
+ decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
47
+ decoder_depths: Optional[str] = None,
48
+ **kwargs
49
+ ):
50
+ super().__init__(**kwargs)
51
+ self.channels = channels
52
+ self.corpus_normalize = corpus_normalize
53
+ self.causal = causal
54
+ self.vae_dim = vae_dim
55
+ self.fix_std = fix_std
56
+ self.std_dist_type = std_dist_type
57
+
58
+ # common parameters
59
+ self.conv_norm = conv_norm
60
+ self.pad_mode = pad_mode
61
+ self.layernorm_eps = layernorm_eps
62
+ self.disable_last_norm = disable_last_norm
63
+ self.layernorm = layernorm
64
+ self.layernorm_elementwise_affine = layernorm_elementwise_affine
65
+ self.conv_bias = conv_bias
66
+ self.layer_scale_init_value = layer_scale_init_value
67
+ self.weight_init_value = weight_init_value
68
+ self.mixer_layer = mixer_layer
69
+
70
+ # encoder specific parameters
71
+ self.encoder_n_filters = encoder_n_filters
72
+ self.encoder_ratios = encoder_ratios
73
+ self.encoder_depths = encoder_depths
74
+
75
+ # decoder specific parameters
76
+ self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
77
+ self.decoder_n_filters = decoder_n_filters
78
+ self.decoder_depths = decoder_depths
79
+
80
+
81
+ class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
82
+ model_type = "vibevoice_semantic_tokenizer"
83
+
84
+ def __init__(
85
+ self,
86
+ channels: int = 1,
87
+ corpus_normalize: float = 0.0,
88
+ causal: bool = True,
89
+ vae_dim: int = 64,
90
+ fix_std: float = 0,
91
+ std_dist_type: str = 'none',
92
+ # common
93
+ mixer_layer: str = 'depthwise_conv',
94
+ conv_norm: str = 'none',
95
+ pad_mode: str = 'constant',
96
+ disable_last_norm: bool = True,
97
+ layernorm: str = 'RMSNorm',
98
+ layernorm_eps: float = 1e-5,
99
+ layernorm_elementwise_affine: bool = True,
100
+ conv_bias: bool = True,
101
+ layer_scale_init_value: float = 1e-6,
102
+ weight_init_value: float = 1e-2,
103
+ # encoder specific
104
+ encoder_n_filters: int = 32,
105
+ encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
106
+ encoder_depths: str = "3-3-3-3-3-3-8",
107
+ **kwargs
108
+ ):
109
+ super().__init__(**kwargs)
110
+ self.channels = channels
111
+ self.corpus_normalize = corpus_normalize
112
+ self.causal = causal
113
+ self.vae_dim = vae_dim
114
+ self.fix_std = fix_std
115
+ self.std_dist_type = std_dist_type
116
+
117
+ # common parameters
118
+ self.conv_norm = conv_norm
119
+ self.pad_mode = pad_mode
120
+ self.layernorm_eps = layernorm_eps
121
+ self.disable_last_norm = disable_last_norm
122
+ self.layernorm = layernorm
123
+ self.layernorm_elementwise_affine = layernorm_elementwise_affine
124
+ self.conv_bias = conv_bias
125
+ self.layer_scale_init_value = layer_scale_init_value
126
+ self.weight_init_value = weight_init_value
127
+ self.mixer_layer = mixer_layer
128
+
129
+ # encoder specific parameters
130
+ self.encoder_n_filters = encoder_n_filters
131
+ self.encoder_ratios = encoder_ratios
132
+ self.encoder_depths = encoder_depths
133
+
134
+
135
+ class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
136
+ model_type = "vibevoice_diffusion_head"
137
+
138
+ def __init__(
139
+ self,
140
+ hidden_size=768,
141
+ head_layers=4,
142
+ head_ffn_ratio=3.0,
143
+ rms_norm_eps=1e-5,
144
+ latent_size=64,
145
+ speech_vae_dim=None,
146
+ prediction_type="v_prediction",
147
+ diffusion_type="ddpm",
148
+ ddpm_num_steps=1000,
149
+ ddpm_num_inference_steps=20,
150
+ ddpm_beta_schedule="cosine",
151
+ ddpm_batch_mul=4,
152
+ **kwargs
153
+ ):
154
+ self.hidden_size = hidden_size
155
+ self.head_layers = head_layers
156
+ self.head_ffn_ratio = head_ffn_ratio
157
+ self.rms_norm_eps = rms_norm_eps
158
+ self.latent_size = latent_size
159
+ self.speech_vae_dim = speech_vae_dim
160
+ self.prediction_type = prediction_type
161
+ self.diffusion_type = diffusion_type
162
+ self.ddpm_num_steps = ddpm_num_steps
163
+ self.ddpm_num_inference_steps = ddpm_num_inference_steps
164
+ self.ddpm_beta_schedule = ddpm_beta_schedule
165
+ self.ddpm_batch_mul = ddpm_batch_mul
166
+
167
+ super().__init__(**kwargs)
168
+
169
+ class VibeVoiceConfig(PretrainedConfig):
170
+ model_type = "vibevoice"
171
+ is_composition = True
172
+ sub_configs = {
173
+ "acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
174
+ "semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
175
+ "decoder_config": Qwen2Config,
176
+ "diffusion_head_config": VibeVoiceDiffusionHeadConfig,
177
+ }
178
+ # keys_to_ignore_at_inference = ["past_key_values"]
179
+ # Default tensor parallel plan for base model `Qwen2`
180
+ base_model_tp_plan = {
181
+ "layers.*.self_attn.q_proj": "colwise",
182
+ "layers.*.self_attn.k_proj": "colwise",
183
+ "layers.*.self_attn.v_proj": "colwise",
184
+ "layers.*.self_attn.o_proj": "rowwise",
185
+ "layers.*.mlp.gate_proj": "colwise",
186
+ "layers.*.mlp.up_proj": "colwise",
187
+ "layers.*.mlp.down_proj": "rowwise",
188
+ }
189
+
190
+ def __init__(
191
+ self,
192
+ acoustic_tokenizer_config=None,
193
+ semantic_tokenizer_config=None,
194
+ decoder_config=None,
195
+ diffusion_head_config=None,
196
+ **kwargs
197
+ ):
198
+
199
+ # kwargs["_attn_implementation"] = "flash_attention_2"
200
+ kwargs["_attn_implementation_autoset"] = False
201
+
202
+ if acoustic_tokenizer_config is None:
203
+ self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
204
+ elif isinstance(acoustic_tokenizer_config, dict):
205
+ acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
206
+ self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
207
+ elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
208
+ # If an instance of the config class is provided
209
+ self.acoustic_tokenizer_config = acoustic_tokenizer_config
210
+
211
+ if semantic_tokenizer_config is None:
212
+ self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
213
+ elif isinstance(semantic_tokenizer_config, dict):
214
+ semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
215
+ self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
216
+ elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
217
+ # If an instance of the config class is provided
218
+ self.semantic_tokenizer_config = semantic_tokenizer_config
219
+
220
+ if decoder_config is None:
221
+ self.decoder_config = self.sub_configs["decoder_config"]()
222
+ elif isinstance(decoder_config, dict):
223
+ # If a dictionary is provided, instantiate the config class with it
224
+ # self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
225
+ if decoder_config.get("model_type", '') == "qwen2":
226
+ self.decoder_config = Qwen2Config(**decoder_config)
227
+ else:
228
+ raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
229
+ elif isinstance(decoder_config, (Qwen2Config,)):
230
+ # If an instance of the config class is provided
231
+ self.decoder_config = decoder_config
232
+
233
+ if diffusion_head_config is None:
234
+ self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
235
+ elif isinstance(diffusion_head_config, dict):
236
+ diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
237
+ self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
238
+ elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
239
+ # If an instance of the config class is provided
240
+ self.diffusion_head_config = diffusion_head_config
241
+
242
+ # other parameters
243
+ self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
244
+ self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
245
+
246
+ # Add attributes required by newer transformers versions from decoder_config
247
+ # These are used by GenerationMixin in newer versions
248
+ if hasattr(self.decoder_config, 'num_hidden_layers'):
249
+ self.num_hidden_layers = self.decoder_config.num_hidden_layers
250
+ if hasattr(self.decoder_config, 'vocab_size'):
251
+ self.vocab_size = self.decoder_config.vocab_size
252
+ if hasattr(self.decoder_config, 'hidden_size'):
253
+ self.hidden_size = self.decoder_config.hidden_size
254
+ if hasattr(self.decoder_config, 'num_attention_heads'):
255
+ self.num_attention_heads = self.decoder_config.num_attention_heads
256
+ if hasattr(self.decoder_config, 'num_key_value_heads'):
257
+ self.num_key_value_heads = self.decoder_config.num_key_value_heads
258
+ if hasattr(self.decoder_config, 'intermediate_size'):
259
+ self.intermediate_size = self.decoder_config.intermediate_size
260
+ if hasattr(self.decoder_config, 'max_position_embeddings'):
261
+ self.max_position_embeddings = self.decoder_config.max_position_embeddings
262
+
263
+ super().__init__(**kwargs)
264
+
265
+ __all__ = [
266
+ "VibeVoiceAcousticTokenizerConfig",
267
+ "VibeVoiceSemanticTokenizerConfig",
268
+ "VibeVoiceDiffusionHeadConfig",
269
+ "VibeVoiceConfig"
270
+ ]
vvembed/modular/modeling_vibevoice.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple, Union, Callable
3
+ from tqdm import tqdm
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.distributed as dist
8
+
9
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
10
+
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
13
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
14
+ from transformers import modeling_utils
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.utils import logging
18
+
19
+
20
+ from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
21
+ from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
22
+ from schedule.dpm_solver import DPMSolverMultistepScheduler
23
+
24
+ from .configuration_vibevoice import VibeVoiceConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
30
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
31
+
32
+ @dataclass
33
+ class VibeVoiceCausalLMOutputWithPast(ModelOutput):
34
+ loss: Optional[torch.FloatTensor] = None
35
+ diffusion_loss: Optional[torch.FloatTensor] = None
36
+ speech_token_num: Optional[int] = None
37
+ logits: torch.FloatTensor = None
38
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
39
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
40
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
41
+
42
+
43
+ @dataclass
44
+ class VibeVoiceGenerationOutput(ModelOutput):
45
+ """
46
+ Output type for VibeVoice generation.
47
+
48
+ Args:
49
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
50
+ The generated sequences.
51
+ speech_outputs (`List[torch.FloatTensor]`, *optional*):
52
+ List of generated speech waveforms or latents for each speech segment.
53
+ """
54
+ sequences: torch.LongTensor = None
55
+ speech_outputs: Optional[List[torch.FloatTensor]] = None
56
+
57
+
58
+ class SpeechConnector(nn.Module):
59
+ def __init__(self, input_dim, output_dim):
60
+ super().__init__()
61
+ self.fc1 = nn.Linear(input_dim, output_dim)
62
+ self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
63
+ self.fc2 = nn.Linear(output_dim, output_dim)
64
+
65
+ def forward(self, features, **kwargs):
66
+ x = self.fc1(features)
67
+ x = self.norm(x)
68
+ x = self.fc2(x)
69
+ return x
70
+
71
+
72
+ # @auto_docstring
73
+ class VibeVoicePreTrainedModel(PreTrainedModel):
74
+ config_class = VibeVoiceConfig
75
+ base_model_prefix = "model"
76
+ supports_gradient_checkpointing = True
77
+ _skip_keys_device_placement = "past_key_values"
78
+ _supports_cache_class = True
79
+ _supports_flash_attn_2 = True
80
+ _supports_sdpa = True
81
+ _supports_quantized_cache = True
82
+ _supports_static_cache = True
83
+ _supports_attention_backend = True
84
+
85
+ def _init_weights(self, module):
86
+ if isinstance(module, VibeVoiceDiffusionHead):
87
+ module.initialize_weights()
88
+ return
89
+
90
+ # Use the language model's initializer_range if available
91
+ if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
92
+ std = self.config.language_model_config.initializer_range
93
+ elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
94
+ std = self.config.decoder_config.initializer_range
95
+ else:
96
+ std = 0.02 # Default value
97
+
98
+ if isinstance(module, nn.Linear):
99
+ module.weight.data.normal_(mean=0.0, std=std)
100
+ if module.bias is not None:
101
+ module.bias.data.zero_()
102
+ elif isinstance(module, nn.LayerNorm):
103
+ module.weight.data.fill_(1.0)
104
+ module.bias.data.zero_()
105
+
106
+ # @auto_docstring
107
+ class VibeVoiceModel(VibeVoicePreTrainedModel):
108
+ def __init__(self, config):
109
+ super().__init__(config)
110
+
111
+ if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
112
+ if isinstance(config.torch_dtype, str):
113
+ dtype = getattr(torch, config.torch_dtype)
114
+ else:
115
+ dtype = config.torch_dtype
116
+ else:
117
+ dtype = torch.float32
118
+
119
+ # Initialize Qwen2 model for language modeling
120
+ lm_config = config.decoder_config
121
+ self.language_model = AutoModel.from_config(lm_config)
122
+
123
+ # Initialize speech components if needed
124
+ self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
125
+ self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
126
+
127
+ self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
128
+ self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
129
+
130
+ # Register scaling factors as buffers - use 1D tensors for FSDP compatibility
131
+ self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
132
+ self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
133
+
134
+ # Initialize prediction head for speech generation
135
+ self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
136
+
137
+ # Initialize noise scheduler
138
+ self.noise_scheduler = DPMSolverMultistepScheduler(
139
+ num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
140
+ beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
141
+ prediction_type=config.diffusion_head_config.prediction_type
142
+ )
143
+
144
+ def get_input_embeddings(self):
145
+ if hasattr(self.language_model, 'embed_tokens'):
146
+ # If the language model has an embed_tokens attribute, return it
147
+ return self.language_model.embed_tokens
148
+
149
+ for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
150
+ if attr.orig_name == 'embed_tokens.weight':
151
+ return getattr(self.language_model, name)
152
+ assert False, 'should not arrive here'
153
+
154
+ def set_input_embeddings(self, value):
155
+ self.language_model.embed_tokens = value
156
+
157
+ def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
158
+ """Set the speech tokenizers used for encoding and decoding speech."""
159
+ self.acoustic_tokenizer = acoustic_tokenizer
160
+ self.semantic_tokenizer = semantic_tokenizer
161
+
162
+ # Reset the encoder to evaluation mode
163
+ if self.acoustic_tokenizer is not None:
164
+ self.acoustic_tokenizer.eval()
165
+
166
+ if self.semantic_tokenizer is not None:
167
+ self.semantic_tokenizer.eval()
168
+
169
+ def forward(
170
+ self,
171
+ input_ids: torch.LongTensor = None,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ position_ids: Optional[torch.LongTensor] = None,
174
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
175
+ inputs_embeds: Optional[torch.FloatTensor] = None,
176
+ use_cache: Optional[bool] = None,
177
+ output_attentions: Optional[bool] = None,
178
+ output_hidden_states: Optional[bool] = None,
179
+ return_dict: Optional[bool] = None,
180
+ cache_position: Optional[torch.LongTensor] = None,
181
+ **kwargs,
182
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
183
+
184
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
185
+
186
+ # Forward through language model
187
+ outputs = self.language_model(
188
+ input_ids=input_ids,
189
+ attention_mask=attention_mask,
190
+ position_ids=position_ids,
191
+ past_key_values=past_key_values,
192
+ inputs_embeds=inputs_embeds,
193
+ use_cache=use_cache,
194
+ output_attentions=output_attentions,
195
+ output_hidden_states=output_hidden_states,
196
+ return_dict=return_dict,
197
+ cache_position=cache_position,
198
+ **kwargs,
199
+ )
200
+
201
+ if not return_dict:
202
+ return outputs
203
+
204
+ return BaseModelOutputWithPast(
205
+ last_hidden_state=outputs.last_hidden_state,
206
+ past_key_values=outputs.past_key_values,
207
+ hidden_states=outputs.hidden_states,
208
+ attentions=outputs.attentions,
209
+ )
210
+
211
+
212
+ class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
213
+ _tied_weights_keys = ["lm_head.weight"]
214
+ _tp_plan = {"lm_head": "colwise_rep"}
215
+
216
+ def __init__(self, config):
217
+ super().__init__(config)
218
+ self.model = VibeVoiceModel(config)
219
+ self.vocab_size = config.decoder_config.vocab_size
220
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
221
+
222
+ self.post_init()
223
+
224
+ def get_input_embeddings(self):
225
+ return self.model.get_input_embeddings()
226
+
227
+ def set_input_embeddings(self, value):
228
+ self.model.set_input_embeddings(value)
229
+
230
+ def get_output_embeddings(self):
231
+ return self.lm_head
232
+
233
+ def set_decoder(self, decoder):
234
+ self.model.language_model = decoder
235
+
236
+ def get_decoder(self):
237
+ return self.model.language_model
238
+
239
+ def tie_weights(self):
240
+ """
241
+ Tie the weights between the input embeddings and the output embeddings.
242
+ """
243
+ if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
244
+ # The standard PreTrainedModel method will handle the tying.
245
+ # It typically does a simple parameter object assignment, which is
246
+ # CORRECT to do BEFORE FSDP wraps the model.
247
+ output_embeddings = self.get_output_embeddings()
248
+ input_embeddings = self.get_input_embeddings()
249
+ if hasattr(input_embeddings, 'weight'):
250
+ output_embeddings.weight = input_embeddings.weight
251
+ else:
252
+ # maybe returned input_embeddings a tensor directly
253
+ output_embeddings.weight = input_embeddings
254
+
255
+ if getattr(output_embeddings, "bias", None) is not None:
256
+ output_embeddings.bias.data = nn.functional.pad(
257
+ output_embeddings.bias.data,
258
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
259
+ "constant",
260
+ 0,
261
+ )
262
+ print("✅ Tied input and output embeddings using standard assignment.")
263
+ else:
264
+ print("ℹ️ tie_word_embeddings is False, not tying weights.")
265
+
266
+ # Also, ensure set_output_embeddings is safe, though your implementation looks okay.
267
+ # The key is to avoid calling it after accelerator.prepare().
268
+ def set_output_embeddings(self, new_embeddings):
269
+ # Your current implementation using data.copy_ is good practice,
270
+ # but the best way is to not call this after prepare().
271
+ self.lm_head = new_embeddings
272
+
273
+ def forward_speech_features(
274
+ self,
275
+ speech_tensors=None,
276
+ speech_masks=None,
277
+ speech_type="audio",
278
+ return_unmask=False
279
+ ):
280
+ if speech_tensors is None:
281
+ # Use config to get vae_dim instead of non-existent self.args
282
+ vae_dim = self.config.acoustic_tokenizer_config.vae_dim
283
+ audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
284
+ connect_features = self.model.acoustic_connector(audio_features)
285
+ return audio_features, connect_features
286
+ else:
287
+ with torch.no_grad():
288
+ if speech_type == "audio":
289
+ with torch.no_grad():
290
+ frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
291
+ audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
292
+
293
+ elif speech_type == "vae":
294
+ # Use config to get vae_dim instead of non-existent self.args
295
+ vae_dim = self.config.acoustic_tokenizer_config.vae_dim
296
+ speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
297
+
298
+ # gaussian sample from the speech_mode
299
+ batch_size = speech_mode.size(0)
300
+ value = self.model.acoustic_tokenizer.fix_std / 0.8
301
+ std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
302
+ std = std.view(-1, *[1] * (speech_mode.dim() - 1))
303
+ audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
304
+ else:
305
+ raise NotImplementedError(f"Speech type {speech_type} not implemented")
306
+
307
+ if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
308
+ scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
309
+ bias_factor = -audio_tokens[speech_masks].flatten().mean()
310
+
311
+ # Only use distributed operations if the process group is initialized
312
+ if dist.is_available() and dist.is_initialized():
313
+ dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
314
+ dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
315
+ world_size = dist.get_world_size()
316
+ self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
317
+ self.model.speech_bias_factor.copy_(bias_factor / world_size)
318
+ print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
319
+ else:
320
+ # Single process case
321
+ self.model.speech_scaling_factor.copy_(scaling_factor)
322
+ self.model.speech_bias_factor.copy_(bias_factor)
323
+ print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
324
+
325
+ audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
326
+
327
+ connect_features = self.model.acoustic_connector(audio_features)
328
+ if return_unmask:
329
+ return audio_features, connect_features
330
+ return audio_features[speech_masks], connect_features[speech_masks]
331
+
332
+ def forward(
333
+ self,
334
+ input_ids: torch.LongTensor = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ position_ids: Optional[torch.LongTensor] = None,
337
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
338
+ inputs_embeds: Optional[torch.FloatTensor] = None,
339
+ labels: Optional[torch.LongTensor] = None,
340
+ use_cache: Optional[bool] = False,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ cache_position: Optional[torch.LongTensor] = None,
345
+ # New arguments for speech processing and loss calculation
346
+ speech_tensors: Optional[torch.FloatTensor] = None,
347
+ speech_masks: Optional[torch.BoolTensor] = None,
348
+ speeches_loss_input: Optional[torch.FloatTensor] = None,
349
+ speech_semantic_tensors: Optional[torch.FloatTensor] = None,
350
+ acoustic_input_mask: Optional[torch.BoolTensor] = None,
351
+ acoustic_loss_mask: Optional[torch.BoolTensor] = None,
352
+ ddpm_batch_mul: int = 1,
353
+ **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
354
+ ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
355
+
356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
357
+
358
+ x = self.get_input_embeddings()(input_ids)
359
+
360
+ semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
361
+ if speeches_loss_input is not None:
362
+ # only part audio need diffuse
363
+ speech_all_features, speech_all_connect_features = self.forward_speech_features(
364
+ speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
365
+ speech_masks=speech_masks,
366
+ speech_type=kwargs.get("speech_type", "audio"),
367
+ return_unmask=True
368
+ )
369
+ if speech_tensors is not None:
370
+ if semantic_speech_all_connect_features is not None:
371
+ x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks]
372
+ else:
373
+ x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
374
+ speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse
375
+ speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks]
376
+ else:
377
+ speech_features, speech_connect_features = self.forward_speech_features(
378
+ speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
379
+ speech_masks=speech_masks,
380
+ speech_type=kwargs.get("speech_type", "audio"),
381
+ )
382
+ if speech_tensors is not None:
383
+ x[acoustic_input_mask] = speech_connect_features
384
+
385
+ outputs = self.model(
386
+ input_ids=None,
387
+ attention_mask=attention_mask,
388
+ position_ids=position_ids,
389
+ past_key_values=past_key_values,
390
+ inputs_embeds=x,
391
+ use_cache=use_cache,
392
+ output_attentions=output_attentions,
393
+ output_hidden_states=False,
394
+ return_dict=return_dict,
395
+ cache_position=cache_position,
396
+ )
397
+
398
+ hidden_states = outputs.last_hidden_state
399
+ logits = self.lm_head(hidden_states)
400
+ # logits = logits.float()
401
+
402
+ loss = None
403
+ if labels is not None:
404
+ # The custom CE loss with masking is calculated in the training script.
405
+ # We leave the standard loss calculation here as None.
406
+ pass
407
+
408
+ # --- Diffusion Loss Calculation ---
409
+ diffusion_loss = None
410
+ # This block is executed only if we are in a context that involves speech.
411
+ if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
412
+ condition_features = hidden_states[acoustic_loss_mask]
413
+
414
+ speech_len, latent_size = speech_features.shape
415
+
416
+ noise = torch.randn(
417
+ (speech_len * ddpm_batch_mul, latent_size),
418
+ device=hidden_states.device,
419
+ dtype=hidden_states.dtype
420
+ )
421
+
422
+ timesteps = torch.multinomial(
423
+ torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
424
+ speech_len * ddpm_batch_mul,
425
+ replacement=True,
426
+ ).to(hidden_states.device)
427
+
428
+ speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
429
+ condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
430
+
431
+ noisy_speech_features = self.model.noise_scheduler.add_noise(
432
+ speech_features_repeated, noise, timesteps
433
+ )
434
+
435
+ model_output = self.model.prediction_head(
436
+ noisy_speech_features,
437
+ timesteps.type_as(x),
438
+ condition_features_repeated
439
+ )
440
+
441
+ prediction_type = self.config.diffusion_head_config.prediction_type
442
+ if prediction_type == "epsilon":
443
+ target_for_loss = noise
444
+ elif prediction_type == "v_prediction":
445
+ target_for_loss = self.model.noise_scheduler.get_velocity(
446
+ speech_features_repeated, noise, timesteps
447
+ )
448
+ else:
449
+ raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
450
+
451
+ diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
452
+ if latent_size > 0 and ddpm_batch_mul > 0:
453
+ diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
454
+ else:
455
+ diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
456
+
457
+ else:
458
+ # Dummy loss for DDP to work when there are no speech samples in a batch,
459
+ # but we are in a speech context.
460
+ diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
461
+ diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
462
+ diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
463
+ # --- End Diffusion Loss Calculation ---
464
+
465
+ if not return_dict:
466
+ output = (logits, speech_len) + outputs.to_tuple()[1:]
467
+ return (loss, diffusion_loss) + output
468
+
469
+ return VibeVoiceCausalLMOutputWithPast(
470
+ loss=loss,
471
+ diffusion_loss=diffusion_loss,
472
+ speech_token_num=speech_len if speech_tensors is not None else 0,
473
+ logits=logits,
474
+ past_key_values=outputs.past_key_values,
475
+ hidden_states=outputs.hidden_states,
476
+ attentions=outputs.attentions,
477
+ )
478
+
479
+ AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
480
+ AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
481
+
482
+ __all__ = [
483
+ "VibeVoiceModel",
484
+ "VibeVoicePreTrainedModel",
485
+ "VibeVoiceForConditionalGeneration",
486
+ "VibeVoiceCausalLMOutputWithPast",
487
+ "VibeVoiceGenerationOutput",
488
+ ]
vvembed/modular/modeling_vibevoice_inference.py ADDED
@@ -0,0 +1,838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code by Microsoft
2
+ # updated by Fabio Sarracino - Enemyx-net
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional, Tuple, Union, Callable
6
+ from tqdm import tqdm
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
11
+
12
+ from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
14
+ from transformers import modeling_utils
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.utils import logging
18
+
19
+
20
+ # from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
21
+ from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput
22
+ from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
23
+ from schedule.dpm_solver import DPMSolverMultistepScheduler
24
+
25
+ from .configuration_vibevoice import VibeVoiceConfig
26
+
27
+ from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
28
+
29
+ from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel
30
+ from .streamer import AudioStreamer, AsyncAudioStreamer
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
35
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
36
+
37
+ @dataclass
38
+ class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
39
+ logits: Optional[torch.FloatTensor] = None
40
+
41
+ @dataclass
42
+ class VibeVoiceGenerationOutput(ModelOutput):
43
+ """
44
+ Output type for VibeVoice generation.
45
+
46
+ Args:
47
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
48
+ The generated sequences.
49
+ speech_outputs (`List[torch.FloatTensor]`, *optional*):
50
+ List of generated speech waveforms or latents for each speech segment.
51
+ """
52
+ sequences: torch.LongTensor = None
53
+ speech_outputs: Optional[List[torch.FloatTensor]] = None
54
+ reach_max_step_sample: Optional[torch.BoolTensor] = None
55
+
56
+ class VibeVoiceTokenConstraintProcessor(LogitsProcessor):
57
+ """Constrains token generation to only valid tokens during speech generation."""
58
+
59
+ def __init__(self, valid_token_ids: List[int], device: torch.device = None):
60
+ self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
61
+
62
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
63
+ # Create a mask for valid tokens
64
+ mask = torch.full_like(scores, float('-inf'))
65
+ mask[:, self.valid_token_ids] = 0
66
+
67
+ # Apply mask to scores
68
+ scores = scores + mask
69
+ return scores
70
+
71
+ def access_cache_safely(cache, layer_idx):
72
+ """Access cache tensors safely across different transformers versions
73
+
74
+ This function handles the different DynamicCache structures across transformers versions:
75
+ - Old versions (< 4.36): cache.key_cache, cache.value_cache
76
+ - Intermediate versions: cache._keys, cache._values
77
+ - New versions (4.36+): Various new structures
78
+
79
+ Returns (k_cache, v_cache) or (None, None) if cache structure is incompatible
80
+ """
81
+ try:
82
+ # Method 1: Old versions (< 4.36)
83
+ if hasattr(cache, 'key_cache') and hasattr(cache, 'value_cache'):
84
+ if layer_idx < len(cache.key_cache):
85
+ return cache.key_cache[layer_idx], cache.value_cache[layer_idx]
86
+
87
+ # Method 2: Private attributes (some intermediate versions)
88
+ if hasattr(cache, '_keys') and hasattr(cache, '_values'):
89
+ if layer_idx < len(cache._keys):
90
+ return cache._keys[layer_idx], cache._values[layer_idx]
91
+
92
+ # Method 3: New versions with get_seq_length or similar
93
+ # Some versions store as list of tuples
94
+ if isinstance(cache, (list, tuple)) and len(cache) > layer_idx:
95
+ layer_cache = cache[layer_idx]
96
+ if isinstance(layer_cache, (list, tuple)) and len(layer_cache) >= 2:
97
+ return layer_cache[0], layer_cache[1]
98
+ elif hasattr(layer_cache, 'key_states') and hasattr(layer_cache, 'value_states'):
99
+ return layer_cache.key_states, layer_cache.value_states
100
+
101
+ # Method 4: Check if cache has a different structure entirely
102
+ # Some very new versions might not expose cache directly
103
+ if hasattr(cache, 'to_legacy_tuple'):
104
+ # Convert to legacy format if possible
105
+ legacy = cache.to_legacy_tuple()
106
+ if legacy and layer_idx < len(legacy):
107
+ return legacy[layer_idx][0], legacy[layer_idx][1]
108
+
109
+ except (AttributeError, IndexError, TypeError) as e:
110
+ # Log the issue but don't fail
111
+ logger.debug(f"Could not access cache at layer {layer_idx}: {e}")
112
+
113
+ # Return None if we can't access the cache safely
114
+ return None, None
115
+
116
+ def get_num_layers_from_cache(cache):
117
+ """Get the number of layers in the cache structure"""
118
+ try:
119
+ if hasattr(cache, 'key_cache'):
120
+ return len(cache.key_cache)
121
+ elif hasattr(cache, '_keys'):
122
+ return len(cache._keys)
123
+ elif isinstance(cache, (list, tuple)):
124
+ return len(cache)
125
+ elif hasattr(cache, 'num_layers'):
126
+ return cache.num_layers
127
+ # Default fallback - most models have 32 or fewer layers
128
+ return 32
129
+ except:
130
+ return 32
131
+
132
+ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin):
133
+ _tied_weights_keys = ["lm_head.weight"]
134
+ _tp_plan = {"lm_head": "colwise_rep"}
135
+
136
+ def __init__(self, config):
137
+ super().__init__(config)
138
+
139
+ # Initialize the base model
140
+ self.model = VibeVoiceModel(config)
141
+
142
+ # LM head for text generation
143
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False)
144
+
145
+ # inference configuration
146
+ self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
147
+
148
+ # Initialize weights and apply final processing
149
+ self.post_init()
150
+
151
+ @property
152
+ def noise_scheduler(self):
153
+ return self.model.noise_scheduler
154
+
155
+ @property
156
+ def prediction_head(self):
157
+ return self.model.prediction_head
158
+
159
+ @property
160
+ def speech_scaling_factor(self):
161
+ return self.model.speech_scaling_factor
162
+
163
+ @property
164
+ def speech_bias_factor(self):
165
+ return self.model.speech_bias_factor
166
+
167
+ @property
168
+ def acoustic_tokenizer(self):
169
+ return self.model.acoustic_tokenizer
170
+
171
+ @property
172
+ def semantic_tokenizer(self):
173
+ return self.model.semantic_tokenizer
174
+
175
+ @property
176
+ def acoustic_connector(self):
177
+ return self.model.acoustic_connector
178
+
179
+ @property
180
+ def semantic_connector(self):
181
+ return self.model.semantic_connector
182
+
183
+ def tie_weights(self):
184
+ """
185
+ Tie the weights between the input embeddings and the output embeddings.
186
+ """
187
+ # Tie lm_head.weight to language_model.embed_tokens.weight
188
+ if not getattr(self.config, 'tie_word_embeddings', False):
189
+ return
190
+
191
+ if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
192
+ self.lm_head.weight = self.model.language_model.embed_tokens.weight
193
+
194
+ def get_input_embeddings(self):
195
+ return self.model.get_input_embeddings()
196
+
197
+ def set_input_embeddings(self, value):
198
+ self.model.set_input_embeddings(value)
199
+
200
+ def get_output_embeddings(self):
201
+ return self.lm_head
202
+
203
+ def set_output_embeddings(self, new_embeddings):
204
+ self.lm_head = new_embeddings
205
+
206
+ def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
207
+ """Set the speech tokenizers used for encoding and decoding speech."""
208
+ self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer)
209
+
210
+ def set_ddpm_inference_steps(self, num_steps=None):
211
+ self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
212
+
213
+ def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"):
214
+ """Process speech inputs through tokenizers and connectors."""
215
+ with torch.no_grad():
216
+ if speech_type == "audio":
217
+ # Encode audio to acoustic latents
218
+ encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
219
+ acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
220
+
221
+ # Apply scaling and bias
222
+ acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
223
+
224
+ # Connect to language model space
225
+ acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
226
+
227
+ return acoustic_features, acoustic_connected
228
+ elif speech_type == "pt":
229
+ encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std)
230
+ acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
231
+
232
+ # Apply scaling and bias
233
+ acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device)
234
+
235
+ # Connect to language model space
236
+ acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
237
+
238
+ return acoustic_features, acoustic_connected
239
+ else:
240
+ raise NotImplementedError(f"Speech type {speech_type} not implemented")
241
+
242
+ # @can_return_tuple
243
+ def forward(
244
+ self,
245
+ input_ids: torch.LongTensor = None,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ position_ids: Optional[torch.LongTensor] = None,
248
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
249
+ inputs_embeds: Optional[torch.FloatTensor] = None,
250
+ labels: Optional[torch.LongTensor] = None,
251
+ use_cache: Optional[bool] = None,
252
+ output_attentions: Optional[bool] = None,
253
+ output_hidden_states: Optional[bool] = None,
254
+ return_dict: Optional[bool] = None,
255
+ cache_position: Optional[torch.LongTensor] = None,
256
+ speech_tensors: Optional[torch.FloatTensor] = None,
257
+ speech_masks: Optional[torch.BoolTensor] = None,
258
+ speech_input_mask: Optional[torch.BoolTensor] = None,
259
+ logits_to_keep: Union[int, slice] = 0,
260
+ **kwargs,
261
+ ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
262
+ """
263
+ Args:
264
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
265
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
266
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
267
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
268
+ speech_tensors (`torch.FloatTensor`, *optional*):
269
+ Input speech waveforms for voice cloning or speech understanding.
270
+ speech_masks (`torch.BoolTensor`, *optional*):
271
+ Masks indicating valid speech frames.
272
+ speech_input_mask (`torch.BoolTensor`, *optional*):
273
+ Positions in the input sequence where speech embeddings should be inserted.
274
+
275
+ Returns:
276
+ `VibeVoiceCausalLMOutputWithPast` or tuple
277
+ """
278
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
279
+
280
+ # Get embeddings
281
+ if inputs_embeds is None:
282
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
283
+
284
+ # Process speech inputs if provided
285
+ if speech_tensors is not None and speech_masks is not None:
286
+ acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors.to(self.dtype), speech_masks)
287
+ if speech_input_mask is not None:
288
+ inputs_embeds[speech_input_mask] = speech_embeds
289
+
290
+ outputs = self.model(
291
+ inputs_embeds=inputs_embeds,
292
+ attention_mask=attention_mask,
293
+ position_ids=position_ids,
294
+ past_key_values=past_key_values,
295
+ use_cache=use_cache,
296
+ output_attentions=output_attentions,
297
+ output_hidden_states=output_hidden_states,
298
+ return_dict=return_dict,
299
+ cache_position=cache_position,
300
+ **kwargs,
301
+ )
302
+
303
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
304
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
305
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
306
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
307
+
308
+ if labels is not None:
309
+ raise NotImplementedError("Loss computation is not implemented in this version.")
310
+
311
+ return VibeVoiceCausalLMOutputWithPast(
312
+ logits=logits,
313
+ past_key_values=outputs.past_key_values,
314
+ last_hidden_state=hidden_states,
315
+ attentions=outputs.attentions,
316
+ )
317
+
318
+ def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
319
+ if generation_config is None:
320
+ generation_config = GenerationConfig(
321
+ bos_token_id=tokenizer.bos_token_id,
322
+ eos_token_id=tokenizer.eos_token_id,
323
+ pad_token_id = tokenizer.pad_token_id
324
+ )
325
+ else:
326
+ generation_config = GenerationConfig(
327
+ **generation_config,
328
+ bos_token_id=tokenizer.bos_token_id,
329
+ eos_token_id=tokenizer.eos_token_id,
330
+ pad_token_id = tokenizer.pad_token_id
331
+ )
332
+
333
+ generation_config, model_kwargs = self._prepare_generation_config(
334
+ generation_config,
335
+ True,
336
+ speech_start_id=tokenizer.speech_start_id,
337
+ speech_end_id=tokenizer.speech_end_id,
338
+ speech_diffusion_id=tokenizer.speech_diffusion_id,
339
+ **kwargs
340
+ )
341
+ generation_config.speech_start_id = tokenizer.speech_start_id
342
+ generation_config.speech_end_id = tokenizer.speech_end_id
343
+ generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
344
+
345
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
346
+ batch_size = inputs_tensor.shape[0]
347
+ device = self.device
348
+
349
+ self._prepare_special_tokens(generation_config, True, device=device)
350
+ generation_config.use_cache = True
351
+ model_kwargs["use_cache"] = generation_config.use_cache
352
+ input_ids = inputs_tensor.to(self.device)
353
+
354
+ input_ids_length = input_ids.shape[1]
355
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
356
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
357
+ generation_config = self._prepare_generated_length(
358
+ generation_config=generation_config,
359
+ has_default_max_length=has_default_max_length,
360
+ has_default_min_length=has_default_min_length,
361
+ model_input_name=model_input_name,
362
+ inputs_tensor=inputs_tensor,
363
+ input_ids_length=input_ids_length,
364
+ )
365
+
366
+ max_cache_length = generation_config.max_length - 1
367
+
368
+ # Fix for transformers compatibility: detect number of parameters accepted
369
+ import inspect
370
+ try:
371
+ sig = inspect.signature(self._prepare_cache_for_generation)
372
+ num_params = len(sig.parameters)
373
+
374
+ # Newer transformers expects 6 parameters (without 'device')
375
+ # Older transformers expects 7 parameters (with 'device')
376
+ if num_params == 6 or 'device' not in sig.parameters:
377
+ # New signature: (self, generation_config, model_kwargs, assistant_model, batch_size, max_cache_length)
378
+ self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length)
379
+ else:
380
+ # Old signature: (self, generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device)
381
+ self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
382
+ except Exception as e:
383
+ # Fallback: try both signatures
384
+ try:
385
+ # Try new signature first (6 params)
386
+ self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length)
387
+ except TypeError:
388
+ # Fall back to old signature (7 params)
389
+ self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
390
+
391
+ model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
392
+ for k, v in model_kwargs.items():
393
+ if isinstance(v, torch.Tensor):
394
+ model_kwargs[k] = v.to(device=device)
395
+
396
+ if return_processors:
397
+ logits_processor = self._get_logits_processor(
398
+ generation_config=generation_config,
399
+ input_ids_seq_length=input_ids_length,
400
+ encoder_input_ids=inputs_tensor,
401
+ prefix_allowed_tokens_fn=None,
402
+ logits_processor=LogitsProcessorList(),
403
+ device=inputs_tensor.device,
404
+ model_kwargs=model_kwargs,
405
+ )
406
+
407
+ stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
408
+
409
+ return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
410
+ else:
411
+ return generation_config, model_kwargs, input_ids
412
+
413
+ @torch.no_grad()
414
+ def generate(
415
+ self,
416
+ inputs: Optional[torch.Tensor] = None,
417
+ generation_config: Optional[GenerationConfig] = None,
418
+ logits_processor: Optional[LogitsProcessorList] = None,
419
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
420
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
421
+ synced_gpus: Optional[bool] = None,
422
+ assistant_model: Optional["PreTrainedModel"] = None,
423
+ audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
424
+ negative_prompt_ids: Optional[torch.Tensor] = None,
425
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
426
+ speech_tensors: Optional[torch.FloatTensor] = None,
427
+ speech_masks: Optional[torch.BoolTensor] = None,
428
+ speech_input_mask: Optional[torch.BoolTensor] = None,
429
+ return_speech: bool = True,
430
+ cfg_scale: float = 1.0,
431
+ stop_check_fn: Optional[Callable[[], bool]] = None,
432
+ **kwargs,
433
+ ) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
434
+ """
435
+ Generates sequences of token ids and optionally speech outputs.
436
+
437
+ Args:
438
+ All standard generation arguments from GenerationMixin
439
+ negative_prompt_ids: Negative prompt for CFG in speech generation
440
+ negative_prompt_attention_mask: Attention mask for negative prompt
441
+ speech_tensors: Input speech for voice cloning
442
+ speech_masks: Masks for speech tensors
443
+ speech_input_mask: Positions to insert speech embeddings
444
+ return_speech: Whether to decode and return speech outputs
445
+ cfg_scale: CFG scale for speech generation
446
+ stop_check_fn: Optional callable that returns True if generation should stop
447
+
448
+ Returns:
449
+ Generated token sequences and optionally speech outputs
450
+ """
451
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
452
+ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
453
+ parsed_scripts = kwargs.pop("parsed_scripts", None)
454
+ all_speakers_list = kwargs.pop("all_speakers_list", None)
455
+ max_length_times = kwargs.pop("max_length_times", 2)
456
+
457
+ if kwargs.get('max_new_tokens', None) is None:
458
+ kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1]
459
+
460
+ generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
461
+ generation_config, inputs, tokenizer, return_processors=True, **kwargs
462
+ )
463
+
464
+ negative_kwargs = {
465
+ 'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device),
466
+ 'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
467
+ 'max_new_tokens': kwargs.get('max_new_tokens', 100)
468
+ }
469
+ negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
470
+ None, None, tokenizer, return_processors=False, **negative_kwargs
471
+ )
472
+
473
+ acoustic_cache = VibeVoiceTokenizerStreamingCache()
474
+ semantic_cache = VibeVoiceTokenizerStreamingCache()
475
+
476
+ batch_size = input_ids.shape[0]
477
+ device = input_ids.device
478
+ finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
479
+ correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
480
+ is_prefill = True
481
+ inputs_embeds = None
482
+ verbose = kwargs.get("verbose", False)
483
+
484
+ # Initialize audio chunks storage for each sample
485
+ audio_chunks = [[] for _ in range(batch_size)]
486
+
487
+ initial_length = input_ids.shape[-1]
488
+ initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1)
489
+
490
+ # Define all valid tokens that can be generated
491
+ valid_tokens = [
492
+ generation_config.speech_start_id,
493
+ generation_config.speech_end_id,
494
+ generation_config.speech_diffusion_id,
495
+ generation_config.eos_token_id
496
+ ]
497
+ # Add bos_token_id if it exists
498
+ if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None:
499
+ valid_tokens.append(generation_config.bos_token_id)
500
+
501
+ # Add custom processor to constrain token generation
502
+ token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device)
503
+ if logits_processor is None:
504
+ logits_processor = LogitsProcessorList()
505
+ logits_processor.append(token_constraint_processor)
506
+
507
+ max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length))
508
+ max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long())
509
+ reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
510
+
511
+ # Create progress iterator if verbose
512
+ if kwargs.get("show_progress_bar", True):
513
+ progress_bar = tqdm(range(max_steps), desc="Generating", leave=False)
514
+ else:
515
+ progress_bar = range(max_steps)
516
+
517
+ for step in progress_bar:
518
+ # Check for external stop signal
519
+ if stop_check_fn is not None and stop_check_fn():
520
+ if verbose:
521
+ print(f"Generation stopped externally at step {step + 1}")
522
+ # End the audio streamer if it exists
523
+ if audio_streamer is not None:
524
+ audio_streamer.end()
525
+ break
526
+
527
+ # Check if audio_streamer has been ended (stopped externally)
528
+ if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
529
+ if any(audio_streamer.finished_flags):
530
+ if verbose:
531
+ print(f"Audio generation stopped externally at step {step + 1}")
532
+ break
533
+
534
+ if finished_tags.all():
535
+ if hasattr(progress_bar, 'set_description'):
536
+ progress_bar.set_description("Generation complete")
537
+ break
538
+
539
+ if input_ids.shape[-1] >= generation_config.max_length:
540
+ print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
541
+ reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
542
+ if reached_samples.numel() > 0:
543
+ reach_max_step_sample[reached_samples] = True
544
+ break
545
+
546
+ # Update progress bar description with active samples
547
+ if hasattr(progress_bar, 'set_description'):
548
+ active_samples = (~finished_tags).sum().item()
549
+ progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})")
550
+
551
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
552
+ if is_prefill:
553
+ # we process the speech inputs only during the first generation step
554
+ prefill_inputs = {
555
+ "speech_tensors": speech_tensors.to(device=device),
556
+ "speech_masks": speech_masks.to(device),
557
+ "speech_input_mask": speech_input_mask.to(device),
558
+ }
559
+ is_prefill = False
560
+ else:
561
+ _ = model_inputs.pop('inputs_embeds', None)
562
+ prefill_inputs = {'inputs_embeds': inputs_embeds}
563
+
564
+ # Forward pass through the model
565
+ outputs = self(
566
+ **model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False,
567
+ )
568
+ model_kwargs = self._update_model_kwargs_for_generation(
569
+ outputs, model_kwargs, is_encoder_decoder=False,
570
+ )
571
+
572
+ # Get logits and apply logits processor
573
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
574
+ # next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
575
+ next_token_scores = logits_processor(input_ids, next_token_logits)
576
+
577
+ # token selection
578
+ if generation_config.do_sample:
579
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
580
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
581
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
582
+ else:
583
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
584
+
585
+ next_tokens[finished_tags] = generation_config.eos_token_id
586
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
587
+
588
+ if not kwargs.get('refresh_negative', True):
589
+ negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
590
+ # Forward negative pass through the model
591
+ if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
592
+ negative_model_inputs['inputs_embeds'] = inputs_embeds
593
+ negative_model_inputs['input_ids'] = None
594
+
595
+ negative_outputs = self(
596
+ **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
597
+ )
598
+ negative_model_kwargs = self._update_model_kwargs_for_generation(
599
+ negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
600
+ )
601
+ negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
602
+
603
+ # reached end of generation
604
+ if (next_tokens == generation_config.eos_token_id).any():
605
+ eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1)
606
+ # Only print for samples that are newly finished (not already marked as finished)
607
+ new_eos_indices = eos_indices[~finished_tags[eos_indices]]
608
+ if new_eos_indices.numel() > 0:
609
+ finished_tags[new_eos_indices] = True
610
+ if verbose:
611
+ print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True)
612
+ if audio_streamer is not None:
613
+ audio_streamer.end(new_eos_indices)
614
+
615
+ # Check if any sample reached its maximum generation length
616
+ max_length_reached = step >= max_step_per_sample
617
+ new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1)
618
+ if new_max_length_indices.numel() > 0:
619
+ finished_tags[new_max_length_indices] = True
620
+ reach_max_step_sample[new_max_length_indices] = True
621
+ if verbose:
622
+ print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True)
623
+ if audio_streamer is not None:
624
+ audio_streamer.end(new_max_length_indices)
625
+
626
+ # speech_end
627
+ diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1)
628
+ if diffusion_end_indices.numel() > 0:
629
+ # Clear tokenizer caches for samples that reached speech end
630
+ acoustic_cache.set_to_zero(diffusion_end_indices)
631
+ semantic_cache.set_to_zero(diffusion_end_indices)
632
+
633
+ # speech_begin
634
+ diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)]
635
+ if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True):
636
+ # update attention mask
637
+ for i, sample_idx in enumerate(diffusion_start_indices.tolist()):
638
+ negative_model_kwargs['attention_mask'][sample_idx, :] = 0
639
+ negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
640
+ # update past key values - using safe cache access
641
+ cache = negative_model_kwargs['past_key_values']
642
+ num_layers = get_num_layers_from_cache(cache)
643
+ cache_update_failed = False
644
+
645
+ for layer_idx in range(num_layers):
646
+ k_cache, v_cache = access_cache_safely(cache, layer_idx)
647
+ if k_cache is None or v_cache is None:
648
+ # Cache structure not compatible, skip optimization
649
+ logger.debug(f"Cache optimization skipped at layer {layer_idx} - incompatible structure")
650
+ cache_update_failed = True
651
+ break
652
+
653
+ # Process each non-diffusion sample
654
+ for sample_idx in diffusion_start_indices.tolist():
655
+ try:
656
+ # Shift cache for this sample
657
+ k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
658
+ v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
659
+ except (IndexError, RuntimeError) as e:
660
+ logger.debug(f"Cache update failed for sample {sample_idx}: {e}")
661
+ cache_update_failed = True
662
+ break
663
+
664
+ if cache_update_failed:
665
+ break
666
+ # update negative_input_ids
667
+ for sample_idx in diffusion_start_indices.tolist():
668
+ negative_input_ids[sample_idx, -1] = generation_config.speech_start_id
669
+
670
+ # Prepare inputs_embeds for next iteration
671
+ # Initialize with default embeddings for all tokens
672
+ next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size]
673
+
674
+ # forward diffusion
675
+ # Diffusion indices are those that are not finished and not special tokens
676
+ diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)]
677
+
678
+ if diffusion_indices.numel() > 0:
679
+ if kwargs.get('refresh_negative', True):
680
+ negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs)
681
+ # Forward negative pass through the model
682
+ if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None:
683
+ negative_model_inputs['inputs_embeds'] = inputs_embeds
684
+ negative_model_inputs['input_ids'] = None
685
+
686
+ negative_outputs = self(
687
+ **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False,
688
+ )
689
+ negative_model_kwargs = self._update_model_kwargs_for_generation(
690
+ negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
691
+ )
692
+ negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1)
693
+ # correct the non-diffusion indices
694
+ # we forward all samples' negative outputs even if
695
+ # they are not in diffusion mode to keep the cache consistent
696
+ # So we need to correct the kv cache of non-diffusion samples
697
+ non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id)
698
+ if non_diffusion_mask.any():
699
+ non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask]
700
+ start_indices = correct_cnt[non_diffusion_indices]
701
+
702
+ # 1. Update attention_mask - need to handle each sample separately
703
+ seq_len = negative_model_kwargs['attention_mask'].shape[1]
704
+ for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())):
705
+ # Shift the attention mask for this sample
706
+ if start_idx + 1 < seq_len - 1:
707
+ negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \
708
+ negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone()
709
+ negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
710
+
711
+ # 2. Update past_key_values - using safe cache access
712
+ cache = negative_model_kwargs['past_key_values']
713
+ num_layers = get_num_layers_from_cache(cache)
714
+ cache_update_failed = False
715
+
716
+ for layer_idx in range(num_layers):
717
+ k_cache, v_cache = access_cache_safely(cache, layer_idx)
718
+ if k_cache is None or v_cache is None:
719
+ # Cache structure not compatible, skip optimization
720
+ logger.debug(f"Cache optimization skipped at layer {layer_idx} - incompatible structure")
721
+ cache_update_failed = True
722
+ break
723
+
724
+ # Process each non-diffusion sample
725
+ for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
726
+ try:
727
+ if start_idx + 1 < k_cache.shape[2] - 1:
728
+ # Shift cache for this sample
729
+ k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone()
730
+ v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone()
731
+ except (IndexError, RuntimeError) as e:
732
+ logger.debug(f"Cache update failed for sample {sample_idx}: {e}")
733
+ cache_update_failed = True
734
+ break
735
+
736
+ if cache_update_failed:
737
+ break
738
+
739
+ # 3. Update negative_input_ids
740
+ for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
741
+ if start_idx + 1 < negative_input_ids.shape[1] - 1:
742
+ negative_input_ids[sample_idx, start_idx+1:] = \
743
+ negative_input_ids[sample_idx, start_idx:-1].clone()
744
+
745
+ correct_cnt[non_diffusion_indices] += 1
746
+
747
+ positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :]
748
+ negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :]
749
+
750
+ speech_latent = self.sample_speech_tokens(
751
+ positive_condition,
752
+ negative_condition,
753
+ cfg_scale=cfg_scale,
754
+ ).unsqueeze(1)
755
+
756
+ # Decode acoustic latent to audio using acoustic streaming cache
757
+ scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
758
+ audio_chunk = self.model.acoustic_tokenizer.decode(
759
+ scaled_latent.to(self.model.acoustic_tokenizer.device),
760
+ cache=acoustic_cache, # Use acoustic-specific cache
761
+ sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
762
+ use_cache=True,
763
+ debug=False
764
+ )
765
+
766
+ # Store audio chunks for each sample
767
+ for i, sample_idx in enumerate(diffusion_indices):
768
+ idx = sample_idx.item()
769
+ # Only append audio chunk if the sample is not finished
770
+ if not finished_tags[idx]:
771
+ audio_chunks[idx].append(audio_chunk[i])
772
+
773
+ # Add streaming support here
774
+ if audio_streamer is not None:
775
+ # Stream the audio chunks immediately
776
+ audio_streamer.put(audio_chunk, diffusion_indices)
777
+
778
+ # Encode audio to semantic features using semantic streaming cache
779
+ semantic_features = self.model.semantic_tokenizer.encode(
780
+ audio_chunk,
781
+ cache=semantic_cache, # Use semantic-specific cache
782
+ sample_indices=diffusion_indices,
783
+ use_cache=True,
784
+ debug=False
785
+ ).mean # semantic tokenizer has no VAE.
786
+
787
+ # Combine acoustic and semantic features for next input
788
+ acoustic_embed = self.model.acoustic_connector(speech_latent)
789
+ semantic_embed = self.model.semantic_connector(semantic_features)
790
+ diffusion_embeds = acoustic_embed + semantic_embed
791
+
792
+ # Update embeddings for diffusion indices
793
+ next_inputs_embeds[diffusion_indices] = diffusion_embeds
794
+
795
+ # Set inputs_embeds for next iteration
796
+ inputs_embeds = next_inputs_embeds
797
+
798
+ if audio_streamer is not None:
799
+ audio_streamer.end()
800
+
801
+ # Concatenate audio chunks for each sample
802
+ final_audio_outputs = []
803
+ for sample_chunks in audio_chunks:
804
+ if sample_chunks:
805
+ # Concatenate all chunks along the time dimension (assumed to be the last dimension)
806
+ concatenated_audio = torch.cat(sample_chunks, dim=-1)
807
+ final_audio_outputs.append(concatenated_audio)
808
+ else:
809
+ # If no audio was generated for this sample, append None
810
+ final_audio_outputs.append(None)
811
+
812
+ return VibeVoiceGenerationOutput(
813
+ sequences=input_ids,
814
+ speech_outputs=final_audio_outputs if return_speech else None,
815
+ reach_max_step_sample=reach_max_step_sample,
816
+ )
817
+
818
+ @torch.no_grad()
819
+ def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
820
+ self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
821
+ condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
822
+ speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
823
+ for t in self.model.noise_scheduler.timesteps:
824
+ half = speech[: len(speech) // 2]
825
+ combined = torch.cat([half, half], dim=0)
826
+ eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
827
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
828
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
829
+ eps = torch.cat([half_eps, half_eps], dim=0)
830
+ speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
831
+ return speech[: len(speech) // 2]
832
+
833
+
834
+ AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference)
835
+
836
+ __all__ = [
837
+ "VibeVoiceForConditionalGenerationInference",
838
+ ]
vvembed/modular/modular_vibevoice_diffusion_head.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from transformers.models.auto import AutoModel
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ # from transformers.modeling_layers import GradientCheckpointingLayer
11
+ from transformers.activations import ACT2FN
12
+ from transformers.utils import logging
13
+
14
+ from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class RMSNorm(nn.Module):
21
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
22
+ super().__init__()
23
+ self.dim = dim
24
+ self.eps = eps
25
+ self.elementwise_affine = elementwise_affine
26
+ if self.elementwise_affine:
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+ else:
29
+ self.register_parameter('weight', None)
30
+
31
+ def _norm(self, x):
32
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
33
+
34
+ def forward(self, x):
35
+ output = self._norm(x.float()).type_as(x)
36
+ if self.weight is not None:
37
+ output = output * self.weight
38
+ return output
39
+
40
+ def extra_repr(self) -> str:
41
+ return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
42
+
43
+ def modulate(x, shift, scale):
44
+ """Apply modulation to input tensor."""
45
+ return x * (1 + scale) + shift
46
+
47
+
48
+ class TimestepEmbedder(nn.Module):
49
+ """
50
+ Embeds scalar timesteps into vector representations.
51
+
52
+ Args:
53
+ hidden_size (`int`): Size of the output embedding
54
+ frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
55
+ """
56
+ def __init__(self, hidden_size, frequency_embedding_size=256):
57
+ super().__init__()
58
+ self.mlp = nn.Sequential(
59
+ nn.Linear(frequency_embedding_size, hidden_size, bias=False),
60
+ # nn.SiLU(),
61
+ ACT2FN['silu'],
62
+ nn.Linear(hidden_size, hidden_size, bias=False),
63
+ )
64
+ self.frequency_embedding_size = frequency_embedding_size
65
+
66
+ @staticmethod
67
+ def timestep_embedding(t, dim, max_period=10000):
68
+ """
69
+ Create sinusoidal timestep embeddings.
70
+
71
+ Args:
72
+ t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
73
+ These may be fractional.
74
+ dim (`int`): The dimension of the output.
75
+ max_period (`int`, optional): Controls the minimum frequency of the embeddings.
76
+
77
+ Returns:
78
+ `torch.Tensor`: An [N, D] Tensor of positional embeddings.
79
+ """
80
+ half = dim // 2
81
+ freqs = torch.exp(
82
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
83
+ ).to(t.device)
84
+ args = t[:, None].float() * freqs[None]
85
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
86
+ if dim % 2:
87
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
88
+ return embedding.to(t.dtype)
89
+
90
+ def forward(self, t):
91
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
92
+ t_emb = self.mlp(t_freq)
93
+ return t_emb
94
+
95
+
96
+ class FeedForwardNetwork(nn.Module):
97
+ """
98
+ Standard feed-forward network with SwiGLU activation.
99
+
100
+ Args:
101
+ embed_dim (`int`): Input dimension
102
+ ffn_dim (`int`): Hidden dimension
103
+ """
104
+ def __init__(
105
+ self,
106
+ embed_dim,
107
+ ffn_dim,
108
+ ):
109
+ super().__init__()
110
+ self.embed_dim = embed_dim
111
+ self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
112
+ self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
113
+ self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
114
+ self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
115
+
116
+ def forward(self, x):
117
+ gate = self.gate_proj(x)
118
+ up = self.up_proj(x)
119
+
120
+ # SwiGLU activation
121
+ # gate = F.silu(gate)
122
+ gate = self.act_fn(gate)
123
+ return self.down_proj(gate * up)
124
+
125
+
126
+ class HeadLayer(nn.Module):
127
+ """
128
+ A layer in the diffusion head.
129
+
130
+ Args:
131
+ embed_dim (`int`): Input dimension
132
+ ffn_dim (`int`): Hidden dimension
133
+ cond_dim (`int`): Condition embedding dimension
134
+ norm_eps (`float`, optional): Epsilon for normalization
135
+ """
136
+ def __init__(
137
+ self,
138
+ embed_dim,
139
+ ffn_dim,
140
+ cond_dim,
141
+ norm_eps=1e-5,
142
+ ):
143
+ super().__init__()
144
+ self.embed_dim = embed_dim
145
+ self.cond_dim = cond_dim
146
+ self.ffn_dim = ffn_dim
147
+ self.ffn = FeedForwardNetwork(
148
+ self.embed_dim,
149
+ self.ffn_dim,
150
+ )
151
+ self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
152
+ self.adaLN_modulation = nn.Sequential(
153
+ # nn.SiLU(),
154
+ ACT2FN['silu'],
155
+ nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
156
+ )
157
+
158
+ def forward(self, x, c):
159
+ shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
160
+ x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
161
+ return x
162
+
163
+
164
+ class FinalLayer(nn.Module):
165
+ """
166
+ Final layer in the diffusion head.
167
+
168
+ Args:
169
+ hidden_size (`int`): Input dimension
170
+ output_size (`int`): Output dimension
171
+ cond_size (`int`): Condition embedding dimension
172
+ norm_eps (`float`, optional): Epsilon for normalization
173
+ """
174
+ def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
175
+ super().__init__()
176
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
177
+ self.linear = nn.Linear(hidden_size, output_size, bias=False)
178
+ self.adaLN_modulation = nn.Sequential(
179
+ # nn.SiLU(),
180
+ ACT2FN['silu'],
181
+ nn.Linear(cond_size, 2 * hidden_size, bias=False)
182
+ )
183
+
184
+ def forward(self, x, c):
185
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
186
+ x = modulate(self.norm_final(x), shift, scale)
187
+ x = self.linear(x)
188
+ return x
189
+
190
+
191
+ class VibeVoiceDiffusionHead(PreTrainedModel):
192
+ """
193
+ Diffusion head model for vibevoice.
194
+
195
+ Args:
196
+ config (`VibeVoiceDiffusionHeadConfig`): Model configuration
197
+ latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
198
+ """
199
+ config_class = VibeVoiceDiffusionHeadConfig
200
+ supports_gradient_checkpointing = True
201
+ _supports_flash_attn_2 = True
202
+ _supports_sdpa = True
203
+
204
+ def __init__(
205
+ self,
206
+ config,
207
+ ):
208
+ super().__init__(config)
209
+ self.config = config
210
+ self.cond_dim = config.hidden_size
211
+ latent_size = config.latent_size
212
+
213
+ self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
214
+ self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
215
+ self.t_embedder = TimestepEmbedder(self.cond_dim)
216
+
217
+ ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
218
+
219
+ # Create the intermediate layers
220
+ self.layers = nn.ModuleList([
221
+ HeadLayer(
222
+ embed_dim=config.hidden_size,
223
+ ffn_dim=ffn_dim,
224
+ cond_dim=self.cond_dim,
225
+ norm_eps=config.rms_norm_eps
226
+ )
227
+ for _ in range(config.head_layers)
228
+ ])
229
+
230
+ # Final layer for output
231
+ self.final_layer = FinalLayer(
232
+ hidden_size=config.hidden_size,
233
+ output_size=latent_size,
234
+ cond_size=self.cond_dim,
235
+ norm_eps=config.rms_norm_eps
236
+ )
237
+
238
+ self.initialize_weights()
239
+
240
+ def initialize_weights(self):
241
+ """Initialize the weights of the model."""
242
+ # Initialize timestep embedder
243
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
244
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
245
+
246
+ # Zero-out adaLN modulation layers
247
+ for layer in self.layers:
248
+ nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
249
+
250
+ # Zero-out output layers
251
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
252
+ nn.init.constant_(self.final_layer.linear.weight, 0)
253
+
254
+ def forward(
255
+ self,
256
+ noisy_images,
257
+ timesteps,
258
+ condition,
259
+ ):
260
+ """
261
+ Forward pass of the prediction head.
262
+
263
+ Args:
264
+ noisy_images (`torch.Tensor`): Noisy images/latents to denoise
265
+ timesteps (`torch.Tensor`): Timesteps for diffusion
266
+ condition (`torch.Tensor`): Conditioning information
267
+
268
+ Returns:
269
+ `torch.Tensor`: The predicted noise/velocity
270
+ """
271
+ x = self.noisy_images_proj(noisy_images)
272
+ t = self.t_embedder(timesteps)
273
+ condition = self.cond_proj(condition)
274
+ c = condition + t
275
+
276
+ for layer in self.layers:
277
+ x = layer(x, c)
278
+
279
+ x = self.final_layer(x, c)
280
+ return x
281
+
282
+
283
+ AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
284
+
285
+ __all__ = [
286
+ "VibeVoiceDiffusionHead",
287
+ ]
vvembed/modular/modular_vibevoice_text_tokenizer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenization classes for vibevoice."""
2
+
3
+ from typing import List, Optional, Union
4
+
5
+ from transformers.utils import logging
6
+ from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
7
+ from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ class VibeVoiceTextTokenizer(Qwen2Tokenizer):
13
+ """
14
+ Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
15
+
16
+ Args:
17
+ vocab_file (`str`):
18
+ Path to the vocabulary file.
19
+ merges_file (`str`):
20
+ Path to the merges file.
21
+ errors (`str`, *optional*, defaults to `"replace"`):
22
+ Paradigm to follow when decoding bytes to UTF-8.
23
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
24
+ The unknown token.
25
+ bos_token (`str`, *optional*):
26
+ The beginning of sequence token. Not used for vibevoice.
27
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
28
+ The end of sequence token.
29
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
30
+ The token used for padding.
31
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
32
+ Whether or not to add special tokens when encoding.
33
+ """
34
+
35
+ model_input_names = ["input_ids", "attention_mask"]
36
+
37
+ def __init__(
38
+ self,
39
+ vocab_file,
40
+ merges_file,
41
+ errors="replace",
42
+ unk_token="<|endoftext|>",
43
+ bos_token=None,
44
+ eos_token="<|endoftext|>",
45
+ pad_token="<|endoftext|>",
46
+ add_prefix_space=False,
47
+ add_special_tokens=True,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(
51
+ vocab_file=vocab_file,
52
+ merges_file=merges_file,
53
+ errors=errors,
54
+ unk_token=unk_token,
55
+ bos_token=bos_token,
56
+ eos_token=eos_token,
57
+ pad_token=pad_token,
58
+ add_prefix_space=add_prefix_space,
59
+ add_special_tokens=add_special_tokens,
60
+ **kwargs,
61
+ )
62
+
63
+ # Add VibeVoice-specific special tokens
64
+ self._add_vibevoice_special_tokens()
65
+
66
+ def _add_vibevoice_special_tokens(self):
67
+ """Add VibeVoice-specific special tokens."""
68
+ special_tokens = {
69
+ "additional_special_tokens": [
70
+ "<|vision_start|>", # Speech start (reusing vision tokens)
71
+ "<|vision_end|>", # Speech end
72
+ "<|vision_pad|>", # Speech diffusion pad
73
+ ]
74
+ }
75
+ num_added = self.add_special_tokens(special_tokens)
76
+
77
+ # Cache special token IDs
78
+ self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
79
+ self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
80
+ self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
81
+
82
+ self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
83
+
84
+ return num_added
85
+
86
+ @property
87
+ def eos_id(self) -> int:
88
+ """Id of the end of sequence token."""
89
+ return self._eos_id
90
+
91
+ @property
92
+ def speech_start_id(self) -> int:
93
+ """Id of the speech start token."""
94
+ return self._speech_start_id
95
+
96
+ @property
97
+ def speech_end_id(self) -> int:
98
+ """Id of the speech end token."""
99
+ return self._speech_end_id
100
+
101
+ @property
102
+ def speech_diffusion_id(self) -> int:
103
+ """Id of the speech diffusion token."""
104
+ return self._speech_diffusion_id
105
+
106
+ @property
107
+ def pad_id(self) -> int:
108
+ """Id used for padding (returns -100 for loss masking)."""
109
+ return -100
110
+
111
+
112
+ class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
113
+ """
114
+ Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
115
+ Based on the Qwen2 tokenizer with additional special tokens for speech.
116
+
117
+ Args:
118
+ vocab_file (`str`, *optional*):
119
+ Path to the vocabulary file.
120
+ merges_file (`str`, *optional*):
121
+ Path to the merges file.
122
+ tokenizer_file (`str`, *optional*):
123
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
124
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
125
+ The unknown token.
126
+ bos_token (`str`, *optional*):
127
+ The beginning of sequence token. Not used for vibevoice.
128
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
129
+ The end of sequence token.
130
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
131
+ The token used for padding.
132
+ """
133
+
134
+ model_input_names = ["input_ids", "attention_mask"]
135
+
136
+ def __init__(
137
+ self,
138
+ vocab_file=None,
139
+ merges_file=None,
140
+ tokenizer_file=None,
141
+ unk_token="<|endoftext|>",
142
+ bos_token=None,
143
+ eos_token="<|endoftext|>",
144
+ pad_token="<|endoftext|>",
145
+ add_prefix_space=False,
146
+ **kwargs,
147
+ ):
148
+ super().__init__(
149
+ vocab_file=vocab_file,
150
+ merges_file=merges_file,
151
+ tokenizer_file=tokenizer_file,
152
+ unk_token=unk_token,
153
+ bos_token=bos_token,
154
+ eos_token=eos_token,
155
+ pad_token=pad_token,
156
+ add_prefix_space=add_prefix_space,
157
+ **kwargs,
158
+ )
159
+
160
+ # Add VibeVoice-specific special tokens
161
+ self._add_vibevoice_special_tokens()
162
+
163
+ def _add_vibevoice_special_tokens(self):
164
+ """Add VibeVoice-specific special tokens."""
165
+ special_tokens = {
166
+ "additional_special_tokens": [
167
+ "<|vision_start|>", # Speech start (reusing vision tokens)
168
+ "<|vision_end|>", # Speech end
169
+ "<|vision_pad|>", # Speech diffusion pad
170
+ ]
171
+ }
172
+ num_added = self.add_special_tokens(special_tokens)
173
+
174
+ # Cache special token IDs
175
+ self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
176
+ self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
177
+ self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
178
+
179
+ # self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
180
+ self._eos_id = self.eos_token_id # qwen2 / qwen3
181
+ self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
182
+
183
+ return num_added
184
+
185
+ @property
186
+ def eos_id(self) -> int:
187
+ """Id of the end of sequence token."""
188
+ return self._eos_id
189
+
190
+ @property
191
+ def speech_start_id(self) -> int:
192
+ """Id of the speech start token."""
193
+ return self._speech_start_id
194
+
195
+ @property
196
+ def speech_end_id(self) -> int:
197
+ """Id of the speech end token."""
198
+ return self._speech_end_id
199
+
200
+ @property
201
+ def speech_diffusion_id(self) -> int:
202
+ """Id of the speech diffusion token."""
203
+ return self._speech_diffusion_id
204
+
205
+ @property
206
+ def pad_id(self) -> int:
207
+ """Id used for padding (returns -100 for loss masking)."""
208
+ return self._pad_id
209
+
210
+
211
+ __all__ = [
212
+ "VibeVoiceTextTokenizer",
213
+ "VibeVoiceTextTokenizerFast",
214
+ ]
vvembed/modular/modular_vibevoice_tokenizer.py ADDED
@@ -0,0 +1,1195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing as tp
3
+ from functools import partial
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+ import copy
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.models.auto import AutoModel
14
+
15
+ from transformers.configuration_utils import PretrainedConfig
16
+ from transformers.utils import logging
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.activations import ACT2FN
19
+
20
+ from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ import os
25
+ # Try to import APEX FusedRMSNorm
26
+ try:
27
+ from apex.normalization.fused_layer_norm import fused_rms_norm_affine
28
+ APEX_AVAILABLE = True
29
+ logger.info("APEX FusedRMSNorm is available and will be used for optimization")
30
+ if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
31
+ APEX_AVAILABLE = False
32
+ logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
33
+ except ImportError:
34
+ APEX_AVAILABLE = False
35
+ logger.warning("APEX FusedRMSNorm not available, using native implementation")
36
+ # APEX_AVAILABLE=False
37
+
38
+ # Normalization modules
39
+ class ConvLayerNorm(nn.LayerNorm):
40
+ """
41
+ Convolution-friendly LayerNorm that moves channels to last dimensions
42
+ before running the normalization and moves them back to original position right after.
43
+ """
44
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
45
+ super().__init__(normalized_shape, **kwargs)
46
+
47
+ def forward(self, x):
48
+ x = x.transpose(1, 2) # b ... t -> b t ...
49
+ x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
50
+ x = x.transpose(1, 2) # b t ... -> b ... t
51
+ return x
52
+
53
+ class RMSNorm(nn.Module):
54
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.eps = eps
58
+ self.elementwise_affine = elementwise_affine
59
+ if self.elementwise_affine:
60
+ weight_shape = (dim,) if weight_shape is None else weight_shape
61
+ self.weight = nn.Parameter(torch.ones(weight_shape))
62
+ else:
63
+ self.register_parameter('weight', None)
64
+
65
+ def _norm(self, x):
66
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
67
+
68
+ def forward(self, x):
69
+ output = self._norm(x.float()).type_as(x)
70
+ if self.weight is not None:
71
+ output = output * self.weight
72
+ return output
73
+
74
+ def extra_repr(self) -> str:
75
+ return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
76
+
77
+ class ConvRMSNorm(RMSNorm):
78
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
79
+ super().__init__(dim, eps, elementwise_affine, weight_shape)
80
+
81
+ def forward(self, x):
82
+ x = x.transpose(1, 2) # b ... t -> b t ...
83
+ if (not APEX_AVAILABLE) or (not self.elementwise_affine):
84
+ # Fallback to native implementation
85
+ output = self._norm(x.float()).type_as(x)
86
+ if self.weight is not None:
87
+ output = output * self.weight
88
+ else:
89
+ output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
90
+ output = output.transpose(1, 2) # b t ... -> b ... t
91
+ return output
92
+
93
+ # Convolutional layers and utilities
94
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
95
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
96
+
97
+
98
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
99
+ assert norm in CONV_NORMALIZATIONS
100
+ if norm == 'weight_norm':
101
+ return nn.utils.weight_norm(module)
102
+ elif norm == 'spectral_norm':
103
+ return nn.utils.spectral_norm(module)
104
+ else:
105
+ # We already check was in CONV_NORMALIZATION, so any other choice
106
+ # doesn't need reparametrization.
107
+ return module
108
+
109
+
110
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
111
+ """Return the proper normalization module. If causal is True, this will ensure the returned
112
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
113
+ """
114
+ assert norm in CONV_NORMALIZATIONS
115
+ if norm == 'layer_norm':
116
+ assert isinstance(module, nn.modules.conv._ConvNd)
117
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
118
+ elif norm == 'time_group_norm':
119
+ if causal:
120
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
121
+ assert isinstance(module, nn.modules.conv._ConvNd)
122
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
123
+ else:
124
+ return nn.Identity()
125
+
126
+
127
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
128
+ padding_total: int = 0) -> int:
129
+ """Calculate extra padding needed for convolution to have the same output length"""
130
+ length = x.shape[-1]
131
+ n_frames = (length - kernel_size + padding_total) / stride + 1
132
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
133
+ return ideal_length - length
134
+
135
+
136
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
137
+ """Pad 1D input with handling for small inputs in reflect mode"""
138
+ length = x.shape[-1]
139
+ padding_left, padding_right = paddings
140
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
141
+ if mode == 'reflect':
142
+ max_pad = max(padding_left, padding_right)
143
+ extra_pad = 0
144
+ if length <= max_pad:
145
+ extra_pad = max_pad - length + 1
146
+ x = F.pad(x, (0, extra_pad))
147
+ padded = F.pad(x, paddings, mode, value)
148
+ end = padded.shape[-1] - extra_pad
149
+ return padded[..., :end]
150
+ else:
151
+ return F.pad(x, paddings, mode, value)
152
+
153
+
154
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
155
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
156
+ padding_left, padding_right = paddings
157
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
158
+ assert (padding_left + padding_right) <= x.shape[-1]
159
+ end = x.shape[-1] - padding_right
160
+ return x[..., padding_left: end]
161
+
162
+
163
+ class NormConv1d(nn.Module):
164
+ """Wrapper around Conv1d and normalization applied to this conv"""
165
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
166
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
167
+ super().__init__()
168
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
169
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
170
+ self.norm_type = norm
171
+
172
+ def forward(self, x):
173
+ x = self.conv(x)
174
+ x = self.norm(x)
175
+ return x
176
+
177
+
178
+ class NormConvTranspose1d(nn.Module):
179
+ """Wrapper around ConvTranspose1d and normalization applied to this conv"""
180
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
185
+ self.norm_type = norm
186
+
187
+ def forward(self, x):
188
+ x = self.convtr(x)
189
+ x = self.norm(x)
190
+ return x
191
+
192
+
193
+ class VibeVoiceTokenizerStreamingCache:
194
+ """Cache for streaming convolution, similar to KV cache in attention"""
195
+ def __init__(self):
196
+ self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor
197
+
198
+ def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
199
+ """Get cached states for given layer and sample indices"""
200
+ states = []
201
+ max_length = 0
202
+
203
+ # First pass: collect states and find max length
204
+ for idx in sample_indices.tolist():
205
+ key = (layer_id, idx)
206
+ if key not in self.cache:
207
+ return None # If any sample is missing, return None
208
+ state = self.cache[key]
209
+ states.append(state)
210
+ max_length = max(max_length, state.shape[-1])
211
+
212
+ # Second pass: pad states to max length if needed
213
+ if len(states) > 0 and states[0].dim() >= 2:
214
+ padded_states = []
215
+ for state in states:
216
+ if state.shape[-1] < max_length:
217
+ # Pad on the time dimension (last dimension)
218
+ pad_size = max_length - state.shape[-1]
219
+ # Pad with zeros on the LEFT to align the most recent samples
220
+ padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
221
+ padded_states.append(padded_state)
222
+ else:
223
+ padded_states.append(state)
224
+ return torch.stack(padded_states, dim=0)
225
+ else:
226
+ return torch.stack(states, dim=0)
227
+
228
+ def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
229
+ """Set cached states for given layer and sample indices"""
230
+ for i, idx in enumerate(sample_indices.tolist()):
231
+ key = (layer_id, idx)
232
+ self.cache[key] = states[i].detach()
233
+
234
+ def set_to_zero(self, sample_indices: torch.Tensor):
235
+ """Set all cached states to zero for given sample indices"""
236
+ for key in list(self.cache.keys()):
237
+ layer_id, sample_idx = key
238
+ if sample_idx in sample_indices.tolist():
239
+ # Create zero tensor with same shape and dtype as cached tensor
240
+ cached_tensor = self.cache[key]
241
+ self.cache[key] = torch.zeros_like(cached_tensor)
242
+
243
+ def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
244
+ """Clear cache for specific layer/samples or everything"""
245
+ if layer_id is None and sample_indices is None:
246
+ self.cache.clear()
247
+ elif layer_id is not None and sample_indices is None:
248
+ # Clear all samples for a specific layer
249
+ keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
250
+ for k in keys_to_remove:
251
+ del self.cache[k]
252
+ elif layer_id is not None and sample_indices is not None:
253
+ # Clear specific samples for a specific layer
254
+ for idx in sample_indices.tolist():
255
+ key = (layer_id, idx)
256
+ self.cache.pop(key, None)
257
+
258
+ class SConv1d(nn.Module):
259
+ """Conv1d with built-in handling of asymmetric or causal padding and normalization."""
260
+ def __init__(self, in_channels: int, out_channels: int,
261
+ kernel_size: int, stride: int = 1, dilation: int = 1,
262
+ groups: int = 1, bias: bool = True, causal: bool = False,
263
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
264
+ pad_mode: str = 'reflect'):
265
+ super().__init__()
266
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
267
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
268
+ norm=norm, norm_kwargs=norm_kwargs)
269
+ self.causal = causal
270
+ self.pad_mode = pad_mode
271
+
272
+ # Store configuration
273
+ self.kernel_size = kernel_size
274
+ self.dilation = dilation
275
+ self.stride = stride
276
+ self.in_channels = in_channels
277
+ self.out_channels = out_channels
278
+
279
+ # For causal convolution, we need to maintain kernel_size - 1 samples as context
280
+ # need to check use which context_size is more suitable
281
+ # self.context_size = (kernel_size - 1) * dilation
282
+ self.context_size = (kernel_size - 1) * dilation - (stride - 1)
283
+
284
+ # For non-streaming mode, calculate padding
285
+ self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
286
+
287
+ # Create a unique layer ID for cache management
288
+ self._layer_id = None
289
+
290
+ @property
291
+ def layer_id(self):
292
+ if self._layer_id is None:
293
+ self._layer_id = f"sconv1d_{id(self)}"
294
+ return self._layer_id
295
+
296
+ def forward(self, x: torch.Tensor,
297
+ cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
298
+ sample_indices: Optional[torch.Tensor] = None,
299
+ use_cache: bool = False,
300
+ debug: bool = False) -> torch.Tensor:
301
+ """
302
+ Forward pass with optional streaming support via cache.
303
+
304
+ Args:
305
+ x: Input tensor [batch_size, channels, time]
306
+ cache: VibeVoiceTokenizerStreamingCache object for maintaining states
307
+ sample_indices: Indices identifying each sample for cache management
308
+ use_cache: Whether to use cached states for streaming
309
+ debug: Whether to print debug information
310
+
311
+ Returns:
312
+ Output tensor
313
+ """
314
+ B, C, T = x.shape
315
+
316
+ # Non-streaming mode
317
+ if not use_cache or cache is None:
318
+ return self._forward_non_streaming(x, debug=debug)
319
+
320
+ # Streaming mode
321
+ assert self.causal, "Streaming mode is only supported for causal convolutions"
322
+ assert sample_indices is not None, "sample_indices must be provided for streaming mode"
323
+ assert len(sample_indices) == B, "sample_indices must match batch size"
324
+
325
+ return self._forward_streaming(x, cache, sample_indices, debug)
326
+
327
+ def _forward_streaming(self, x: torch.Tensor,
328
+ cache: VibeVoiceTokenizerStreamingCache,
329
+ sample_indices: torch.Tensor,
330
+ debug: bool = False) -> torch.Tensor:
331
+ """Streaming forward pass with cache operations kept separate from compiled code"""
332
+ B, C, T = x.shape
333
+
334
+ # Cache operations (not compiled)
335
+ cached_states = cache.get(self.layer_id, sample_indices)
336
+
337
+ if cached_states is None:
338
+ # First chunk - initialize with zeros for context
339
+ if self.context_size > 0:
340
+ cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
341
+ if debug:
342
+ print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
343
+ else:
344
+ cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
345
+ if debug:
346
+ print(f"[DEBUG] No context needed (kernel_size=stride)")
347
+
348
+ # Concatenate cached states with input
349
+ if cached_states.shape[2] > 0:
350
+ input_with_context = torch.cat([cached_states, x], dim=2)
351
+ else:
352
+ input_with_context = x
353
+
354
+ if debug:
355
+ print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
356
+
357
+ # Apply convolution directly - no extra padding in streaming mode
358
+ # The conv layer will handle its own padding internally
359
+ output = self.conv(input_with_context)
360
+
361
+ if debug:
362
+ print(f"[DEBUG] Output shape: {output.shape}")
363
+
364
+ # Update cache for next chunk
365
+ if self.context_size > 0:
366
+ # Calculate how many samples to keep
367
+ total_input_length = input_with_context.shape[2]
368
+
369
+ # Keep the last context_size samples
370
+ if total_input_length >= self.context_size:
371
+ new_cache_start = total_input_length - self.context_size
372
+ new_cache = input_with_context[:, :, new_cache_start:]
373
+ else:
374
+ # If we have less than context_size samples, keep everything
375
+ new_cache = input_with_context
376
+
377
+ if debug:
378
+ print(f"[DEBUG] New cache shape: {new_cache.shape}")
379
+
380
+ cache.set(self.layer_id, sample_indices, new_cache)
381
+
382
+ return output
383
+
384
+ def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
385
+ """Standard forward pass without streaming"""
386
+ B, C, T = x.shape
387
+ kernel_size = self.kernel_size
388
+ stride = self.stride
389
+ dilation = self.dilation
390
+ padding_total = self.padding_total
391
+
392
+ # Compute extra padding for stride alignment
393
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
394
+
395
+ if debug:
396
+ print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
397
+
398
+ if self.causal:
399
+ # Left padding for causal
400
+ if self.pad_mode == 'constant':
401
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
402
+ else:
403
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
404
+ else:
405
+ # Symmetric padding for non-causal
406
+ padding_right = padding_total // 2
407
+ padding_left = padding_total - padding_right
408
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
409
+
410
+ if debug:
411
+ print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
412
+
413
+ output = self.conv(x)
414
+
415
+ if debug:
416
+ print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
417
+
418
+ return output
419
+
420
+
421
+ class SConvTranspose1d(nn.Module):
422
+ """ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
423
+ def __init__(self, in_channels: int, out_channels: int,
424
+ kernel_size: int, stride: int = 1, causal: bool = False,
425
+ norm: str = 'none', trim_right_ratio: float = 1.,
426
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
427
+ super().__init__()
428
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
429
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
430
+ self.causal = causal
431
+ self.trim_right_ratio = trim_right_ratio
432
+ assert self.causal or self.trim_right_ratio == 1., \
433
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
434
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
435
+
436
+ # Store configuration
437
+ self.kernel_size = kernel_size
438
+ self.stride = stride
439
+ self.in_channels = in_channels
440
+ self.out_channels = out_channels
441
+
442
+ # For transposed convolution, padding calculation is different
443
+ self.padding_total = kernel_size - stride
444
+
445
+ # For streaming, we need to keep track of input history
446
+ # Transposed conv needs to see multiple input samples to produce correct output
447
+ self.context_size = kernel_size - 1
448
+
449
+ # Create a unique layer ID for cache management
450
+ self._layer_id = None
451
+
452
+ @property
453
+ def layer_id(self):
454
+ if self._layer_id is None:
455
+ self._layer_id = f"sconvtr1d_{id(self)}"
456
+ return self._layer_id
457
+
458
+ def forward(self, x: torch.Tensor,
459
+ cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
460
+ sample_indices: Optional[torch.Tensor] = None,
461
+ use_cache: bool = False,
462
+ debug: bool = False) -> torch.Tensor:
463
+ """
464
+ Forward pass with optional streaming support via cache.
465
+ """
466
+ B, C, T = x.shape
467
+
468
+ # Non-streaming mode
469
+ if not use_cache or cache is None:
470
+ return self._forward_non_streaming(x, debug=debug)
471
+
472
+ # Streaming mode
473
+ assert sample_indices is not None, "sample_indices must be provided for streaming mode"
474
+ assert len(sample_indices) == B, "sample_indices must match batch size"
475
+
476
+ return self._forward_streaming(x, cache, sample_indices, debug)
477
+
478
+ def _forward_streaming(self, x: torch.Tensor,
479
+ cache: VibeVoiceTokenizerStreamingCache,
480
+ sample_indices: torch.Tensor,
481
+ debug: bool = False) -> torch.Tensor:
482
+ """Streaming forward pass with cache operations kept separate from compiled code"""
483
+ B, C, T = x.shape
484
+
485
+ # Cache operations (not compiled)
486
+ cached_input = cache.get(self.layer_id, sample_indices)
487
+
488
+ if cached_input is None:
489
+ # First chunk - no history yet
490
+ cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
491
+ if debug:
492
+ print(f"[DEBUG] Initialized empty cache for transposed conv")
493
+
494
+ # Concatenate cached input with new input
495
+ full_input = torch.cat([cached_input, x], dim=2)
496
+
497
+ if debug:
498
+ print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
499
+
500
+ # First chunk or debug mode - use uncompiled version
501
+ full_output = self.convtr(full_input)
502
+
503
+ if debug:
504
+ print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
505
+
506
+ # Calculate padding to remove
507
+ if self.causal:
508
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
509
+ padding_left = self.padding_total - padding_right
510
+ else:
511
+ padding_right = self.padding_total // 2
512
+ padding_left = self.padding_total - padding_right
513
+
514
+ # Remove padding
515
+ if padding_left + padding_right > 0:
516
+ full_output = unpad1d(full_output, (padding_left, padding_right))
517
+
518
+ if debug:
519
+ print(f"[DEBUG] After unpadding: {full_output.shape}")
520
+
521
+ # Determine which part of the output corresponds to the new input
522
+ if cached_input.shape[2] == 0:
523
+ # First chunk - return all output
524
+ output = full_output
525
+ else:
526
+ # Subsequent chunks - return only the new output
527
+ expected_new_output = T * self.stride
528
+
529
+ # Take the last expected_new_output samples
530
+ if full_output.shape[2] >= expected_new_output:
531
+ output = full_output[:, :, -expected_new_output:]
532
+ else:
533
+ output = full_output
534
+
535
+ if debug:
536
+ print(f"[DEBUG] Final streaming output shape: {output.shape}")
537
+
538
+ # Update cache
539
+ if full_input.shape[2] > self.context_size:
540
+ new_cache = full_input[:, :, -self.context_size:]
541
+ else:
542
+ new_cache = full_input
543
+
544
+ if debug:
545
+ print(f"[DEBUG] New cache shape: {new_cache.shape}")
546
+
547
+ cache.set(self.layer_id, sample_indices, new_cache)
548
+
549
+ return output
550
+
551
+ def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
552
+ """Standard forward pass without streaming"""
553
+ if debug:
554
+ print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
555
+
556
+ # Apply transposed convolution
557
+ y = self.convtr(x)
558
+
559
+ if debug:
560
+ print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
561
+
562
+ # Calculate and remove padding
563
+ if self.causal:
564
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
565
+ padding_left = self.padding_total - padding_right
566
+ else:
567
+ padding_right = self.padding_total // 2
568
+ padding_left = self.padding_total - padding_right
569
+
570
+ if padding_left + padding_right > 0:
571
+ y = unpad1d(y, (padding_left, padding_right))
572
+
573
+ if debug:
574
+ print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
575
+
576
+ return y
577
+
578
+ # FFN
579
+ class FFN(nn.Module):
580
+ def __init__(
581
+ self,
582
+ embed_dim,
583
+ ffn_dim,
584
+ bias=False,
585
+ ):
586
+ super().__init__()
587
+ self.embed_dim = embed_dim
588
+ self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
589
+ self.gelu = ACT2FN["gelu"]
590
+ self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
591
+
592
+ def forward(self, x):
593
+ x = self.linear1(x)
594
+ x = self.gelu(x)
595
+ x = self.linear2(x)
596
+ return x
597
+
598
+
599
+ class Convlayer(nn.Module):
600
+ def __init__(
601
+ self,
602
+ in_channels,
603
+ out_channels,
604
+ kernel_size,
605
+ stride=1,
606
+ dilation=1,
607
+ groups=1,
608
+ bias=True,
609
+ pad_mode='zeros',
610
+ norm='weight_norm',
611
+ causal=True,
612
+ ):
613
+ super().__init__()
614
+ self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
615
+ groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
616
+
617
+ def forward(self, x):
618
+ return self.conv(x)
619
+
620
+ class Block1D(nn.Module):
621
+ def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
622
+ layer_scale_init_value=1e-6, **kwargs):
623
+ super().__init__()
624
+
625
+ if kwargs.get('layernorm', 'LN') == 'LN':
626
+ self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
627
+ self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
628
+ elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
629
+ self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
630
+ self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
631
+
632
+ if mixer_layer == 'conv':
633
+ self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
634
+ kernel_size=kernel_size,
635
+ pad_mode=kwargs.get('pad_mode', 'reflect'),
636
+ norm=kwargs.get('norm', 'none'),
637
+ causal=kwargs.get('causal', True),
638
+ bias=kwargs.get('bias', True),
639
+ )
640
+ elif mixer_layer == 'depthwise_conv':
641
+ self.mixer = Convlayer(dim, dim, groups=dim,
642
+ kernel_size=kernel_size,
643
+ pad_mode=kwargs.get('pad_mode', 'reflect'),
644
+ norm=kwargs.get('norm', 'none'),
645
+ causal=kwargs.get('causal', True),
646
+ bias=kwargs.get('bias', True),
647
+ )
648
+ else:
649
+ raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
650
+
651
+ self.ffn = FFN(
652
+ dim,
653
+ kwargs.get('ffn_expansion', 4) * dim,
654
+ bias=kwargs.get('bias', False),
655
+ )
656
+ self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
657
+
658
+ if layer_scale_init_value > 0:
659
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
660
+ self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
661
+ else:
662
+ self.gamma = None
663
+ self.ffn_gamma = None
664
+
665
+ def forward(self, x):
666
+ # mixer
667
+ residual = x
668
+ x = self.norm(x)
669
+ x = self.mixer(x)
670
+ if self.gamma is not None:
671
+ x = x * self.gamma.unsqueeze(-1)
672
+ x = residual + self.drop_path(x)
673
+
674
+ # ffn
675
+ residual = x
676
+ x = self.ffn_norm(x)
677
+ x = x.permute(0, 2, 1)
678
+ x = self.ffn(x)
679
+ x = x.permute(0, 2, 1)
680
+ if self.ffn_gamma is not None:
681
+ x = x * self.ffn_gamma.unsqueeze(-1)
682
+ x = residual + self.drop_path(x)
683
+
684
+ return x
685
+
686
+
687
+ class TokenizerEncoder(nn.Module):
688
+ """
689
+ Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
690
+
691
+ Args:
692
+ config: Configuration object with model parameters
693
+ """
694
+ def __init__(self, config):
695
+ super().__init__()
696
+
697
+ # Extract parameters from config
698
+ self.channels = config.channels
699
+ self.dimension = config.dimension
700
+ self.n_filters = config.n_filters
701
+ self.ratios = list(reversed(config.ratios))
702
+ self.depths = config.depths
703
+ self.n_residual_layers = getattr(config, "n_residual_layers", 1)
704
+ self.hop_length = np.prod(self.ratios)
705
+ self.causal = config.causal
706
+
707
+ # Additional config parameters with defaults
708
+ kernel_size = getattr(config, "kernel_size", 7)
709
+ last_kernel_size = getattr(config, "last_kernel_size", 7)
710
+ norm = getattr(config, "norm", "none")
711
+ norm_params = getattr(config, "norm_params", {})
712
+ pad_mode = getattr(config, "pad_mode", "reflect")
713
+ bias = getattr(config, "bias", True)
714
+ layernorm = getattr(config, "layernorm", "LN")
715
+ layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
716
+ layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
717
+ drop_path_rate = getattr(config, "drop_path_rate", 0.0)
718
+ mixer_layer = getattr(config, "mixer_layer", "conv")
719
+ layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
720
+ disable_last_norm = getattr(config, "disable_last_norm", False)
721
+
722
+ # determine the norm type based on layernorm
723
+ if layernorm == 'LN':
724
+ norm_type = ConvLayerNorm
725
+ elif layernorm == 'RMSNorm':
726
+ norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
727
+ else:
728
+ raise ValueError(f"Unsupported norm type: {layernorm}")
729
+
730
+ # stem and intermediate downsampling conv layers
731
+ stem = nn.Sequential(
732
+ SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
733
+ )
734
+
735
+ self.downsample_layers = nn.ModuleList()
736
+ self.downsample_layers.append(stem)
737
+ for i in range(len(self.ratios)):
738
+ in_ch = self.n_filters * (2 ** i)
739
+ out_ch = self.n_filters * (2 ** (i + 1))
740
+ downsample_layer = nn.Sequential(
741
+ SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
742
+ )
743
+ self.downsample_layers.append(downsample_layer)
744
+
745
+ # configure the transformer blocks
746
+ layer_type = partial(
747
+ Block1D,
748
+ mixer_layer=mixer_layer,
749
+ layernorm=layernorm,
750
+ eps=layernorm_eps,
751
+ causal=self.causal,
752
+ pad_mode=pad_mode,
753
+ norm=norm,
754
+ bias=bias,
755
+ layer_scale_init_value=layer_scale_init_value,
756
+ )
757
+
758
+ self.stages = nn.ModuleList()
759
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
760
+ cur = 0
761
+
762
+ for i in range(len(self.depths)):
763
+ in_ch = self.n_filters * (2 ** i)
764
+ stage = nn.Sequential(
765
+ *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
766
+ )
767
+ self.stages.append(stage)
768
+ cur += self.depths[i]
769
+
770
+ if not disable_last_norm:
771
+ self.norm = norm_type(in_ch, eps=layernorm_eps)
772
+ else:
773
+ self.norm = nn.Identity()
774
+ self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
775
+
776
+ def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
777
+ for i in range(len(self.depths)):
778
+ # Apply downsampling
779
+ for layer in self.downsample_layers[i]:
780
+ if isinstance(layer, SConv1d):
781
+ x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
782
+ else:
783
+ x = layer(x)
784
+
785
+ # Apply stage (Block1D contains Convlayer which contains SConv1d)
786
+ for block in self.stages[i]:
787
+ if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
788
+ # Block1D forward with cache support
789
+ residual = x
790
+ x = block.norm(x)
791
+ x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
792
+ if block.gamma is not None:
793
+ x = x * block.gamma.unsqueeze(-1)
794
+ x = residual + x
795
+
796
+ # FFN part
797
+ residual = x
798
+ x = block.ffn_norm(x)
799
+ x = x.permute(0, 2, 1)
800
+ x = block.ffn(x)
801
+ x = x.permute(0, 2, 1)
802
+ if block.ffn_gamma is not None:
803
+ x = x * block.ffn_gamma.unsqueeze(-1)
804
+ x = residual + x
805
+ else:
806
+ x = block(x)
807
+
808
+ return self.norm(x)
809
+
810
+ def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
811
+ x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
812
+ x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
813
+ return x
814
+
815
+
816
+ class TokenizerDecoder(nn.Module):
817
+ """
818
+ Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
819
+
820
+ Args:
821
+ config: Configuration object with model parameters
822
+ """
823
+ def __init__(self, config):
824
+ super().__init__()
825
+
826
+ # Extract parameters from config
827
+ self.dimension = config.dimension
828
+ self.channels = config.channels
829
+ self.n_filters = config.n_filters
830
+ self.ratios = config.ratios
831
+
832
+ # IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel
833
+ self.depths = config.depths # Changed from list(reversed(config.depths))
834
+
835
+ self.n_residual_layers = getattr(config, "n_residual_layers", 1)
836
+ self.hop_length = np.prod(self.ratios)
837
+ self.causal = config.causal
838
+
839
+ # Additional config parameters with defaults
840
+ kernel_size = getattr(config, "kernel_size", 7)
841
+ last_kernel_size = getattr(config, "last_kernel_size", 7)
842
+ norm = getattr(config, "norm", "none")
843
+ norm_params = getattr(config, "norm_params", {})
844
+ pad_mode = getattr(config, "pad_mode", "reflect")
845
+ bias = getattr(config, "bias", True)
846
+ layernorm = getattr(config, "layernorm", "LN")
847
+ layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
848
+ trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
849
+ layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
850
+ drop_path_rate = getattr(config, "drop_path_rate", 0.0)
851
+ mixer_layer = getattr(config, "mixer_layer", "conv")
852
+ layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
853
+ disable_last_norm = getattr(config, "disable_last_norm", False)
854
+
855
+ # determine the norm type based on layernorm
856
+ if layernorm == 'LN':
857
+ norm_type = ConvLayerNorm
858
+ elif layernorm == 'RMSNorm':
859
+ norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
860
+ else:
861
+ raise ValueError(f"Unsupported norm type: {layernorm}")
862
+
863
+ # stem and upsampling layers
864
+ stem = nn.Sequential(
865
+ SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
866
+ norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
867
+ )
868
+
869
+ self.upsample_layers = nn.ModuleList()
870
+ self.upsample_layers.append(stem)
871
+ for i in range(len(self.ratios)):
872
+ in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
873
+ out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
874
+ upsample_layer = nn.Sequential(
875
+ SConvTranspose1d(in_ch, out_ch,
876
+ kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
877
+ norm=norm, norm_kwargs=norm_params, bias=bias,
878
+ causal=self.causal, trim_right_ratio=trim_right_ratio),
879
+ )
880
+ self.upsample_layers.append(upsample_layer)
881
+
882
+ # configure transformer blocks
883
+ layer_type = partial(
884
+ Block1D,
885
+ mixer_layer=mixer_layer,
886
+ layernorm=layernorm,
887
+ eps=layernorm_eps,
888
+ causal=self.causal,
889
+ pad_mode=pad_mode,
890
+ norm=norm,
891
+ bias=bias,
892
+ layer_scale_init_value=layer_scale_init_value,
893
+ )
894
+
895
+ self.stages = nn.ModuleList()
896
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
897
+ cur = 0
898
+
899
+ # Create stages in the same order as the original model
900
+ for i in range(len(self.depths)):
901
+ in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
902
+ stage = nn.Sequential(
903
+ *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
904
+ )
905
+ self.stages.append(stage)
906
+ cur += self.depths[i]
907
+
908
+ if not disable_last_norm:
909
+ self.norm = norm_type(in_ch, eps=layernorm_eps)
910
+ else:
911
+ self.norm = nn.Identity()
912
+ self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
913
+
914
+ def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
915
+ for i in range(len(self.depths)):
916
+ # Apply upsampling
917
+ for layer in self.upsample_layers[i]:
918
+ if isinstance(layer, (SConv1d, SConvTranspose1d)):
919
+ x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
920
+ else:
921
+ x = layer(x)
922
+
923
+ # Apply stage (Block1D contains Convlayer which contains SConv1d)
924
+ for block in self.stages[i]:
925
+ if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
926
+ # Block1D forward with cache support
927
+ residual = x
928
+ x = block.norm(x)
929
+ x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
930
+ if block.gamma is not None:
931
+ x = x * block.gamma.unsqueeze(-1)
932
+ x = residual + x
933
+
934
+ # FFN part
935
+ residual = x
936
+ x = block.ffn_norm(x)
937
+ x = x.permute(0, 2, 1)
938
+ x = block.ffn(x)
939
+ x = x.permute(0, 2, 1)
940
+ if block.ffn_gamma is not None:
941
+ x = x * block.ffn_gamma.unsqueeze(-1)
942
+ x = residual + x
943
+ else:
944
+ x = block(x)
945
+
946
+ return self.norm(x)
947
+
948
+ def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
949
+ x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
950
+ x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
951
+ return x
952
+
953
+
954
+ @dataclass
955
+ class VibeVoiceTokenizerEncoderOutput:
956
+ """
957
+ Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
958
+
959
+ Args:
960
+ mean (`torch.FloatTensor`): The mean parameters of the distribution.
961
+ std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
962
+ """
963
+ mean: torch.Tensor
964
+ std: Optional[Union[float, torch.Tensor]] = None
965
+
966
+ def sample(self, dist_type='fix'):
967
+ """
968
+ Sample from the distribution.
969
+
970
+ Args:
971
+ dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
972
+
973
+ Returns:
974
+ `torch.FloatTensor`: Sampled values.
975
+ `torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
976
+ """
977
+ if dist_type == 'fix':
978
+ x = self.mean + self.std * torch.randn_like(self.mean)
979
+ return x, self.std
980
+ elif dist_type == 'gaussian':
981
+ batch_size = self.mean.size(0)
982
+ value = self.std / 0.8
983
+ std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
984
+
985
+ while std.dim() < self.mean.dim():
986
+ std = std.unsqueeze(-1)
987
+
988
+ x = self.mean + std * torch.randn_like(self.mean)
989
+ return x, std
990
+ else:
991
+ return self.mean, self.std
992
+
993
+ def kl(self):
994
+ """Compute KL divergence between this distribution and a standard normal."""
995
+ target = torch.zeros_like(self.mean)
996
+ return F.mse_loss(self.mean, target, reduction='none')
997
+
998
+ def mode(self):
999
+ """Return the distribution mode (which is the mean for Gaussian)."""
1000
+ return self.mean
1001
+
1002
+ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
1003
+ """VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
1004
+
1005
+ config_class = VibeVoiceAcousticTokenizerConfig
1006
+ base_model_prefix = "vibevoice_acoustic_tokenizer"
1007
+ _supports_flash_attn_2 = True
1008
+ _supports_sdpa = True
1009
+ _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
1010
+
1011
+ def __init__(self, config):
1012
+ super().__init__(config)
1013
+
1014
+ self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
1015
+ self.std_dist_type = getattr(config, "std_dist_type", "fix")
1016
+
1017
+ # Parse encoder depths
1018
+ if isinstance(config.encoder_depths, str):
1019
+ encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
1020
+ else:
1021
+ encoder_depths = config.encoder_depths
1022
+
1023
+ # Parse decoder depths if provided
1024
+ if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
1025
+ decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
1026
+ else:
1027
+ # Default: use reversed encoder depths if decoder_depths is None
1028
+ decoder_depths = list(reversed(encoder_depths))
1029
+
1030
+ # Create encoder config
1031
+ encoder_config = copy.deepcopy(config)
1032
+ encoder_config.dimension = config.vae_dim
1033
+ encoder_config.n_filters = config.encoder_n_filters
1034
+ encoder_config.ratios = config.encoder_ratios
1035
+ encoder_config.depths = encoder_depths
1036
+ encoder_config.norm = config.conv_norm
1037
+ encoder_config.pad_mode = config.pad_mode
1038
+ encoder_config.bias = config.conv_bias
1039
+ encoder_config.layernorm_eps = config.layernorm_eps
1040
+ encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1041
+ encoder_config.mixer_layer = config.mixer_layer
1042
+ encoder_config.layer_scale_init_value = config.layer_scale_init_value
1043
+ encoder_config.disable_last_norm = config.disable_last_norm
1044
+
1045
+ # Create decoder config
1046
+ decoder_config = copy.deepcopy(config)
1047
+ decoder_config.dimension = config.vae_dim
1048
+ decoder_config.n_filters = config.decoder_n_filters
1049
+ decoder_config.ratios = config.decoder_ratios
1050
+ decoder_config.depths = decoder_depths
1051
+ decoder_config.norm = config.conv_norm
1052
+ decoder_config.pad_mode = config.pad_mode
1053
+ decoder_config.bias = config.conv_bias
1054
+ decoder_config.layernorm_eps = config.layernorm_eps
1055
+ decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1056
+ decoder_config.mixer_layer = config.mixer_layer
1057
+ decoder_config.layer_scale_init_value = config.layer_scale_init_value
1058
+ decoder_config.disable_last_norm = config.disable_last_norm
1059
+
1060
+ # Initialize encoder and decoder
1061
+ self.encoder = TokenizerEncoder(encoder_config)
1062
+ self.decoder = TokenizerDecoder(decoder_config)
1063
+
1064
+ # Initialize weights
1065
+ self.apply(self._init_weights)
1066
+
1067
+ def _init_weights(self, module):
1068
+ """Initialize weights for the model"""
1069
+ if isinstance(module, nn.Linear):
1070
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1071
+ if module.bias is not None:
1072
+ nn.init.zeros_(module.bias)
1073
+ elif isinstance(module, nn.LayerNorm):
1074
+ nn.init.ones_(module.weight)
1075
+ nn.init.zeros_(module.bias)
1076
+ elif isinstance(module, nn.Conv1d):
1077
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1078
+ if module.bias is not None:
1079
+ nn.init.zeros_(module.bias)
1080
+
1081
+ @torch.no_grad()
1082
+ def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1083
+ """Convert audio to latent representations"""
1084
+ latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1085
+ return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
1086
+
1087
+ @torch.no_grad()
1088
+ def sampling(self, encoder_output, dist_type=None):
1089
+ """Sample from the encoder output distribution"""
1090
+ dist_type = dist_type or self.std_dist_type
1091
+
1092
+ if dist_type == 'fix':
1093
+ return encoder_output.sample(dist_type='fix')
1094
+ elif dist_type == 'gaussian':
1095
+ return encoder_output.sample(dist_type='gaussian')
1096
+ else:
1097
+ raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
1098
+
1099
+ @torch.no_grad()
1100
+ def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
1101
+ """Convert latent representations back to audio"""
1102
+ if latents.shape[1] == self.config.vae_dim:
1103
+ pass
1104
+ else:
1105
+ latents = latents.permute(0, 2, 1)
1106
+
1107
+ audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1108
+ return audio
1109
+
1110
+ def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1111
+ """Full forward pass: encode audio to latents, then decode back to audio"""
1112
+ encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1113
+ sampled_latents, _ = self.sampling(encoder_output)
1114
+ reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1115
+ return reconstructed, sampled_latents
1116
+
1117
+
1118
+ class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
1119
+ """VibeVoice speech tokenizer model with only encoder for semantic tokens"""
1120
+
1121
+ config_class = VibeVoiceSemanticTokenizerConfig
1122
+ base_model_prefix = "vibevoice_semantic_tokenizer"
1123
+ _supports_flash_attn_2 = True
1124
+ _supports_sdpa = True
1125
+ _no_split_modules = ["TokenizerEncoder"]
1126
+
1127
+ def __init__(self, config):
1128
+ super().__init__(config)
1129
+
1130
+ # Parse encoder depths
1131
+ if isinstance(config.encoder_depths, str):
1132
+ encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
1133
+ else:
1134
+ encoder_depths = config.encoder_depths
1135
+
1136
+ # Create encoder config
1137
+ encoder_config = copy.deepcopy(config)
1138
+ encoder_config.dimension = config.vae_dim
1139
+ encoder_config.n_filters = config.encoder_n_filters
1140
+ encoder_config.ratios = config.encoder_ratios
1141
+ encoder_config.depths = encoder_depths
1142
+ encoder_config.norm = config.conv_norm
1143
+ encoder_config.pad_mode = config.pad_mode
1144
+ encoder_config.bias = config.conv_bias
1145
+ encoder_config.layernorm_eps = config.layernorm_eps
1146
+ encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1147
+ encoder_config.mixer_layer = config.mixer_layer
1148
+ encoder_config.layer_scale_init_value = config.layer_scale_init_value
1149
+ encoder_config.disable_last_norm = config.disable_last_norm
1150
+
1151
+ # Initialize encoder and decoder
1152
+ self.encoder = TokenizerEncoder(encoder_config)
1153
+
1154
+ # Initialize weights
1155
+ self.apply(self._init_weights)
1156
+
1157
+ def _init_weights(self, module):
1158
+ """Initialize weights for the model"""
1159
+ if isinstance(module, nn.Linear):
1160
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1161
+ if module.bias is not None:
1162
+ nn.init.zeros_(module.bias)
1163
+ elif isinstance(module, nn.LayerNorm):
1164
+ nn.init.ones_(module.weight)
1165
+ nn.init.zeros_(module.bias)
1166
+ elif isinstance(module, nn.Conv1d):
1167
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1168
+ if module.bias is not None:
1169
+ nn.init.zeros_(module.bias)
1170
+
1171
+ @torch.no_grad()
1172
+ def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1173
+ """Convert audio to latent representations"""
1174
+ latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1175
+ return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
1176
+
1177
+ @torch.no_grad()
1178
+ def sampling(self, encoder_output, dist_type=None):
1179
+ """Sample from the encoder output distribution"""
1180
+ return encoder_output.sample(dist_type='none')
1181
+
1182
+ def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1183
+ """Full forward pass: encode audio to latents, then decode back to audio"""
1184
+ encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1185
+ sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
1186
+ return None, sampled_latents
1187
+
1188
+ AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
1189
+ AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
1190
+
1191
+ __all__ = [
1192
+ "VibeVoiceTokenizerStreamingCache",
1193
+ "VibeVoiceAcousticTokenizerModel",
1194
+ "VibeVoiceSemanticTokenizerModel",
1195
+ ]
vvembed/modular/streamer.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ import asyncio
6
+ from queue import Queue
7
+ from typing import TYPE_CHECKING, Optional
8
+
9
+
10
+ from transformers.generation import BaseStreamer
11
+
12
+
13
+ class AudioStreamer(BaseStreamer):
14
+ """
15
+ Audio streamer that stores audio chunks in queues for each sample in the batch.
16
+ This allows streaming audio generation for multiple samples simultaneously.
17
+
18
+ Parameters:
19
+ batch_size (`int`):
20
+ The batch size for generation
21
+ stop_signal (`any`, *optional*):
22
+ The signal to put in the queue when generation ends. Defaults to None.
23
+ timeout (`float`, *optional*):
24
+ The timeout for the audio queue. If `None`, the queue will block indefinitely.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ batch_size: int,
30
+ stop_signal: Optional[any] = None,
31
+ timeout: Optional[float] = None,
32
+ ):
33
+ self.batch_size = batch_size
34
+ self.stop_signal = stop_signal
35
+ self.timeout = timeout
36
+
37
+ # Create a queue for each sample in the batch
38
+ self.audio_queues = [Queue() for _ in range(batch_size)]
39
+ self.finished_flags = [False for _ in range(batch_size)]
40
+ self.sample_indices_map = {} # Maps from sample index to queue index
41
+
42
+ def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
43
+ """
44
+ Receives audio chunks and puts them in the appropriate queues.
45
+
46
+ Args:
47
+ audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
48
+ sample_indices: Tensor indicating which samples these chunks belong to
49
+ """
50
+ for i, sample_idx in enumerate(sample_indices):
51
+ idx = sample_idx.item()
52
+ if idx < self.batch_size and not self.finished_flags[idx]:
53
+ # Convert to numpy or keep as tensor based on preference
54
+ audio_chunk = audio_chunks[i].detach().cpu()
55
+ self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
56
+
57
+ def end(self, sample_indices: Optional[torch.Tensor] = None):
58
+ """
59
+ Signals the end of generation for specified samples or all samples.
60
+
61
+ Args:
62
+ sample_indices: Optional tensor of sample indices to end. If None, ends all.
63
+ """
64
+ if sample_indices is None:
65
+ # End all samples
66
+ for idx in range(self.batch_size):
67
+ if not self.finished_flags[idx]:
68
+ self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
69
+ self.finished_flags[idx] = True
70
+ else:
71
+ # End specific samples
72
+ for sample_idx in sample_indices:
73
+ idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
74
+ if idx < self.batch_size and not self.finished_flags[idx]:
75
+ self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
76
+ self.finished_flags[idx] = True
77
+
78
+ def __iter__(self):
79
+ """Returns an iterator over the batch of audio streams."""
80
+ return AudioBatchIterator(self)
81
+
82
+ def get_stream(self, sample_idx: int):
83
+ """Get the audio stream for a specific sample."""
84
+ if sample_idx >= self.batch_size:
85
+ raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
86
+ return AudioSampleIterator(self, sample_idx)
87
+
88
+
89
+ class AudioSampleIterator:
90
+ """Iterator for a single audio stream from the batch."""
91
+
92
+ def __init__(self, streamer: AudioStreamer, sample_idx: int):
93
+ self.streamer = streamer
94
+ self.sample_idx = sample_idx
95
+
96
+ def __iter__(self):
97
+ return self
98
+
99
+ def __next__(self):
100
+ value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
101
+ if value == self.streamer.stop_signal:
102
+ raise StopIteration()
103
+ return value
104
+
105
+
106
+ class AudioBatchIterator:
107
+ """Iterator that yields audio chunks for all samples in the batch."""
108
+
109
+ def __init__(self, streamer: AudioStreamer):
110
+ self.streamer = streamer
111
+ self.active_samples = set(range(streamer.batch_size))
112
+
113
+ def __iter__(self):
114
+ return self
115
+
116
+ def __next__(self):
117
+ if not self.active_samples:
118
+ raise StopIteration()
119
+
120
+ batch_chunks = {}
121
+ samples_to_remove = set()
122
+
123
+ # Try to get chunks from all active samples
124
+ for idx in self.active_samples:
125
+ try:
126
+ value = self.streamer.audio_queues[idx].get(block=False)
127
+ if value == self.streamer.stop_signal:
128
+ samples_to_remove.add(idx)
129
+ else:
130
+ batch_chunks[idx] = value
131
+ except:
132
+ # Queue is empty for this sample, skip it this iteration
133
+ pass
134
+
135
+ # Remove finished samples
136
+ self.active_samples -= samples_to_remove
137
+
138
+ if batch_chunks:
139
+ return batch_chunks
140
+ elif self.active_samples:
141
+ # If no chunks were ready but we still have active samples,
142
+ # wait a bit and try again
143
+ import time
144
+ time.sleep(0.01)
145
+ return self.__next__()
146
+ else:
147
+ raise StopIteration()
148
+
149
+
150
+ class AsyncAudioStreamer(AudioStreamer):
151
+ """
152
+ Async version of AudioStreamer for use in async contexts.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ batch_size: int,
158
+ stop_signal: Optional[any] = None,
159
+ timeout: Optional[float] = None,
160
+ ):
161
+ super().__init__(batch_size, stop_signal, timeout)
162
+ # Replace regular queues with async queues
163
+ self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
164
+ self.loop = asyncio.get_running_loop()
165
+
166
+ def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
167
+ """Put audio chunks in the appropriate async queues."""
168
+ for i, sample_idx in enumerate(sample_indices):
169
+ idx = sample_idx.item()
170
+ if idx < self.batch_size and not self.finished_flags[idx]:
171
+ audio_chunk = audio_chunks[i].detach().cpu()
172
+ self.loop.call_soon_threadsafe(
173
+ self.audio_queues[idx].put_nowait, audio_chunk
174
+ )
175
+
176
+ def end(self, sample_indices: Optional[torch.Tensor] = None):
177
+ """Signal the end of generation for specified samples."""
178
+ if sample_indices is None:
179
+ indices_to_end = range(self.batch_size)
180
+ else:
181
+ indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
182
+
183
+ for idx in indices_to_end:
184
+ if idx < self.batch_size and not self.finished_flags[idx]:
185
+ self.loop.call_soon_threadsafe(
186
+ self.audio_queues[idx].put_nowait, self.stop_signal
187
+ )
188
+ self.finished_flags[idx] = True
189
+
190
+ async def get_stream(self, sample_idx: int):
191
+ """Get async iterator for a specific sample's audio stream."""
192
+ if sample_idx >= self.batch_size:
193
+ raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
194
+
195
+ while True:
196
+ value = await self.audio_queues[sample_idx].get()
197
+ if value == self.stop_signal:
198
+ break
199
+ yield value
200
+
201
+ def __aiter__(self):
202
+ """Returns an async iterator over all audio streams."""
203
+ return AsyncAudioBatchIterator(self)
204
+
205
+
206
+ class AsyncAudioBatchIterator:
207
+ """Async iterator for batch audio streaming."""
208
+
209
+ def __init__(self, streamer: AsyncAudioStreamer):
210
+ self.streamer = streamer
211
+ self.active_samples = set(range(streamer.batch_size))
212
+
213
+ def __aiter__(self):
214
+ return self
215
+
216
+ async def __anext__(self):
217
+ if not self.active_samples:
218
+ raise StopAsyncIteration()
219
+
220
+ batch_chunks = {}
221
+ samples_to_remove = set()
222
+
223
+ # Create tasks for all active samples
224
+ tasks = {
225
+ idx: asyncio.create_task(self._get_chunk(idx))
226
+ for idx in self.active_samples
227
+ }
228
+
229
+ # Wait for at least one chunk to be ready
230
+ done, pending = await asyncio.wait(
231
+ tasks.values(),
232
+ return_when=asyncio.FIRST_COMPLETED,
233
+ timeout=self.streamer.timeout
234
+ )
235
+
236
+ # Cancel pending tasks
237
+ for task in pending:
238
+ task.cancel()
239
+
240
+ # Process completed tasks
241
+ for idx, task in tasks.items():
242
+ if task in done:
243
+ try:
244
+ value = await task
245
+ if value == self.streamer.stop_signal:
246
+ samples_to_remove.add(idx)
247
+ else:
248
+ batch_chunks[idx] = value
249
+ except asyncio.CancelledError:
250
+ pass
251
+
252
+ self.active_samples -= samples_to_remove
253
+
254
+ if batch_chunks:
255
+ return batch_chunks
256
+ elif self.active_samples:
257
+ # Try again if we still have active samples
258
+ return await self.__anext__()
259
+ else:
260
+ raise StopAsyncIteration()
261
+
262
+ async def _get_chunk(self, idx):
263
+ """Helper to get a chunk from a specific queue."""
264
+ return await self.streamer.audio_queues[idx].get()
vvembed/processor/__init__.py ADDED
File without changes
vvembed/processor/vibevoice_processor.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import List, Optional, Union, Dict, Any, Tuple
4
+ import os
5
+ import re
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
11
+ from transformers.utils import TensorType, logging
12
+ from .vibevoice_tokenizer_processor import AudioNormalizer
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ class VibeVoiceProcessor:
18
+ r"""
19
+ Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
20
+
21
+ [`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
22
+ See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
23
+
24
+ Args:
25
+ tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
26
+ The tokenizer for text processing.
27
+ audio_processor (`VibeVoiceTokenizerProcessor`):
28
+ The audio processor for speech processing.
29
+ speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
30
+ The compression ratio for speech tokenization.
31
+ db_normalize (`bool`, *optional*, defaults to True):
32
+ Whether to apply decibel normalization to audio inputs.
33
+ """
34
+
35
+ def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
36
+ self.tokenizer = tokenizer
37
+ self.audio_processor = audio_processor
38
+ self.speech_tok_compress_ratio = speech_tok_compress_ratio
39
+ self.db_normalize = db_normalize
40
+ self.audio_normalizer = AudioNormalizer() if db_normalize else None
41
+ self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
42
+
43
+ @classmethod
44
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
45
+ """
46
+ Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
47
+
48
+ Args:
49
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
50
+ This can be either:
51
+ - a string, the *model id* of a pretrained model
52
+ - a path to a *directory* containing processor config
53
+
54
+ Returns:
55
+ [`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
56
+ """
57
+ import os
58
+ import json
59
+ from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
60
+ from modular.modular_vibevoice_text_tokenizer import (
61
+ VibeVoiceTextTokenizer,
62
+ VibeVoiceTextTokenizerFast
63
+ )
64
+
65
+ # Load processor configuration
66
+ config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
67
+ if os.path.exists(config_path):
68
+ with open(config_path, 'r') as f:
69
+ config = json.load(f)
70
+ else:
71
+ logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
72
+ config = {
73
+ "speech_tok_compress_ratio": 3200,
74
+ "db_normalize": True,
75
+ }
76
+
77
+ # Extract main processor parameters
78
+ speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
79
+ db_normalize = config.get("db_normalize", True)
80
+
81
+ # Load tokenizer - try from model path first, then fallback to Qwen
82
+ language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
83
+ logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
84
+ if 'qwen' in language_model_pretrained_name.lower():
85
+ tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
86
+ language_model_pretrained_name,
87
+ **kwargs
88
+ )
89
+ else:
90
+ raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
91
+
92
+ # Load audio processor
93
+ if "audio_processor" in config:
94
+ # Create audio processor from config
95
+ audio_config = config["audio_processor"]
96
+ audio_processor = VibeVoiceTokenizerProcessor(
97
+ sampling_rate=audio_config.get("sampling_rate", 24000),
98
+ normalize_audio=audio_config.get("normalize_audio", True),
99
+ target_dB_FS=audio_config.get("target_dB_FS", -25),
100
+ eps=audio_config.get("eps", 1e-6),
101
+ )
102
+ else:
103
+ # Create default audio processor
104
+ audio_processor = VibeVoiceTokenizerProcessor()
105
+
106
+ # Create and return the processor
107
+ return cls(
108
+ tokenizer=tokenizer,
109
+ audio_processor=audio_processor,
110
+ speech_tok_compress_ratio=speech_tok_compress_ratio,
111
+ db_normalize=db_normalize,
112
+ )
113
+
114
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
115
+ """
116
+ Save a processor to a directory, so that it can be re-loaded using the
117
+ [`~VibeVoiceProcessor.from_pretrained`] class method.
118
+
119
+ Args:
120
+ save_directory (`str` or `os.PathLike`):
121
+ Directory where the processor will be saved.
122
+ """
123
+ import os
124
+ import json
125
+
126
+ os.makedirs(save_directory, exist_ok=True)
127
+
128
+ # Save processor configuration
129
+ processor_config = {
130
+ "processor_class": "VibeVoiceProcessor",
131
+ "speech_tok_compress_ratio": self.speech_tok_compress_ratio,
132
+ "db_normalize": self.db_normalize,
133
+ "audio_processor": {
134
+ "feature_extractor_type": "VibeVoiceTokenizerProcessor",
135
+ "sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
136
+ "normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
137
+ "target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
138
+ "eps": getattr(self.audio_processor, 'eps', 1e-6),
139
+ }
140
+ }
141
+
142
+ config_path = os.path.join(save_directory, "preprocessor_config.json")
143
+ with open(config_path, 'w') as f:
144
+ json.dump(processor_config, f, indent=2)
145
+
146
+ logger.info(f"Processor configuration saved in {config_path}")
147
+
148
+ def __call__(
149
+ self,
150
+ text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
151
+ voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
152
+ padding: Union[bool, str, PaddingStrategy] = True,
153
+ truncation: Union[bool, str, TruncationStrategy] = False,
154
+ max_length: Optional[int] = None,
155
+ return_tensors: Optional[Union[str, TensorType]] = None,
156
+ return_attention_mask: bool = True,
157
+ **kwargs,
158
+ ) -> BatchEncoding:
159
+ """
160
+ Main method to process one or more podcast scripts with optional voice samples.
161
+
162
+ Args:
163
+ text (`str`, `List[str]`):
164
+ The input text(s) to process. Can be:
165
+ - A single script string
166
+ - A list of script strings for batch processing
167
+ - A path to a .json or .txt file
168
+ - A list of paths
169
+ voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
170
+ Voice samples for each script. Can be:
171
+ - A list of samples for a single script
172
+ - A list of lists for batch processing
173
+ padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
174
+ Whether to pad sequences to the same length
175
+ truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
176
+ Whether to truncate sequences
177
+ max_length (`int`, *optional*):
178
+ Maximum length of the returned sequences
179
+ return_tensors (`str` or `TensorType`, *optional*):
180
+ If set, will return tensors of a particular framework
181
+ return_attention_mask (`bool`, defaults to `True`):
182
+ Whether to return the attention mask
183
+
184
+ Returns:
185
+ `BatchEncoding`: A BatchEncoding with the following fields:
186
+ - **input_ids** -- List of token id sequences or tensor
187
+ - **attention_mask** -- List of attention masks or tensor
188
+ - **speech_tensors** -- Padded speech inputs (if voice_samples provided)
189
+ - **speech_masks** -- Speech masks (if voice_samples provided)
190
+ - **speech_input_mask** -- Boolean masks indicating speech token positions
191
+ """
192
+ # Handle single vs batch input
193
+ if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
194
+ # Single input
195
+ texts = [text]
196
+ is_batched = False
197
+ else:
198
+ # Batch input
199
+ texts = text
200
+ is_batched = True
201
+
202
+ # Handle voice samples
203
+ if voice_samples is not None:
204
+ if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
205
+ # Single set of voice samples
206
+ voice_samples_list = [voice_samples]
207
+ else:
208
+ # Batch of voice samples
209
+ voice_samples_list = voice_samples
210
+ else:
211
+ voice_samples_list = [None] * len(texts)
212
+
213
+ # Process each input
214
+ all_encodings = []
215
+ for text_input, voice_input in zip(texts, voice_samples_list):
216
+ encoding = self._process_single(text_input, voice_input)
217
+ all_encodings.append(encoding)
218
+
219
+ # Combine batch
220
+ batch_encoding = self._batch_encode(
221
+ all_encodings,
222
+ padding=padding,
223
+ truncation=truncation,
224
+ max_length=max_length,
225
+ return_tensors=return_tensors,
226
+ return_attention_mask=return_attention_mask,
227
+ )
228
+
229
+ return batch_encoding
230
+
231
+ def _process_single(
232
+ self,
233
+ text: Union[str, TextInput],
234
+ voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
235
+ ) -> Dict[str, Any]:
236
+ """Process a single podcast script."""
237
+ # Determine if text is a file path or direct script
238
+ script = None
239
+ if isinstance(text, str):
240
+ # Check if it's a file path
241
+ if text.endswith('.json') and os.path.exists(text):
242
+ script = self._convert_json_to_script(text)
243
+ elif text.endswith('.txt') and os.path.exists(text):
244
+ script = self._convert_text_to_script(text)
245
+ else:
246
+ # Assume it's the script content directly
247
+ script = text
248
+
249
+ if script is None:
250
+ raise ValueError(f"Could not process input text: {text}")
251
+
252
+ # Parse the script
253
+ parsed_lines = self._parse_script(script)
254
+ all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
255
+
256
+ # Create system prompt
257
+ # system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
258
+ system_tokens = self.tokenizer.encode(self.system_prompt)
259
+
260
+ # Process voice samples if provided
261
+ if voice_samples:
262
+ voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
263
+ else:
264
+ voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
265
+
266
+ # Build full token sequence
267
+ full_tokens = system_tokens + voice_tokens
268
+ speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
269
+
270
+ # Add text input section
271
+ full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
272
+ speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
273
+
274
+ for speaker_id, speaker_text in parsed_lines:
275
+ speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
276
+ full_tokens += speaker_text_tokens
277
+ speech_input_mask += [False] * len(speaker_text_tokens)
278
+
279
+ # Add speech output section
280
+ full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
281
+ speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
282
+
283
+ return {
284
+ "input_ids": full_tokens,
285
+ "speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
286
+ "speech_input_mask": speech_input_mask,
287
+ "parsed_script": parsed_lines,
288
+ "all_speakers": all_speakers,
289
+ }
290
+
291
+ def _batch_encode(
292
+ self,
293
+ encodings: List[Dict[str, Any]],
294
+ padding: Union[bool, str, PaddingStrategy] = True,
295
+ truncation: Union[bool, str, TruncationStrategy] = False,
296
+ max_length: Optional[int] = None,
297
+ return_tensors: Optional[Union[str, TensorType]] = None,
298
+ return_attention_mask: bool = True,
299
+ ) -> BatchEncoding:
300
+ """Combine multiple encodings into a batch with padding."""
301
+ # Extract input_ids and create attention_mask
302
+ input_ids_list = [enc["input_ids"] for enc in encodings]
303
+ speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
304
+
305
+ # Determine padding strategy
306
+ if isinstance(padding, bool):
307
+ padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
308
+ elif isinstance(padding, str):
309
+ padding_strategy = PaddingStrategy(padding)
310
+ else:
311
+ padding_strategy = padding
312
+
313
+ # Apply padding to input_ids
314
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD:
315
+ if padding_strategy == PaddingStrategy.LONGEST:
316
+ max_len = max(len(ids) for ids in input_ids_list)
317
+ elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
318
+ max_len = max_length
319
+ else:
320
+ max_len = max(len(ids) for ids in input_ids_list)
321
+
322
+ # Pad sequences
323
+ padded_input_ids = []
324
+ attention_masks = []
325
+ padded_speech_input_masks = []
326
+
327
+ for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
328
+ # Truncate if needed
329
+ if truncation and len(input_ids) > max_len:
330
+ input_ids = input_ids[:max_len]
331
+ speech_mask = speech_mask[:max_len]
332
+
333
+ # Pad
334
+ padding_length = max_len - len(input_ids)
335
+ # padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
336
+ padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
337
+ attention_mask = [0] * padding_length + [1] * len(input_ids)
338
+ padded_speech_mask = [False] * padding_length + speech_mask
339
+
340
+ padded_input_ids.append(padded_ids)
341
+ attention_masks.append(attention_mask)
342
+ padded_speech_input_masks.append(padded_speech_mask)
343
+
344
+ input_ids_list = padded_input_ids
345
+ speech_input_masks_list = padded_speech_input_masks
346
+ else:
347
+ # No padding, just create attention masks
348
+ attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
349
+
350
+ # Process speech inputs
351
+ all_speech_inputs = []
352
+ has_speech = False
353
+ for enc in encodings:
354
+ if enc["speech_inputs"] is not None:
355
+ all_speech_inputs.extend(enc["speech_inputs"])
356
+ has_speech = True
357
+
358
+ # Prepare batch encoding
359
+ batch_encoding = BatchEncoding()
360
+
361
+ # Handle tensor conversion
362
+ if return_tensors is not None:
363
+ batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
364
+ if return_attention_mask and attention_masks is not None:
365
+ batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
366
+ batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
367
+ else:
368
+ batch_encoding["input_ids"] = input_ids_list
369
+ if return_attention_mask and attention_masks is not None:
370
+ batch_encoding["attention_mask"] = attention_masks
371
+ batch_encoding["speech_input_mask"] = speech_input_masks_list
372
+
373
+ # Process speech tensors if present
374
+ if has_speech:
375
+ speech_dict = self.prepare_speech_inputs(
376
+ all_speech_inputs,
377
+ return_tensors=return_tensors,
378
+ )
379
+ batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
380
+ batch_encoding["speech_masks"] = speech_dict["speech_masks"]
381
+ else:
382
+ batch_encoding["speech_tensors"] = None
383
+ batch_encoding["speech_masks"] = None
384
+
385
+ # Add metadata
386
+ batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
387
+ batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
388
+
389
+ return batch_encoding
390
+
391
+ def _create_voice_prompt(
392
+ self,
393
+ speaker_samples: List[Union[str, np.ndarray]]
394
+ ) -> Tuple[List[int], List[np.ndarray], List[bool]]:
395
+ """
396
+ Create voice prompt tokens and process audio samples.
397
+
398
+ Returns:
399
+ tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
400
+ """
401
+ vae_token_id = self.tokenizer.speech_diffusion_id
402
+
403
+ voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
404
+ voice_speech_inputs = []
405
+ voice_speech_masks = [False] * len(voice_full_tokens)
406
+
407
+ for speaker_id, speaker_audio in enumerate(speaker_samples):
408
+ prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
409
+
410
+ # Process audio
411
+ if isinstance(speaker_audio, str):
412
+ # Load audio from file
413
+ wav = self.audio_processor._load_audio_from_path(speaker_audio)
414
+ else:
415
+ wav = np.array(speaker_audio, dtype=np.float32)
416
+
417
+ # Apply normalization if needed
418
+ if self.db_normalize and self.audio_normalizer:
419
+ wav = self.audio_normalizer(wav)
420
+
421
+ # Calculate token length based on compression ratio
422
+ # if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
423
+ # vae_tok_len = wav.shape[0]
424
+ # else:
425
+ vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
426
+
427
+ # Build tokens and masks
428
+ speaker_tokens = (prefix_tokens +
429
+ [self.tokenizer.speech_start_id] +
430
+ [vae_token_id] * vae_tok_len +
431
+ [self.tokenizer.speech_end_id] +
432
+ self.tokenizer.encode('\n', add_special_tokens=False))
433
+
434
+ vae_input_mask = ([False] * len(prefix_tokens) +
435
+ [False] +
436
+ [True] * vae_tok_len +
437
+ [False] +
438
+ [False])
439
+
440
+ voice_full_tokens.extend(speaker_tokens)
441
+ voice_speech_masks.extend(vae_input_mask)
442
+ voice_speech_inputs.append(wav)
443
+
444
+ return voice_full_tokens, voice_speech_inputs, voice_speech_masks
445
+
446
+ def prepare_speech_inputs(
447
+ self,
448
+ speech_inputs: List[np.ndarray],
449
+ return_tensors: Optional[Union[str, TensorType]] = None,
450
+ device: Optional[Union[str, torch.device]] = None,
451
+ dtype: Optional[torch.dtype] = None,
452
+ ) -> Dict[str, Any]:
453
+ """
454
+ Prepare speech inputs for model consumption.
455
+
456
+ Args:
457
+ speech_inputs: List of speech arrays
458
+ return_tensors: Output tensor type
459
+ device: Device to place tensors on
460
+ dtype: Data type for tensors
461
+
462
+ Returns:
463
+ Dictionary with padded_speeches and speech_masks
464
+ """
465
+ if not speech_inputs:
466
+ return {"padded_speeches": None, "speech_masks": None}
467
+
468
+ # Calculate sequence lengths
469
+ vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
470
+ # vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
471
+ max_speech_length = max(s.shape[0] for s in speech_inputs)
472
+
473
+ # Pad speeches
474
+ if speech_inputs[0].ndim == 1:
475
+ padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
476
+ else:
477
+ padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
478
+ speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
479
+
480
+ for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
481
+ padded_speeches[i, :len(speech)] = speech
482
+ speech_masks[i, :vae_tok_length] = True
483
+
484
+ result = {
485
+ "padded_speeches": padded_speeches,
486
+ "speech_masks": speech_masks,
487
+ }
488
+
489
+ # Convert to tensors if requested
490
+ if return_tensors == "pt":
491
+ result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
492
+ result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
493
+
494
+ return result
495
+
496
+ def _convert_json_to_script(self, json_file: str) -> str:
497
+ """
498
+ Convert JSON format to script format.
499
+ Expected JSON format:
500
+ [
501
+ {"speaker": "1", "text": "Hello everyone..."},
502
+ {"speaker": "2", "text": "Great to be here..."}
503
+ ]
504
+ """
505
+ import json
506
+
507
+ with open(json_file, 'r', encoding='utf-8') as f:
508
+ data = json.load(f)
509
+
510
+ if not isinstance(data, list):
511
+ raise ValueError("JSON file must contain a list of speaker entries")
512
+
513
+ script_lines = []
514
+ for item in data:
515
+ if not isinstance(item, dict):
516
+ logger.warning(f"Skipping non-dict entry: {item}")
517
+ continue
518
+
519
+ speaker = item.get('speaker')
520
+ text = item.get('text')
521
+
522
+ if speaker is None or text is None:
523
+ logger.warning(f"Skipping entry missing speaker or text: {item}")
524
+ continue
525
+
526
+ # Ensure speaker ID is valid
527
+ try:
528
+ speaker_id = int(speaker)
529
+ except (ValueError, TypeError):
530
+ logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
531
+ continue
532
+
533
+ # Clean up text
534
+ text = text.strip()
535
+ if text:
536
+ script_lines.append(f"Speaker {speaker_id}: {text}")
537
+
538
+ if not script_lines:
539
+ raise ValueError("No valid entries found in JSON file")
540
+
541
+ return "\n".join(script_lines)
542
+
543
+ def _convert_text_to_script(self, text_file: str) -> str:
544
+ """
545
+ Convert text file to script format.
546
+ Handles multiple formats:
547
+ 1. Already formatted as "Speaker X: text"
548
+ 2. Plain text (assigns to Speaker 1)
549
+
550
+ Handles edge cases like multiple colons in a line.
551
+ """
552
+ with open(text_file, 'r', encoding='utf-8') as f:
553
+ lines = f.readlines()
554
+
555
+ script_lines = []
556
+ current_speaker = 1
557
+
558
+ for line in lines:
559
+ line = line.strip()
560
+ if not line:
561
+ continue
562
+
563
+ # Try to parse as "Speaker X: text" format
564
+ # Use regex to be more robust
565
+ speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
566
+
567
+ if speaker_match:
568
+ speaker_id = int(speaker_match.group(1))
569
+ text = speaker_match.group(2).strip()
570
+ if text:
571
+ script_lines.append(f"Speaker {speaker_id}: {text}")
572
+ else:
573
+ # Treat as plain text - assign to current speaker
574
+ script_lines.append(f"Speaker {current_speaker}: {line}")
575
+
576
+ if not script_lines:
577
+ raise ValueError("No valid content found in text file")
578
+
579
+ return "\n".join(script_lines)
580
+
581
+ def _parse_script(self, script: str) -> List[Tuple[int, str]]:
582
+ """Parse script into list of (speaker_id, text) tuples."""
583
+ lines = script.strip().split("\n")
584
+ parsed_lines = []
585
+ speaker_ids = []
586
+
587
+ # First pass: parse all lines and collect speaker IDs
588
+ for line in lines:
589
+ if not line.strip():
590
+ continue
591
+
592
+ # Use regex to handle edge cases like multiple colons
593
+ match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
594
+
595
+ if match:
596
+ speaker_id = int(match.group(1))
597
+ text = ' ' + match.group(2).strip()
598
+ parsed_lines.append((speaker_id, text))
599
+ speaker_ids.append(speaker_id)
600
+ else:
601
+ logger.warning(f"Could not parse line: '{line}'")
602
+
603
+ if not parsed_lines:
604
+ raise ValueError("No valid speaker lines found in script")
605
+
606
+ # Check if we need to normalize speaker IDs (only if all are > 0)
607
+ min_speaker_id = min(speaker_ids)
608
+ if min_speaker_id > 0:
609
+ # Normalize to start from 0
610
+ normalized_lines = []
611
+ for speaker_id, text in parsed_lines:
612
+ normalized_lines.append((speaker_id - 1, text))
613
+ return normalized_lines
614
+ else:
615
+ # Keep original IDs
616
+ return parsed_lines
617
+
618
+ def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
619
+ """Merge text and audio inputs into a single BatchEncoding."""
620
+ # Start with text inputs
621
+ merged = BatchEncoding(text_inputs)
622
+
623
+ # Add audio-specific fields
624
+ if "audio" in audio_inputs:
625
+ merged["speech_inputs"] = audio_inputs["audio"]
626
+ if "streaming" in audio_inputs:
627
+ merged["streaming"] = audio_inputs["streaming"]
628
+
629
+ return merged
630
+
631
+ def batch_decode(self, *args, **kwargs):
632
+ """
633
+ This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
634
+ Please refer to the docstring of this method for more information.
635
+ """
636
+ return self.tokenizer.batch_decode(*args, **kwargs)
637
+
638
+ def decode(self, *args, **kwargs):
639
+ """
640
+ This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
641
+ Please refer to the docstring of this method for more information.
642
+ """
643
+ return self.tokenizer.decode(*args, **kwargs)
644
+
645
+ @property
646
+ def model_input_names(self):
647
+ """
648
+ Return the list of inputs accepted by the model.
649
+ """
650
+ tokenizer_input_names = self.tokenizer.model_input_names
651
+ audio_processor_input_names = self.audio_processor.model_input_names
652
+ return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
653
+
654
+ def save_audio(self,
655
+ audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
656
+ output_path: str = "output.wav",
657
+ sampling_rate: Optional[int] = None,
658
+ normalize: bool = False,
659
+ batch_prefix: str = "audio_",
660
+ ) -> str:
661
+ """
662
+ Save audio data to a file.
663
+ Args:
664
+ audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
665
+ The audio data to save. Can be a single tensor/array or a list of them.
666
+ output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
667
+ sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
668
+ normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
669
+ batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
670
+ Returns:
671
+ str: The path to the saved audio file.
672
+ """
673
+ return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
674
+
675
+ __all__ = [
676
+ "VibeVoiceProcessor",
677
+ ]
vvembed/processor/vibevoice_tokenizer_processor.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for VibeVoice models.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import warnings
8
+ from typing import List, Optional, Union, Dict, Any
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from transformers.feature_extraction_utils import FeatureExtractionMixin
14
+ from transformers.utils import logging
15
+
16
+ logger = logging.get_logger(__name__)
17
+
18
+
19
+ class AudioNormalizer:
20
+ """
21
+ Audio normalization class for VibeVoice tokenizer.
22
+
23
+ This class provides audio normalization to ensure consistent input levels
24
+ for the VibeVoice tokenizer while maintaining audio quality.
25
+ """
26
+
27
+ def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
28
+ """
29
+ Initialize the audio normalizer.
30
+
31
+ Args:
32
+ target_dB_FS (float): Target dB FS level for the audio. Default: -25
33
+ eps (float): Small value to avoid division by zero. Default: 1e-6
34
+ """
35
+ self.target_dB_FS = target_dB_FS
36
+ self.eps = eps
37
+
38
+ def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
39
+ """
40
+ Adjust the audio to the target dB FS level.
41
+
42
+ Args:
43
+ audio (np.ndarray): Input audio signal
44
+
45
+ Returns:
46
+ tuple: (normalized_audio, rms, scalar)
47
+ """
48
+ rms = np.sqrt(np.mean(audio**2))
49
+ scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
50
+ normalized_audio = audio * scalar
51
+ return normalized_audio, rms, scalar
52
+
53
+ def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
54
+ """
55
+ Avoid clipping by scaling down if necessary.
56
+
57
+ Args:
58
+ audio (np.ndarray): Input audio signal
59
+ scalar (float, optional): Explicit scaling factor
60
+
61
+ Returns:
62
+ tuple: (normalized_audio, scalar)
63
+ """
64
+ if scalar is None:
65
+ max_val = np.max(np.abs(audio))
66
+ if max_val > 1.0:
67
+ scalar = max_val + self.eps
68
+ else:
69
+ scalar = 1.0
70
+
71
+ return audio / scalar, scalar
72
+
73
+ def __call__(self, audio: np.ndarray) -> np.ndarray:
74
+ """
75
+ Normalize the audio by adjusting to target dB FS and avoiding clipping.
76
+
77
+ Args:
78
+ audio (np.ndarray): Input audio signal
79
+
80
+ Returns:
81
+ np.ndarray: Normalized audio signal
82
+ """
83
+ # First adjust to target dB FS
84
+ audio, _, _ = self.tailor_dB_FS(audio)
85
+ # Then avoid clipping
86
+ audio, _ = self.avoid_clipping(audio)
87
+ return audio
88
+
89
+
90
+ # Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
91
+ class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
92
+ """
93
+ Processor for VibeVoice acoustic tokenizer models.
94
+
95
+ This processor handles audio preprocessing for VibeVoice models, including:
96
+ - Audio format conversion (stereo to mono)
97
+ - Optional audio normalization
98
+ - Streaming support for infinite-length audio
99
+
100
+ Args:
101
+ sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
102
+ normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
103
+ target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
104
+ eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
105
+ """
106
+ model_input_names = ["input_features"]
107
+
108
+ def __init__(
109
+ self,
110
+ sampling_rate: int = 24000,
111
+ normalize_audio: bool = True,
112
+ target_dB_FS: float = -25,
113
+ eps: float = 1e-6,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+
118
+ self.sampling_rate = sampling_rate
119
+ self.normalize_audio = normalize_audio
120
+
121
+ # Initialize audio normalizer if needed
122
+ if self.normalize_audio:
123
+ self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
124
+ else:
125
+ self.normalizer = None
126
+
127
+ # Save config
128
+ self.feature_extractor_dict = {
129
+ "sampling_rate": sampling_rate,
130
+ "normalize_audio": normalize_audio,
131
+ "target_dB_FS": target_dB_FS,
132
+ "eps": eps,
133
+ }
134
+
135
+ def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
136
+ """
137
+ Convert stereo audio to mono if needed.
138
+
139
+ Args:
140
+ audio (np.ndarray): Input audio array
141
+
142
+ Returns:
143
+ np.ndarray: Mono audio array
144
+ """
145
+ if len(audio.shape) == 1:
146
+ return audio
147
+ elif len(audio.shape) == 2:
148
+ if audio.shape[0] == 2: # (2, time)
149
+ return np.mean(audio, axis=0)
150
+ elif audio.shape[1] == 2: # (time, 2)
151
+ return np.mean(audio, axis=1)
152
+ else:
153
+ # If one dimension is 1, squeeze it
154
+ if audio.shape[0] == 1:
155
+ return audio.squeeze(0)
156
+ elif audio.shape[1] == 1:
157
+ return audio.squeeze(1)
158
+ else:
159
+ raise ValueError(f"Unexpected audio shape: {audio.shape}")
160
+ else:
161
+ raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
162
+
163
+ def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
164
+ """
165
+ Process a single audio array.
166
+
167
+ Args:
168
+ audio: Single audio input
169
+
170
+ Returns:
171
+ np.ndarray: Processed audio
172
+ """
173
+ # Convert to numpy array
174
+ if not isinstance(audio, np.ndarray):
175
+ audio = np.array(audio, dtype=np.float32)
176
+ else:
177
+ audio = audio.astype(np.float32)
178
+
179
+ # Ensure mono
180
+ audio = self._ensure_mono(audio)
181
+
182
+ # Normalize if requested
183
+ if self.normalize_audio and self.normalizer is not None:
184
+ audio = self.normalizer(audio)
185
+
186
+ return audio
187
+
188
+ def __call__(
189
+ self,
190
+ audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None,
191
+ sampling_rate: Optional[int] = None,
192
+ return_tensors: Optional[str] = None,
193
+ **kwargs,
194
+ ):
195
+ """
196
+ Process audio for VibeVoice models.
197
+
198
+ Args:
199
+ audio: Audio input(s) to process. Can be:
200
+ - str: Path to audio file
201
+ - np.ndarray: Audio array
202
+ - List[float]: Audio as list of floats
203
+ - List[np.ndarray]: Batch of audio arrays
204
+ - List[str]: Batch of audio file paths
205
+ sampling_rate (int, optional): Sampling rate of the input audio
206
+ return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
207
+
208
+ Returns:
209
+ dict: Processed audio inputs with keys:
210
+ - input_features: Audio tensor(s) ready for the model
211
+ """
212
+ if audio is None:
213
+ raise ValueError("Audio input is required")
214
+
215
+ # Validate sampling rate
216
+ if sampling_rate is not None and sampling_rate != self.sampling_rate:
217
+ logger.warning(
218
+ f"Input sampling rate ({sampling_rate}) differs from expected "
219
+ f"sampling rate ({self.sampling_rate}). Please resample your audio."
220
+ )
221
+
222
+ # Handle different input types
223
+ if isinstance(audio, str):
224
+ # Single audio file path
225
+ audio = self._load_audio_from_path(audio)
226
+ is_batched = False
227
+ elif isinstance(audio, list):
228
+ if len(audio) == 0:
229
+ raise ValueError("Empty audio list provided")
230
+
231
+ # Check if it's a list of file paths
232
+ if all(isinstance(item, str) for item in audio):
233
+ # Batch of audio file paths
234
+ audio = [self._load_audio_from_path(path) for path in audio]
235
+ is_batched = True
236
+ else:
237
+ # Check if it's batched audio arrays
238
+ is_batched = isinstance(audio[0], (np.ndarray, list))
239
+ else:
240
+ # Single audio array or list
241
+ is_batched = False
242
+
243
+ # Process audio
244
+ if is_batched:
245
+ processed_audio = [self._process_single_audio(a) for a in audio]
246
+ else:
247
+ processed_audio = [self._process_single_audio(audio)]
248
+
249
+ # Convert to tensors if requested
250
+ if return_tensors == "pt":
251
+ if len(processed_audio) == 1:
252
+ # Create a proper batch dimension (B, T)
253
+ input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
254
+ else:
255
+ # For batched input with different lengths, create a batch properly
256
+ input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1)
257
+ elif return_tensors == "np":
258
+ if len(processed_audio) == 1:
259
+ input_features = processed_audio[0][np.newaxis, np.newaxis, :]
260
+ else:
261
+ input_features = np.stack(processed_audio)[:, np.newaxis, :]
262
+ else:
263
+ input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio
264
+
265
+ outputs = {
266
+ "audio": input_features, # Use "audio" instead of "input_features"
267
+ }
268
+
269
+ return outputs
270
+
271
+ def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
272
+ """
273
+ Load audio from file path.
274
+
275
+ Args:
276
+ audio_path (str): Path to audio file
277
+
278
+ Returns:
279
+ np.ndarray: Loaded audio array
280
+ """
281
+ # Get file extension to determine loading method
282
+ file_ext = os.path.splitext(audio_path)[1].lower()
283
+
284
+ if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
285
+ # Audio file - use librosa
286
+ import librosa
287
+ audio_array, sr = librosa.load(
288
+ audio_path,
289
+ sr=self.sampling_rate,
290
+ mono=True
291
+ )
292
+ return audio_array
293
+ elif file_ext == '.pt':
294
+ # PyTorch tensor file
295
+ audio_tensor = torch.load(audio_path, map_location='cpu').squeeze()
296
+ if isinstance(audio_tensor, torch.Tensor):
297
+ audio_array = audio_tensor.numpy()
298
+ else:
299
+ audio_array = np.array(audio_tensor)
300
+ return audio_array.astype(np.float32)
301
+ elif file_ext == '.npy':
302
+ # NumPy file
303
+ audio_array = np.load(audio_path)
304
+ return audio_array.astype(np.float32)
305
+ else:
306
+ raise ValueError(
307
+ f"Unsupported file format: {file_ext}. "
308
+ f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
309
+ )
310
+
311
+ def preprocess_audio(
312
+ self,
313
+ audio_path_or_array: Union[str, np.ndarray],
314
+ normalize: Optional[bool] = None,
315
+ ) -> np.ndarray:
316
+ """
317
+ Convenience method to preprocess audio from file path or array.
318
+ This method is kept for backward compatibility but __call__ is recommended.
319
+
320
+ Args:
321
+ audio_path_or_array: Path to audio file or numpy array
322
+ normalize: Whether to normalize (overrides default setting)
323
+
324
+ Returns:
325
+ np.ndarray: Preprocessed audio array
326
+ """
327
+ if isinstance(audio_path_or_array, str):
328
+ audio_array = self._load_audio_from_path(audio_path_or_array)
329
+ else:
330
+ audio_array = np.array(audio_path_or_array, dtype=np.float32)
331
+
332
+ # Override normalization setting if specified
333
+ original_normalize = self.normalize_audio
334
+ if normalize is not None:
335
+ self.normalize_audio = normalize
336
+
337
+ try:
338
+ processed = self._process_single_audio(audio_array)
339
+ finally:
340
+ # Restore original setting
341
+ self.normalize_audio = original_normalize
342
+
343
+ return processed
344
+
345
+ # Override to_dict method for configuration saving
346
+ def to_dict(self) -> Dict[str, Any]:
347
+ """
348
+ Convert the object to a dict containing all attributes needed for serialization.
349
+ """
350
+ return self.feature_extractor_dict
351
+
352
+ def save_audio(
353
+ self,
354
+ audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
355
+ output_path: str = "output.wav",
356
+ sampling_rate: Optional[int] = None,
357
+ normalize: bool = False,
358
+ batch_prefix: str = "audio_",
359
+ ):
360
+ """
361
+ Save audio data to WAV file(s).
362
+
363
+ Args:
364
+ audio: Audio data to save. Can be:
365
+ - torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
366
+ - np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
367
+ - List of tensors or arrays
368
+ output_path: Path where to save the audio. If saving multiple files,
369
+ this is treated as a directory and individual files will be saved inside.
370
+ sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
371
+ normalize: Whether to normalize audio before saving.
372
+ batch_prefix: Prefix for batch files when saving multiple audios.
373
+
374
+ Returns:
375
+ List[str]: Paths to the saved audio files.
376
+ """
377
+ if sampling_rate is None:
378
+ sampling_rate = self.sampling_rate
379
+
380
+ try:
381
+ import soundfile as sf
382
+ except ImportError:
383
+ raise ImportError(
384
+ "soundfile is required to save audio files. "
385
+ "Install it with: pip install soundfile"
386
+ )
387
+
388
+ # Ensure audio is in the right format
389
+ if isinstance(audio, torch.Tensor):
390
+ # Convert PyTorch tensor to numpy
391
+ audio_np = audio.float().detach().cpu().numpy()
392
+ elif isinstance(audio, np.ndarray):
393
+ audio_np = audio
394
+ elif isinstance(audio, list):
395
+ # Handle list of tensors or arrays
396
+ if all(isinstance(a, torch.Tensor) for a in audio):
397
+ audio_np = [a.float().detach().cpu().numpy() for a in audio]
398
+ else:
399
+ audio_np = audio
400
+ else:
401
+ raise ValueError(f"Unsupported audio type: {type(audio)}")
402
+
403
+ saved_paths = []
404
+
405
+ # Handle based on shape or type
406
+ if isinstance(audio_np, list):
407
+ # Multiple separate audios to save
408
+ output_dir = output_path
409
+
410
+ # Ensure output directory exists
411
+ os.makedirs(output_dir, exist_ok=True)
412
+
413
+ # Save each audio
414
+ for i, audio_item in enumerate(audio_np):
415
+ audio_item = self._prepare_audio_for_save(audio_item, normalize)
416
+ file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
417
+ sf.write(file_path, audio_item, sampling_rate)
418
+ saved_paths.append(file_path)
419
+
420
+ else:
421
+ # Handle different dimensions
422
+ if len(audio_np.shape) >= 3: # (B, C, T) or similar
423
+ # Get batch size
424
+ batch_size = audio_np.shape[0]
425
+
426
+ if batch_size > 1:
427
+ # Multiple audios in a batch
428
+ output_dir = output_path
429
+
430
+ # Ensure output directory exists
431
+ os.makedirs(output_dir, exist_ok=True)
432
+
433
+ # Save each audio in the batch
434
+ for i in range(batch_size):
435
+ # Extract single audio and remove channel dim if present
436
+ single_audio = audio_np[i]
437
+ if len(single_audio.shape) > 1:
438
+ if single_audio.shape[0] == 1: # (1, T)
439
+ single_audio = single_audio.squeeze(0)
440
+
441
+ single_audio = self._prepare_audio_for_save(single_audio, normalize)
442
+ file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
443
+ sf.write(file_path, single_audio, sampling_rate)
444
+ saved_paths.append(file_path)
445
+ else:
446
+ # Single audio with batch and channel dims
447
+ audio_item = audio_np.squeeze() # Remove batch and channel dimensions
448
+ audio_item = self._prepare_audio_for_save(audio_item, normalize)
449
+ sf.write(output_path, audio_item, sampling_rate)
450
+ saved_paths.append(output_path)
451
+ else:
452
+ # Single audio without batch dimension
453
+ audio_item = self._prepare_audio_for_save(audio_np, normalize)
454
+ sf.write(output_path, audio_item, sampling_rate)
455
+ saved_paths.append(output_path)
456
+
457
+ return saved_paths
458
+
459
+ def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
460
+ """
461
+ Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
462
+
463
+ Args:
464
+ audio: Audio data as numpy array
465
+ normalize: Whether to normalize audio
466
+
467
+ Returns:
468
+ np.ndarray: Processed audio ready for saving
469
+ """
470
+ # Ensure right dimensionality
471
+ if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T)
472
+ audio = audio.squeeze(0)
473
+
474
+ # Normalize if requested
475
+ if normalize:
476
+ max_val = np.abs(audio).max()
477
+ if max_val > 0:
478
+ audio = audio / max_val
479
+
480
+ return audio
481
+
482
+
483
+ __all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"]
vvembed/schedule/__init__.py ADDED
File without changes
vvembed/schedule/dpm_solver.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import deprecate
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
+
28
+ def betas_for_alpha_bar(
29
+ num_diffusion_timesteps,
30
+ max_beta=0.999,
31
+ alpha_transform_type="cosine",
32
+ ):
33
+ """
34
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
35
+ (1-beta) over time from t = [0,1].
36
+
37
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
38
+ to that part of the diffusion process.
39
+
40
+
41
+ Args:
42
+ num_diffusion_timesteps (`int`): the number of betas to produce.
43
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
44
+ prevent singularities.
45
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
46
+ Choose from `cosine` or `exp`
47
+
48
+ Returns:
49
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
50
+ """
51
+ if alpha_transform_type == "cosine":
52
+
53
+ def alpha_bar_fn(t):
54
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
55
+ # return math.cos(t * math.pi / 2 * 0.95) ** 2
56
+
57
+ elif alpha_transform_type == "exp":
58
+
59
+ def alpha_bar_fn(t):
60
+ return math.exp(t * -12.0)
61
+
62
+ elif alpha_transform_type == "cauchy":
63
+ # µ + γ tan (π (0.5 - x)) γ = 1, µ = 3
64
+ # alpha^2 = 1-1/(exp(λ)+1)
65
+ def alpha_bar_fn(t, gamma=1, mu=3):
66
+ snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
67
+ return 1 - 1 / (math.exp(snr) + 1.1)
68
+
69
+ elif alpha_transform_type == "laplace":
70
+ # µ − bsgn(0.5 − t) log(1 − 2|t − 0.5|) µ = 0, b = 1
71
+ def alpha_bar_fn(t, mu=0, b=1):
72
+ snr = mu - b * math.copysign(1, 0.5 - t) * math.log(1 - 2 * abs(t - 0.5) * 0.98)
73
+ return 1 - 1 / (math.exp(snr) + 1.02)
74
+
75
+ else:
76
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
77
+
78
+ betas = []
79
+ for i in range(num_diffusion_timesteps):
80
+ t1 = i / num_diffusion_timesteps
81
+ t2 = (i + 1) / num_diffusion_timesteps
82
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
83
+ return torch.tensor(betas, dtype=torch.float32)
84
+
85
+
86
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
87
+ def rescale_zero_terminal_snr(betas):
88
+ """
89
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
90
+
91
+
92
+ Args:
93
+ betas (`torch.Tensor`):
94
+ the betas that the scheduler is being initialized with.
95
+
96
+ Returns:
97
+ `torch.Tensor`: rescaled betas with zero terminal SNR
98
+ """
99
+ # Convert betas to alphas_bar_sqrt
100
+ alphas = 1.0 - betas
101
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
102
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
103
+
104
+ # Store old values.
105
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
106
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
107
+
108
+ # Shift so the last timestep is zero.
109
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
110
+
111
+ # Scale so the first timestep is back to the old value.
112
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
113
+
114
+ # Convert alphas_bar_sqrt to betas
115
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
116
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
117
+ alphas = torch.cat([alphas_bar[0:1], alphas])
118
+ betas = 1 - alphas
119
+
120
+ return betas
121
+
122
+ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
123
+ """
124
+ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
125
+
126
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
127
+ methods the library implements for all schedulers such as loading and saving.
128
+
129
+ Args:
130
+ num_train_timesteps (`int`, defaults to 1000):
131
+ The number of diffusion steps to train the model.
132
+ beta_start (`float`, defaults to 0.0001):
133
+ The starting `beta` value of inference.
134
+ beta_end (`float`, defaults to 0.02):
135
+ The final `beta` value.
136
+ beta_schedule (`str`, defaults to `"linear"`):
137
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
138
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
139
+ trained_betas (`np.ndarray`, *optional*):
140
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
141
+ solver_order (`int`, defaults to 2):
142
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
143
+ sampling, and `solver_order=3` for unconditional sampling.
144
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
145
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
146
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
147
+ Video](https://imagen.research.google/video/paper.pdf) paper).
148
+ thresholding (`bool`, defaults to `False`):
149
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
150
+ as Stable Diffusion.
151
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
152
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
153
+ sample_max_value (`float`, defaults to 1.0):
154
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
155
+ `algorithm_type="dpmsolver++"`.
156
+ algorithm_type (`str`, defaults to `dpmsolver++`):
157
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
158
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
159
+ paper, and the `dpmsolver++` type implements the algorithms in the
160
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
161
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
162
+ solver_type (`str`, defaults to `midpoint`):
163
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
164
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
165
+ lower_order_final (`bool`, defaults to `True`):
166
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
167
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
168
+ euler_at_final (`bool`, defaults to `False`):
169
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
170
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
171
+ steps, but sometimes may result in blurring.
172
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
173
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
174
+ the sigmas are determined according to a sequence of noise levels {σi}.
175
+ use_lu_lambdas (`bool`, *optional*, defaults to `False`):
176
+ Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
177
+ the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
178
+ `lambda(t)`.
179
+ final_sigmas_type (`str`, defaults to `"zero"`):
180
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
181
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
182
+ lambda_min_clipped (`float`, defaults to `-inf`):
183
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
184
+ cosine (`squaredcos_cap_v2`) noise schedule.
185
+ variance_type (`str`, *optional*):
186
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
187
+ contains the predicted Gaussian variance.
188
+ timestep_spacing (`str`, defaults to `"linspace"`):
189
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
190
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
191
+ steps_offset (`int`, defaults to 0):
192
+ An offset added to the inference steps, as required by some model families.
193
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
194
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
195
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
196
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
197
+ """
198
+
199
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
200
+ order = 1
201
+
202
+ @register_to_config
203
+ def __init__(
204
+ self,
205
+ num_train_timesteps: int = 1000,
206
+ beta_start: float = 0.0001,
207
+ beta_end: float = 0.02,
208
+ beta_schedule: str = "linear",
209
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
210
+ solver_order: int = 2,
211
+ prediction_type: str = "epsilon",
212
+ thresholding: bool = False,
213
+ dynamic_thresholding_ratio: float = 0.995,
214
+ sample_max_value: float = 1.0,
215
+ algorithm_type: str = "dpmsolver++",
216
+ solver_type: str = "midpoint",
217
+ lower_order_final: bool = True,
218
+ euler_at_final: bool = False,
219
+ use_karras_sigmas: Optional[bool] = False,
220
+ use_lu_lambdas: Optional[bool] = False,
221
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
222
+ lambda_min_clipped: float = -float("inf"),
223
+ variance_type: Optional[str] = None,
224
+ timestep_spacing: str = "linspace",
225
+ steps_offset: int = 0,
226
+ rescale_betas_zero_snr: bool = False,
227
+ ):
228
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
229
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
230
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
231
+
232
+ if trained_betas is not None:
233
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
234
+ elif beta_schedule == "linear":
235
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
236
+ elif beta_schedule == "scaled_linear":
237
+ # this schedule is very specific to the latent diffusion model.
238
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
239
+ elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
240
+ # Glide cosine schedule
241
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
242
+ elif beta_schedule == "cauchy":
243
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cauchy")
244
+ elif beta_schedule == "laplace":
245
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
246
+ else:
247
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
248
+
249
+ if rescale_betas_zero_snr:
250
+ self.betas = rescale_zero_terminal_snr(self.betas)
251
+
252
+ self.alphas = 1.0 - self.betas
253
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
254
+
255
+ if rescale_betas_zero_snr:
256
+ # Close to 0 without being 0 so first sigma is not inf
257
+ # FP16 smallest positive subnormal works well here
258
+ self.alphas_cumprod[-1] = 2**-24
259
+
260
+ # Currently we only support VP-type noise schedule
261
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
262
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
263
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
264
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
265
+
266
+ # standard deviation of the initial noise distribution
267
+ self.init_noise_sigma = 1.0
268
+
269
+ # settings for DPM-Solver
270
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
271
+ if algorithm_type == "deis":
272
+ self.register_to_config(algorithm_type="dpmsolver++")
273
+ else:
274
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
275
+
276
+ if solver_type not in ["midpoint", "heun"]:
277
+ if solver_type in ["logrho", "bh1", "bh2"]:
278
+ self.register_to_config(solver_type="midpoint")
279
+ else:
280
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
281
+
282
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
283
+ raise ValueError(
284
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
285
+ )
286
+
287
+ # setable values
288
+ self.num_inference_steps = None
289
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
290
+ self.timesteps = torch.from_numpy(timesteps)
291
+ self.model_outputs = [None] * solver_order
292
+ self.lower_order_nums = 0
293
+ self._step_index = None
294
+ self._begin_index = None
295
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
296
+
297
+ @property
298
+ def step_index(self):
299
+ """
300
+ The index counter for current timestep. It will increase 1 after each scheduler step.
301
+ """
302
+ return self._step_index
303
+
304
+ @property
305
+ def begin_index(self):
306
+ """
307
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
308
+ """
309
+ return self._begin_index
310
+
311
+ def set_begin_index(self, begin_index: int = 0):
312
+ """
313
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
314
+
315
+ Args:
316
+ begin_index (`int`):
317
+ The begin index for the scheduler.
318
+ """
319
+ self._begin_index = begin_index
320
+
321
+ def set_timesteps(
322
+ self,
323
+ num_inference_steps: int = None,
324
+ device: Union[str, torch.device] = None,
325
+ timesteps: Optional[List[int]] = None,
326
+ ):
327
+ """
328
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
329
+
330
+ Args:
331
+ num_inference_steps (`int`):
332
+ The number of diffusion steps used when generating samples with a pre-trained model.
333
+ device (`str` or `torch.device`, *optional*):
334
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
335
+ timesteps (`List[int]`, *optional*):
336
+ Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
337
+ based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
338
+ must be `None`, and `timestep_spacing` attribute will be ignored.
339
+ """
340
+ if num_inference_steps is None and timesteps is None:
341
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
342
+ if num_inference_steps is not None and timesteps is not None:
343
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
344
+ if timesteps is not None and self.config.use_karras_sigmas:
345
+ raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
346
+ if timesteps is not None and self.config.use_lu_lambdas:
347
+ raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
348
+
349
+ if timesteps is not None:
350
+ timesteps = np.array(timesteps).astype(np.int64)
351
+ else:
352
+ # Clipping the minimum of all lambda(t) for numerical stability.
353
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
354
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
355
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
356
+
357
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
358
+ if self.config.timestep_spacing == "linspace":
359
+ timesteps = (
360
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1)
361
+ .round()[::-1][:-1]
362
+ .copy()
363
+ .astype(np.int64)
364
+ )
365
+ elif self.config.timestep_spacing == "leading":
366
+ step_ratio = last_timestep // (num_inference_steps + 1)
367
+ # creates integer timesteps by multiplying by ratio
368
+ # casting to int to avoid issues when num_inference_step is power of 3
369
+ timesteps = (
370
+ (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
371
+ )
372
+ timesteps += self.config.steps_offset
373
+ elif self.config.timestep_spacing == "trailing":
374
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
375
+ # creates integer timesteps by multiplying by ratio
376
+ # casting to int to avoid issues when num_inference_step is power of 3
377
+ timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
378
+ timesteps -= 1
379
+ else:
380
+ raise ValueError(
381
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
382
+ )
383
+
384
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
385
+ log_sigmas = np.log(sigmas)
386
+
387
+ if self.config.use_karras_sigmas:
388
+ sigmas = np.flip(sigmas).copy()
389
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
390
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
391
+ elif self.config.use_lu_lambdas:
392
+ lambdas = np.flip(log_sigmas.copy())
393
+ lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
394
+ sigmas = np.exp(lambdas)
395
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
396
+ else:
397
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
398
+
399
+ if self.config.final_sigmas_type == "sigma_min":
400
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
401
+ elif self.config.final_sigmas_type == "zero":
402
+ sigma_last = 0
403
+ else:
404
+ raise ValueError(
405
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
406
+ )
407
+
408
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
409
+
410
+ self.sigmas = torch.from_numpy(sigmas)
411
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
412
+
413
+ self.num_inference_steps = len(timesteps)
414
+
415
+ self.model_outputs = [
416
+ None,
417
+ ] * self.config.solver_order
418
+ self.lower_order_nums = 0
419
+
420
+ # add an index counter for schedulers that allow duplicated timesteps
421
+ self._step_index = None
422
+ self._begin_index = None
423
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
424
+
425
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
426
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
427
+ """
428
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
429
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
430
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
431
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
432
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
433
+
434
+ https://arxiv.org/abs/2205.11487
435
+ """
436
+ dtype = sample.dtype
437
+ batch_size, channels, *remaining_dims = sample.shape
438
+
439
+ if dtype not in (torch.float32, torch.float64):
440
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
441
+
442
+ # Flatten sample for doing quantile calculation along each image
443
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
444
+
445
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
446
+
447
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
448
+ s = torch.clamp(
449
+ s, min=1, max=self.config.sample_max_value
450
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
451
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
452
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
453
+
454
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
455
+ sample = sample.to(dtype)
456
+
457
+ return sample
458
+
459
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
460
+ def _sigma_to_t(self, sigma, log_sigmas):
461
+ # get log sigma
462
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
463
+
464
+ # get distribution
465
+ dists = log_sigma - log_sigmas[:, np.newaxis]
466
+
467
+ # get sigmas range
468
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
469
+ high_idx = low_idx + 1
470
+
471
+ low = log_sigmas[low_idx]
472
+ high = log_sigmas[high_idx]
473
+
474
+ # interpolate sigmas
475
+ w = (low - log_sigma) / (low - high)
476
+ w = np.clip(w, 0, 1)
477
+
478
+ # transform interpolation to time range
479
+ t = (1 - w) * low_idx + w * high_idx
480
+ t = t.reshape(sigma.shape)
481
+ return t
482
+
483
+ def _sigma_to_alpha_sigma_t(self, sigma):
484
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
485
+ sigma_t = sigma * alpha_t
486
+
487
+ return alpha_t, sigma_t
488
+
489
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
490
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
491
+ """Constructs the noise schedule of Karras et al. (2022)."""
492
+
493
+ # Hack to make sure that other schedulers which copy this function don't break
494
+ # TODO: Add this logic to the other schedulers
495
+ if hasattr(self.config, "sigma_min"):
496
+ sigma_min = self.config.sigma_min
497
+ else:
498
+ sigma_min = None
499
+
500
+ if hasattr(self.config, "sigma_max"):
501
+ sigma_max = self.config.sigma_max
502
+ else:
503
+ sigma_max = None
504
+
505
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
506
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
507
+
508
+ rho = 7.0 # 7.0 is the value used in the paper
509
+ ramp = np.linspace(0, 1, num_inference_steps)
510
+ min_inv_rho = sigma_min ** (1 / rho)
511
+ max_inv_rho = sigma_max ** (1 / rho)
512
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
513
+ return sigmas
514
+
515
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
516
+ """Constructs the noise schedule of Lu et al. (2022)."""
517
+
518
+ lambda_min: float = in_lambdas[-1].item()
519
+ lambda_max: float = in_lambdas[0].item()
520
+
521
+ rho = 1.0 # 1.0 is the value used in the paper
522
+ ramp = np.linspace(0, 1, num_inference_steps)
523
+ min_inv_rho = lambda_min ** (1 / rho)
524
+ max_inv_rho = lambda_max ** (1 / rho)
525
+ lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
526
+ return lambdas
527
+
528
+ def convert_model_output(
529
+ self,
530
+ model_output: torch.Tensor,
531
+ *args,
532
+ sample: torch.Tensor = None,
533
+ **kwargs,
534
+ ) -> torch.Tensor:
535
+ """
536
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
537
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
538
+ integral of the data prediction model.
539
+
540
+ <Tip>
541
+
542
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
543
+ prediction and data prediction models.
544
+
545
+ </Tip>
546
+
547
+ Args:
548
+ model_output (`torch.Tensor`):
549
+ The direct output from the learned diffusion model.
550
+ sample (`torch.Tensor`):
551
+ A current instance of a sample created by the diffusion process.
552
+
553
+ Returns:
554
+ `torch.Tensor`:
555
+ The converted model output.
556
+ """
557
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
558
+ if sample is None:
559
+ if len(args) > 1:
560
+ sample = args[1]
561
+ else:
562
+ raise ValueError("missing `sample` as a required keyward argument")
563
+ if timestep is not None:
564
+ deprecate(
565
+ "timesteps",
566
+ "1.0.0",
567
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
568
+ )
569
+
570
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
571
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
572
+ if self.config.prediction_type == "epsilon":
573
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
574
+ if self.config.variance_type in ["learned", "learned_range"]:
575
+ model_output = model_output[:, :3]
576
+ sigma = self.sigmas[self.step_index]
577
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
578
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
579
+ elif self.config.prediction_type == "sample":
580
+ x0_pred = model_output
581
+ elif self.config.prediction_type == "v_prediction":
582
+ sigma = self.sigmas[self.step_index]
583
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
584
+ x0_pred = alpha_t * sample - sigma_t * model_output
585
+ else:
586
+ raise ValueError(
587
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
588
+ " `v_prediction` for the DPMSolverMultistepScheduler."
589
+ )
590
+
591
+ if self.config.thresholding:
592
+ x0_pred = self._threshold_sample(x0_pred)
593
+
594
+ return x0_pred
595
+
596
+ # DPM-Solver needs to solve an integral of the noise prediction model.
597
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
598
+ if self.config.prediction_type == "epsilon":
599
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
600
+ if self.config.variance_type in ["learned", "learned_range"]:
601
+ epsilon = model_output[:, :3]
602
+ else:
603
+ epsilon = model_output
604
+ elif self.config.prediction_type == "sample":
605
+ sigma = self.sigmas[self.step_index]
606
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
607
+ epsilon = (sample - alpha_t * model_output) / sigma_t
608
+ elif self.config.prediction_type == "v_prediction":
609
+ sigma = self.sigmas[self.step_index]
610
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
611
+ epsilon = alpha_t * model_output + sigma_t * sample
612
+ else:
613
+ raise ValueError(
614
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
615
+ " `v_prediction` for the DPMSolverMultistepScheduler."
616
+ )
617
+
618
+ if self.config.thresholding:
619
+ sigma = self.sigmas[self.step_index]
620
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
621
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
622
+ x0_pred = self._threshold_sample(x0_pred)
623
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
624
+
625
+ return epsilon
626
+
627
+ def dpm_solver_first_order_update(
628
+ self,
629
+ model_output: torch.Tensor,
630
+ *args,
631
+ sample: torch.Tensor = None,
632
+ noise: Optional[torch.Tensor] = None,
633
+ **kwargs,
634
+ ) -> torch.Tensor:
635
+ """
636
+ One step for the first-order DPMSolver (equivalent to DDIM).
637
+
638
+ Args:
639
+ model_output (`torch.Tensor`):
640
+ The direct output from the learned diffusion model.
641
+ sample (`torch.Tensor`):
642
+ A current instance of a sample created by the diffusion process.
643
+
644
+ Returns:
645
+ `torch.Tensor`:
646
+ The sample tensor at the previous timestep.
647
+ """
648
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
649
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
650
+ if sample is None:
651
+ if len(args) > 2:
652
+ sample = args[2]
653
+ else:
654
+ raise ValueError(" missing `sample` as a required keyward argument")
655
+ if timestep is not None:
656
+ deprecate(
657
+ "timesteps",
658
+ "1.0.0",
659
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
660
+ )
661
+
662
+ if prev_timestep is not None:
663
+ deprecate(
664
+ "prev_timestep",
665
+ "1.0.0",
666
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
667
+ )
668
+
669
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
670
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
671
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
672
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
673
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
674
+
675
+ h = lambda_t - lambda_s
676
+ if self.config.algorithm_type == "dpmsolver++":
677
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
678
+ elif self.config.algorithm_type == "dpmsolver":
679
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
680
+ elif self.config.algorithm_type == "sde-dpmsolver++":
681
+ assert noise is not None
682
+ x_t = (
683
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
684
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
685
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
686
+ )
687
+ elif self.config.algorithm_type == "sde-dpmsolver":
688
+ assert noise is not None
689
+ x_t = (
690
+ (alpha_t / alpha_s) * sample
691
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
692
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
693
+ )
694
+ return x_t
695
+
696
+ def multistep_dpm_solver_second_order_update(
697
+ self,
698
+ model_output_list: List[torch.Tensor],
699
+ *args,
700
+ sample: torch.Tensor = None,
701
+ noise: Optional[torch.Tensor] = None,
702
+ **kwargs,
703
+ ) -> torch.Tensor:
704
+ """
705
+ One step for the second-order multistep DPMSolver.
706
+
707
+ Args:
708
+ model_output_list (`List[torch.Tensor]`):
709
+ The direct outputs from learned diffusion model at current and latter timesteps.
710
+ sample (`torch.Tensor`):
711
+ A current instance of a sample created by the diffusion process.
712
+
713
+ Returns:
714
+ `torch.Tensor`:
715
+ The sample tensor at the previous timestep.
716
+ """
717
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
718
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
719
+ if sample is None:
720
+ if len(args) > 2:
721
+ sample = args[2]
722
+ else:
723
+ raise ValueError(" missing `sample` as a required keyward argument")
724
+ if timestep_list is not None:
725
+ deprecate(
726
+ "timestep_list",
727
+ "1.0.0",
728
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
729
+ )
730
+
731
+ if prev_timestep is not None:
732
+ deprecate(
733
+ "prev_timestep",
734
+ "1.0.0",
735
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
736
+ )
737
+
738
+ sigma_t, sigma_s0, sigma_s1 = (
739
+ self.sigmas[self.step_index + 1],
740
+ self.sigmas[self.step_index],
741
+ self.sigmas[self.step_index - 1],
742
+ )
743
+
744
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
745
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
746
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
747
+
748
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
749
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
750
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
751
+
752
+ m0, m1 = model_output_list[-1], model_output_list[-2]
753
+
754
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
755
+ r0 = h_0 / h
756
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
757
+ if self.config.algorithm_type == "dpmsolver++":
758
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
759
+ if self.config.solver_type == "midpoint":
760
+ x_t = (
761
+ (sigma_t / sigma_s0) * sample
762
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
763
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
764
+ )
765
+ elif self.config.solver_type == "heun":
766
+ x_t = (
767
+ (sigma_t / sigma_s0) * sample
768
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
769
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
770
+ )
771
+ elif self.config.algorithm_type == "dpmsolver":
772
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
773
+ if self.config.solver_type == "midpoint":
774
+ x_t = (
775
+ (alpha_t / alpha_s0) * sample
776
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
777
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
778
+ )
779
+ elif self.config.solver_type == "heun":
780
+ x_t = (
781
+ (alpha_t / alpha_s0) * sample
782
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
783
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
784
+ )
785
+ elif self.config.algorithm_type == "sde-dpmsolver++":
786
+ assert noise is not None
787
+ if self.config.solver_type == "midpoint":
788
+ x_t = (
789
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
790
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
791
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
792
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
793
+ )
794
+ elif self.config.solver_type == "heun":
795
+ x_t = (
796
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
797
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
798
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
799
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
800
+ )
801
+ elif self.config.algorithm_type == "sde-dpmsolver":
802
+ assert noise is not None
803
+ if self.config.solver_type == "midpoint":
804
+ x_t = (
805
+ (alpha_t / alpha_s0) * sample
806
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
807
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
808
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
809
+ )
810
+ elif self.config.solver_type == "heun":
811
+ x_t = (
812
+ (alpha_t / alpha_s0) * sample
813
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
814
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
815
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
816
+ )
817
+ return x_t
818
+
819
+ def multistep_dpm_solver_third_order_update(
820
+ self,
821
+ model_output_list: List[torch.Tensor],
822
+ *args,
823
+ sample: torch.Tensor = None,
824
+ **kwargs,
825
+ ) -> torch.Tensor:
826
+ """
827
+ One step for the third-order multistep DPMSolver.
828
+
829
+ Args:
830
+ model_output_list (`List[torch.Tensor]`):
831
+ The direct outputs from learned diffusion model at current and latter timesteps.
832
+ sample (`torch.Tensor`):
833
+ A current instance of a sample created by diffusion process.
834
+
835
+ Returns:
836
+ `torch.Tensor`:
837
+ The sample tensor at the previous timestep.
838
+ """
839
+
840
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
841
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
842
+ if sample is None:
843
+ if len(args) > 2:
844
+ sample = args[2]
845
+ else:
846
+ raise ValueError(" missing`sample` as a required keyward argument")
847
+ if timestep_list is not None:
848
+ deprecate(
849
+ "timestep_list",
850
+ "1.0.0",
851
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
852
+ )
853
+
854
+ if prev_timestep is not None:
855
+ deprecate(
856
+ "prev_timestep",
857
+ "1.0.0",
858
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
859
+ )
860
+
861
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
862
+ self.sigmas[self.step_index + 1],
863
+ self.sigmas[self.step_index],
864
+ self.sigmas[self.step_index - 1],
865
+ self.sigmas[self.step_index - 2],
866
+ )
867
+
868
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
869
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
870
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
871
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
872
+
873
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
874
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
875
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
876
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
877
+
878
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
879
+
880
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
881
+ r0, r1 = h_0 / h, h_1 / h
882
+ D0 = m0
883
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
884
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
885
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
886
+ if self.config.algorithm_type == "dpmsolver++":
887
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
888
+ x_t = (
889
+ (sigma_t / sigma_s0) * sample
890
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
891
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
892
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
893
+ )
894
+ elif self.config.algorithm_type == "dpmsolver":
895
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
896
+ x_t = (
897
+ (alpha_t / alpha_s0) * sample
898
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
899
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
900
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
901
+ )
902
+ return x_t
903
+
904
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
905
+ if schedule_timesteps is None:
906
+ schedule_timesteps = self.timesteps
907
+
908
+ index_candidates = (schedule_timesteps == timestep).nonzero()
909
+
910
+ if len(index_candidates) == 0:
911
+ step_index = len(self.timesteps) - 1
912
+ # The sigma index that is taken for the **very** first `step`
913
+ # is always the second index (or the last index if there is only 1)
914
+ # This way we can ensure we don't accidentally skip a sigma in
915
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
916
+ elif len(index_candidates) > 1:
917
+ step_index = index_candidates[1].item()
918
+ else:
919
+ step_index = index_candidates[0].item()
920
+
921
+ return step_index
922
+
923
+ def _init_step_index(self, timestep):
924
+ """
925
+ Initialize the step_index counter for the scheduler.
926
+ """
927
+
928
+ if self.begin_index is None:
929
+ if isinstance(timestep, torch.Tensor):
930
+ timestep = timestep.to(self.timesteps.device)
931
+ self._step_index = self.index_for_timestep(timestep)
932
+ else:
933
+ self._step_index = self._begin_index
934
+
935
+ def step(
936
+ self,
937
+ model_output: torch.Tensor,
938
+ timestep: int,
939
+ sample: torch.Tensor,
940
+ generator=None,
941
+ variance_noise: Optional[torch.Tensor] = None,
942
+ return_dict: bool = True,
943
+ ) -> Union[SchedulerOutput, Tuple]:
944
+ """
945
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
946
+ the multistep DPMSolver.
947
+
948
+ Args:
949
+ model_output (`torch.Tensor`):
950
+ The direct output from learned diffusion model.
951
+ timestep (`int`):
952
+ The current discrete timestep in the diffusion chain.
953
+ sample (`torch.Tensor`):
954
+ A current instance of a sample created by the diffusion process.
955
+ generator (`torch.Generator`, *optional*):
956
+ A random number generator.
957
+ variance_noise (`torch.Tensor`):
958
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
959
+ itself. Useful for methods such as [`LEdits++`].
960
+ return_dict (`bool`):
961
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
962
+
963
+ Returns:
964
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
965
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
966
+ tuple is returned where the first element is the sample tensor.
967
+
968
+ """
969
+ if self.num_inference_steps is None:
970
+ raise ValueError(
971
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
972
+ )
973
+
974
+ if self.step_index is None:
975
+ self._init_step_index(timestep)
976
+
977
+ # Improve numerical stability for small number of steps
978
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
979
+ self.config.euler_at_final
980
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
981
+ or self.config.final_sigmas_type == "zero"
982
+ )
983
+ lower_order_second = (
984
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
985
+ )
986
+
987
+ model_output = self.convert_model_output(model_output, sample=sample)
988
+ for i in range(self.config.solver_order - 1):
989
+ self.model_outputs[i] = self.model_outputs[i + 1]
990
+ self.model_outputs[-1] = model_output
991
+
992
+ # Upcast to avoid precision issues when computing prev_sample
993
+ sample = sample.to(torch.float32)
994
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
995
+ noise = randn_tensor(
996
+ model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
997
+ )
998
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
999
+ noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
1000
+ else:
1001
+ noise = None
1002
+
1003
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
1004
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
1005
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
1006
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
1007
+ else:
1008
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
1009
+
1010
+ if self.lower_order_nums < self.config.solver_order:
1011
+ self.lower_order_nums += 1
1012
+
1013
+ # Cast sample back to expected dtype
1014
+ prev_sample = prev_sample.to(model_output.dtype)
1015
+
1016
+ # upon completion increase step index by one
1017
+ self._step_index += 1
1018
+
1019
+ if not return_dict:
1020
+ return (prev_sample,)
1021
+
1022
+ return SchedulerOutput(prev_sample=prev_sample)
1023
+
1024
+ def add_noise(
1025
+ self,
1026
+ original_samples: torch.Tensor,
1027
+ noise: torch.Tensor,
1028
+ timesteps: torch.IntTensor,
1029
+ ) -> torch.Tensor:
1030
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
1031
+ # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
1032
+ # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
1033
+ alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
1034
+ sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
1035
+ timesteps = timesteps.to(original_samples.device)
1036
+ alpha_t = alpha_t[timesteps].flatten()
1037
+ while len(alpha_t.shape) < len(original_samples.shape):
1038
+ alpha_t = alpha_t.unsqueeze(-1)
1039
+
1040
+ sigma_t = sigma_t[timesteps].flatten()
1041
+ while len(sigma_t.shape) < len(original_samples.shape):
1042
+ sigma_t = sigma_t.unsqueeze(-1)
1043
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
1044
+ return noisy_samples
1045
+
1046
+ def get_velocity(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
1047
+ # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
1048
+ # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
1049
+ alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
1050
+ sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
1051
+
1052
+ timesteps = timesteps.to(original_samples.device)
1053
+ alpha_t = alpha_t[timesteps].flatten()
1054
+ while len(alpha_t.shape) < len(original_samples.shape):
1055
+ alpha_t = alpha_t.unsqueeze(-1)
1056
+
1057
+ sigma_t = sigma_t[timesteps].flatten()
1058
+ while len(sigma_t.shape) < len(original_samples.shape):
1059
+ sigma_t = sigma_t.unsqueeze(-1)
1060
+
1061
+ velocity = alpha_t * noise - sigma_t * original_samples
1062
+ return velocity
1063
+
1064
+ def __len__(self):
1065
+ return self.config.num_train_timesteps
vvembed/schedule/timestep_sampler.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class UniformSampler:
6
+ def __init__(self, timesteps = 1000):
7
+ self.timesteps = timesteps
8
+ def sample(self, batch_size, device):
9
+ return torch.randint(0, self.timesteps, (batch_size,), device=device)
10
+
11
+ class LogitNormalSampler:
12
+ def __init__(self, timesteps = 1000, m = 0, s = 1):
13
+ self.timesteps = timesteps
14
+ timesteps = torch.linspace(0, 1, timesteps)
15
+ logit = torch.log(timesteps / (1 - timesteps))
16
+ self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
17
+ def sample(self, batch_size, device):
18
+ return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
19
+
vvembed/scripts/__init__.py ADDED
File without changes
vvembed/scripts/convert_nnscaler_checkpoint_to_transformers.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ import argparse
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import re
9
+ import torch
10
+ from typing import Dict, List, Tuple
11
+
12
+ from modular.configuration_vibevoice import (
13
+ VibeVoiceConfig
14
+ )
15
+ from modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
16
+ from transformers.utils import logging
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+ def convert_vibevoice_nnscaler_checkpoint_to_hf(
21
+ checkpoint_path: str,
22
+ pytorch_dump_folder_path: str,
23
+ config_path: str = None,
24
+ ):
25
+ """
26
+ Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
27
+ Supports both regular checkpoints and tensor parallel checkpoints.
28
+ """
29
+
30
+ # Load regular checkpoint
31
+ logger.info(f"Loading regular checkpoint from {checkpoint_path}")
32
+ checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
33
+
34
+ # config = checkpoint['train_args']
35
+ init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
36
+ pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
37
+
38
+ init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
39
+ if init_config_path.exists():
40
+ logger.info(f"Loading initial config from {init_config_path}")
41
+ with open(init_config_path, 'r') as f:
42
+ init_config = json.load(f)
43
+ else:
44
+ raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
45
+
46
+ tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
47
+ logger.info(f"Tie word embeddings: {tie_word_embeddings}")
48
+
49
+ init_config['decoder_config']['use_cache'] = True
50
+ config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
51
+
52
+ # # Extract the model state dict
53
+ model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
54
+ if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
55
+ # If not tying weights, we need to add the lm_head weight separately
56
+ model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
57
+
58
+ # Override with provided config if available
59
+ if config_path:
60
+ logger.info(f"Loading config from {config_path}")
61
+ with open(config_path, 'r') as f:
62
+ config_dict = json.load(f)
63
+ config = VibeVoiceConfig.from_dict(config_dict)
64
+
65
+ # Set the default dtype to bfloat16 before creating the model
66
+ original_dtype = torch.get_default_dtype()
67
+ torch.set_default_dtype(torch.bfloat16)
68
+
69
+ # Create the HuggingFace model
70
+ logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
71
+ model = VibeVoiceForConditionalGeneration(config)
72
+
73
+ # Restore original dtype
74
+ torch.set_default_dtype(original_dtype)
75
+
76
+ # Load the state dict
77
+ logger.info("Loading weights into model")
78
+ missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
79
+
80
+ if missing_keys:
81
+ logger.warning(f"Missing keys: {missing_keys}")
82
+ if unexpected_keys:
83
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
84
+
85
+ # Create output directory
86
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
87
+
88
+ # Save the model and config
89
+ logger.info(f"Saving model to {pytorch_dump_folder_path}")
90
+
91
+ # Save config
92
+ config.save_pretrained(pytorch_dump_folder_path)
93
+
94
+ # Save VibeVoiceProcessor configuration
95
+ logger.info("Saving VibeVoiceProcessor configuration")
96
+ processor_config = {
97
+ "processor_class": "VibeVoiceProcessor",
98
+ "speech_tok_compress_ratio": 3200,
99
+ "db_normalize": True,
100
+ # Audio processor configuration
101
+ "audio_processor": {
102
+ "feature_extractor_type": "VibeVoiceTokenizerProcessor",
103
+ "sampling_rate": 24000,
104
+ "normalize_audio": True,
105
+ "target_dB_FS": -25,
106
+ "eps": 1e-6,
107
+ },
108
+ "language_model_pretrained_name": pretrained_name,
109
+ }
110
+
111
+ processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
112
+ with open(processor_config_path, 'w') as f:
113
+ json.dump(processor_config, f, indent=2)
114
+ logger.info(f"Saved processor config to {processor_config_path}")
115
+
116
+ # Save model with sharding
117
+ # save_pretrained handles tied weights automatically
118
+ logger.info("Saving model weights with sharding...")
119
+ model.save_pretrained(
120
+ pytorch_dump_folder_path,
121
+ max_shard_size="2GB", # Set maximum size for each shard
122
+ safe_serialization=True # Ensure saving in .safetensors format
123
+ )
124
+ logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
125
+
126
+ logger.info("Conversion complete!")
127
+
128
+ # Verify the saved model can be loaded
129
+ logger.info("Verifying saved model...")
130
+ loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
131
+ logger.info("Model successfully loaded from saved checkpoint!")
132
+
133
+ def main():
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument(
136
+ "--nnscaler_checkpoint_path",
137
+ type=str,
138
+ required=True,
139
+ help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
140
+ "provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
141
+ "and the script will automatically detect and merge all parts.",
142
+ )
143
+ parser.add_argument(
144
+ "--pytorch_dump_folder_path",
145
+ type=str,
146
+ required=True,
147
+ help="Path to the output PyTorch model directory",
148
+ )
149
+ parser.add_argument(
150
+ "--config_path",
151
+ type=str,
152
+ default=None,
153
+ help="Optional path to a config JSON file to override extracted config",
154
+ )
155
+
156
+ args = parser.parse_args()
157
+
158
+ convert_vibevoice_nnscaler_checkpoint_to_hf(
159
+ args.nnscaler_checkpoint_path,
160
+ args.pytorch_dump_folder_path,
161
+ args.config_path,
162
+ )
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()