tianfengping.tfp commited on
Commit
efacc59
·
1 Parent(s): 26dff53

modify emotion type to english

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +195 -255
  2. cosyvoice_rodis/__init__.py +2 -0
  3. cosyvoice_rodis/__pycache__/__init__.cpython-310.pyc +0 -0
  4. cosyvoice_rodis/__pycache__/__init__.cpython-312.pyc +0 -0
  5. cosyvoice_rodis/__pycache__/__init__.cpython-38.pyc +0 -0
  6. cosyvoice_rodis/__pycache__/__init__.cpython-39.pyc +0 -0
  7. cosyvoice_rodis/bin/average_model.py +91 -0
  8. cosyvoice_rodis/bin/export_jit.py +73 -0
  9. cosyvoice_rodis/bin/export_onnx.py +110 -0
  10. cosyvoice_rodis/bin/inference.py +114 -0
  11. cosyvoice_rodis/bin/train.py +159 -0
  12. cosyvoice_rodis/cli/__init__.py +2 -0
  13. cosyvoice_rodis/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  14. cosyvoice_rodis/cli/__pycache__/__init__.cpython-312.pyc +0 -0
  15. cosyvoice_rodis/cli/__pycache__/__init__.cpython-38.pyc +0 -0
  16. cosyvoice_rodis/cli/__pycache__/__init__.cpython-39.pyc +0 -0
  17. cosyvoice_rodis/cli/__pycache__/cosyvoice.cpython-310.pyc +0 -0
  18. cosyvoice_rodis/cli/__pycache__/cosyvoice.cpython-312.pyc +0 -0
  19. cosyvoice_rodis/cli/__pycache__/cosyvoice.cpython-38.pyc +0 -0
  20. cosyvoice_rodis/cli/__pycache__/cosyvoice.cpython-39.pyc +0 -0
  21. cosyvoice_rodis/cli/__pycache__/frontend.cpython-310.pyc +0 -0
  22. cosyvoice_rodis/cli/__pycache__/frontend.cpython-38.pyc +0 -0
  23. cosyvoice_rodis/cli/__pycache__/frontend.cpython-39.pyc +0 -0
  24. cosyvoice_rodis/cli/__pycache__/model.cpython-310.pyc +0 -0
  25. cosyvoice_rodis/cli/__pycache__/model.cpython-38.pyc +0 -0
  26. cosyvoice_rodis/cli/__pycache__/model.cpython-39.pyc +0 -0
  27. cosyvoice_rodis/cli/cosyvoice.py +114 -0
  28. cosyvoice_rodis/cli/frontend.py +192 -0
  29. cosyvoice_rodis/cli/model.py +257 -0
  30. cosyvoice_rodis/dataset/__init__.py +2 -0
  31. cosyvoice_rodis/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  32. cosyvoice_rodis/dataset/__pycache__/__init__.cpython-38.pyc +0 -0
  33. cosyvoice_rodis/dataset/__pycache__/__init__.cpython-39.pyc +0 -0
  34. cosyvoice_rodis/dataset/__pycache__/dataset.cpython-310.pyc +0 -0
  35. cosyvoice_rodis/dataset/__pycache__/dataset.cpython-38.pyc +0 -0
  36. cosyvoice_rodis/dataset/__pycache__/processor.cpython-310.pyc +0 -0
  37. cosyvoice_rodis/dataset/__pycache__/processor.cpython-38.pyc +0 -0
  38. cosyvoice_rodis/dataset/__pycache__/processor.cpython-39.pyc +0 -0
  39. cosyvoice_rodis/dataset/dataset.py +163 -0
  40. cosyvoice_rodis/dataset/processor.py +427 -0
  41. cosyvoice_rodis/flow/__pycache__/decoder.cpython-310.pyc +0 -0
  42. cosyvoice_rodis/flow/__pycache__/decoder.cpython-38.pyc +0 -0
  43. cosyvoice_rodis/flow/__pycache__/decoder.cpython-39.pyc +0 -0
  44. cosyvoice_rodis/flow/__pycache__/flow.cpython-310.pyc +0 -0
  45. cosyvoice_rodis/flow/__pycache__/flow.cpython-38.pyc +0 -0
  46. cosyvoice_rodis/flow/__pycache__/flow.cpython-39.pyc +0 -0
  47. cosyvoice_rodis/flow/__pycache__/flow_matching.cpython-310.pyc +0 -0
  48. cosyvoice_rodis/flow/__pycache__/flow_matching.cpython-38.pyc +0 -0
  49. cosyvoice_rodis/flow/__pycache__/flow_matching.cpython-39.pyc +0 -0
  50. cosyvoice_rodis/flow/__pycache__/length_regulator.cpython-310.pyc +0 -0
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import gradio as gr
2
  import sys, os
 
 
3
  import torch
4
  from cosyvoice.utils.file_utils import load_wav
5
- from tts_model.base_model.cosyvoice import CosyVoice as CosyVoiceTTS_base
6
- from tts_model.sft_model.cosyvoice import CosyVoice as CosyVoiceTTS_sft
7
  from uuid import uuid1
8
  import uuid
9
- from tts_model.speaker_minus.cosyvoice import CosyVoice as CosyVoiceTTS_speakerminus
10
- # from tts_model.model_cosy2_instruct import CosyVoiceTTS as CosyVoiceTTS_cosy2
11
  from pydub import AudioSegment
12
  import tempfile
13
  import soundfile as sf
@@ -17,20 +16,56 @@ import random
17
  import numpy
18
 
19
 
20
- from pydub import AudioSegment
21
- # AudioSegment.converter = "/mnt/by079416/fengping/ffmpeg-7.0.2-amd64-static/ffmpeg"
22
- # AudioSegment.ffprobe = "/mnt/by079416/fengping/ffmpeg-7.0.2-amd64-static/ffprobe"
 
 
 
 
 
 
 
 
 
 
23
 
24
- ffmpeg_path = os.path.expanduser("/mnt/by079416/fengping/ffmpeg-7.0.2-amd64-static/ffmpeg/")
25
- os.environ["PATH"] += os.pathsep + ffmpeg_path
26
 
27
  sys.path.append('third_party/Matcha-TTS')
28
  os.system('export PYTHONPATH=third_party/Matcha-TTS')
29
 
30
- tts_base = CosyVoiceTTS_base(model_dir="./pretrained_models/CosyVoice-300M/")
31
- tts_speakerminus = CosyVoiceTTS_speakerminus(model_dir="./pretrained_models/CosyVoice-300M-speakerminus/")
32
- # tts_cosy2_instruct = CosyVoiceTTS_cosy2(model_path="./pretrained_models/CosyVoice-300M-Instruct_cosy2/")
33
- tts_sft = CosyVoiceTTS_base(model_dir="./pretrained_models/CosyVoice-300M-SFT/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  text_prompt = {
36
  "翟佳宁": "这个节目就是把四个男嘉宾,四个女嘉宾放一个大别墅里让他们朝夕相处一整个月,月末选择心动的彼此。",
@@ -58,7 +93,7 @@ audio_prompt = {
58
  "赵晓卉": "zhaoxiaohui",
59
  "徐志胜": "xuzhisheng"
60
  }
61
- audio_prompt_path = "/mnt/by079416/fengping/CosyVoice2/talk_show_prompt/"
62
 
63
  def load_audio_and_convert_to_16bit(file_path, target_sample_rate=16000):
64
  audio = AudioSegment.from_file(file_path)
@@ -88,11 +123,11 @@ def convert_audio_with_sox(input_file, output_file, target_sample_rate=16000):
88
  # ]
89
  command = [
90
  './ffmpeg-7.0.2-amd64-static/ffmpeg',
91
- '-i', input_file, # 必须显式指定 -i 标记输入文件
92
- '-ar', str(target_sample_rate), # 设置音频采样率
93
- '-ac', '1', # 设置单通道 (mono)
94
- '-b:a', '16k', # 设置音频比特率为 16kbps
95
- '-f', 'wav', # 强制输出格式为 WAV
96
  output_file
97
  ]
98
 
@@ -103,88 +138,45 @@ def convert_audio_with_sox(input_file, output_file, target_sample_rate=16000):
103
 
104
  os.makedirs("./tmp", exist_ok=True)
105
 
106
- def generate_speech_sft(tts_text, speaker):
107
- # if not ref_audio and not ref_text:
108
- # ref_text = text_prompt.get(speaker, "")
109
- # ref_audio = os.path.join(audio_prompt_path, f"{audio_prompt.get(speaker)}.wav")
110
- # else:
111
- # random_int = random.randint(0, 90)
112
- # soxsed_ref_audio = "/tmp/{random_int}_ref.wav"
113
- # convert_audio_with_sox(ref_audio, soxsed_ref_audio)
114
- # ref_audio = load_wav(ref_audio, 16000)
115
- # # ref_audio, target_sample_rate = load_audio_and_convert_to_16bit(ref_audio)
116
-
117
- sample_rate, full_audio = tts_sft.inference_sft(
118
- tts_text,
119
- spk_id = speaker
120
- # instruct = instruct
121
- # prompt_text = ref_text,
122
- # prompt_speech_16k = ref_audio,
123
- # speed=speed,
124
- # speaker=speaker,
125
- # emotion=emotion,
126
-
127
- )
128
- full_audio = full_audio.astype(np.float32)
129
- if full_audio.max() > 1.0 or full_audio.min() < -1.0:
130
- full_audio /= 32768.0 # int16 → [-1,1]
131
-
132
- print("dtype:", full_audio.dtype,
133
- "shape:", full_audio.shape,
134
- "max:", full_audio.max(), "min:", full_audio.min())
135
-
136
- out_path = os.path.join("./tmp", f"{uuid.uuid4().hex}.wav")
137
-
138
- audio_segment = AudioSegment(
139
- full_audio.tobytes(),
140
- frame_rate=sample_rate,
141
- sample_width=full_audio.dtype.itemsize,
142
- channels=1
143
- )
144
- audio_segment.export(out_path, format="wav")
145
-
146
- print(">>> audio path:", os.path.abspath(out_path))
147
- # return out_path
148
- return (sample_rate, full_audio)
149
- # with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
150
- # output_audio_path = temp_audio_file.name
151
- # audio_segment = AudioSegment(
152
- # full_audio.tobytes(),
153
- # frame_rate=sample_rate,
154
- # sample_width=full_audio.dtype.itemsize,
155
- # channels=1
156
- # )
157
- # audio_segment.export(output_audio_path, format="wav")
158
- # print(f"Audio saved to {output_audio_path}")
159
-
160
- # return output_audio_path
161
-
162
- # def generate_speech_sft(tts_text, speaker):
163
- # sr = 22050 # 采样率
164
- # t = np.linspace(0, 1, sr, dtype=np.float32)
165
- # audio_np = np.sin(2 * np.pi * 440 * t) # 1 秒 440 Hz 正弦波
166
- # return (sr, audio_np)
167
-
168
- def generate_speech_base(tts_text, speed, speaker, ref_audio, ref_text):
169
  # import pdb;pdb.set_trace()
170
  if not ref_audio and not ref_text:
171
  ref_text = text_prompt.get(speaker, "")
172
- ref_audio = os.path.join(audio_prompt_path, f"{audio_prompt.get(speaker)}.wav")
173
- ref_audio = load_wav(ref_audio, 16000)
 
 
 
174
  else:
175
- random_int = random.randint(0, 90000)
176
- soxsed_ref_audio = f"/tmp/{random_int}_ref.wav"
177
  convert_audio_with_sox(ref_audio, soxsed_ref_audio)
178
- ref_audio = load_wav(soxsed_ref_audio, 16000)
179
- # ref_audio, target_sample_rate = load_audio_and_convert_to_16bit(ref_audio)
180
- sample_rate, full_audio = tts_base.inference_zero_shot(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  tts_text,
182
  prompt_text = ref_text,
183
- prompt_speech_16k = ref_audio,
184
- speed=speed,
185
  # speaker=speaker,
186
- # emotion=emotion,
187
-
 
 
 
 
188
  )
189
  print("sample_rate:", sample_rate, "full_audio:", full_audio.min(), full_audio.max())
190
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
@@ -197,37 +189,45 @@ def generate_speech_base(tts_text, speed, speaker, ref_audio, ref_text):
197
  )
198
  audio_segment.export(output_audio_path, format="wav")
199
  print(f"Audio saved to {output_audio_path}")
200
-
201
  return output_audio_path
202
 
203
- def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_text):
 
204
  # import pdb;pdb.set_trace()
205
  if not ref_audio and not ref_text:
206
  ref_text = text_prompt.get(speaker, "")
207
- ref_audio = os.path.join(audio_prompt_path, f"{audio_prompt.get(speaker)}.wav")
 
 
 
 
208
  else:
209
  random_int = random.randint(0, 90)
210
- soxsed_ref_audio = f"/tmp/{random_int}_ref.wav"
211
  convert_audio_with_sox(ref_audio, soxsed_ref_audio)
212
- # print("output_file:", output_file)
213
- # ref_audio, target_sample_rate = load_audio_and_convert_to_16bit(ref_audio)
 
 
214
  ref_audio = load_wav(ref_audio, 16000)
215
- # if key == "Surprise":
216
- # emotion_info = torch.load("/mnt/by079416/surprise.pt")
217
- # if key == "Sad":
218
- # emotion_info = torch.load("/mnt/by079416/sad.pt")
219
- # if key == "Angry":
220
- # emotion_info = torch.load("/mnt/by079416/angry.pt")
221
- # if key == "Happy":
222
- # emotion_info = torch.load("/mnt/by079416/happy.pt")
223
-
224
- emotion_info = torch.load("/mnt/by079416/fengping/CosyVoice2/embedding_info.pt")["0002"][key]
225
- sample_rate, full_audio = tts_speakerminus.inference_zero_shot(
 
 
226
  tts_text,
227
  prompt_text = ref_text,
228
  # speaker=speaker,
229
  prompt_speech_16k = ref_audio,
230
- key = key,
231
  emotion_speakerminus=emotion_info,
232
  # ref_audio = ref_audio,
233
  speed=speed
@@ -328,7 +328,7 @@ body {
328
  flex-grow: 1; /* 占据所有剩余空间 */
329
  }
330
 
331
- /* 6. 标题文本样式 */
332
  #header-title h1 {
333
  color: white;
334
  font-size: 28px;
@@ -561,38 +561,29 @@ input[type="text"]:focus, textarea:focus {
561
  }
562
  """
563
 
564
- # 创建界面
565
- logo_path = "/mnt/by079416/fengping/CosyVoice2/logo2.png"
566
 
567
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
568
  with gr.Column(elem_classes="header"):
569
- # 行容器,用于左右布局
570
  with gr.Row(elem_id="header-row", variant="compact"):
571
- # 左侧:Logo
572
  gr.Image(value=logo_path,
573
  elem_id="logo-container",
574
  show_label=False,
575
  show_download_button=False,
576
- show_share_button=False) # 隐藏分享按钮
577
 
578
- # 右侧:标题区域
579
  with gr.Column(elem_id="title-area"):
580
- gr.Markdown("# 🎤 Marco-Voice 语音合成系统", elem_id="header-title")
581
 
582
- # gr.Markdown("")
583
-
584
- # 标签页
585
  with gr.Tabs(elem_classes="tabs") as tabs:
586
- # Tab 1: 音色克隆
587
- with gr.TabItem("🎭 音色克隆", id=0):
588
  with gr.Row():
589
  with gr.Column(scale=2, elem_classes="input-section"):
590
- gr.Markdown("### 输入设置")
591
  tts_text_v1 = gr.Textbox(
592
  lines=3,
593
- placeholder="请输入要合成的文本内容...",
594
- label="合成文本",
595
- value="大家好,欢迎使用Marco Voice语音合成系统,这是一个强大的语音生成工具。"
596
  )
597
 
598
  with gr.Row():
@@ -602,174 +593,130 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
602
  maximum=2.0,
603
  value=1.0,
604
  step=0.1,
605
- label="语速控制",
606
- interactive=True
 
 
 
 
 
607
  )
 
 
608
  with gr.Column():
609
  speaker_v1 = gr.Dropdown(
610
  choices=names,
611
  value="徐志胜",
612
- label="预设音色",
613
- info="选择脱口秀演员音色"
614
- )
615
- # [tts_text_v1, speed_v1, speaker_v1, emotion, ref_audio_v1, ref_text_v1]
616
- with gr.Accordion("高级设置", open=False, elem_classes="accordion"):
617
- gr.Markdown("上传3-10秒清晰人声作为参考音频")
618
- with gr.Row():
619
- ref_audio_v1 = gr.Audio(
620
- type="filepath",
621
- label="上传参考音频",
622
- elem_classes="audio-upload"
623
- )
624
- ref_text_v1 = gr.Textbox(
625
- lines=2,
626
- placeholder="参考音频对应的文本...",
627
- label="参考文本"
628
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
  gr.Markdown("""
631
  <div class="model-info">
632
- <p><span class="info-icon">ℹ️</span> <strong>模型说明:</strong> 此模型使用零样本音色克隆技术,只需3-10秒参考音频即可模仿目标音色。</p>
 
633
  </div>
634
  """)
635
 
636
  with gr.Column(scale=1, elem_classes="output-section"):
637
- gr.Markdown("### 输出结果")
638
- tts_base_output = gr.Audio(
639
  type="filepath",
640
- label="生成语音",
641
- elem_id="tts_output_audio",
642
  interactive=False
643
  )
644
- tts_base_button = gr.Button(
645
- "🚀 生成语音",
646
  variant="primary",
647
  elem_classes="btn-generate"
648
  )
649
  gr.Examples(
650
  examples=[
651
- ["大家好,欢迎使用Marco-Voice语音合成系统,这是一个强大的语音生成工具。", "徐志胜"],
652
- ["科技改变生活,创新引领未来。人工智能正在深刻改变我们的世界。", "李雪琴"],
653
- ["在这个充满机遇的时代,我们要勇于探索,敢于创新,不断突破自我。", "范志毅"]
654
- ],
655
- inputs=[tts_text_v1, speaker_v1],
656
- label="示例文本"
657
- )
658
-
659
- # Tab 2: 多语种合成
660
- with gr.TabItem("🌍 多语种合成", id=1):
661
- with gr.Row():
662
- with gr.Column(scale=2, elem_classes="input-section"):
663
- gr.Markdown("### 输入设置")
664
- tts_text_sft = gr.Textbox(
665
- lines=3,
666
- placeholder="请输入要合成的文本内容...",
667
- label="合成文本",
668
- value="Hello, welcome to Marco-Voice text-to-speech system. This is a powerful multilingual TTS tool."
669
- )
670
-
671
- speaker_sft = gr.Dropdown(
672
- choices=["中文男", "中文女", "英文男", "英文女", "韩语女", "日语男"],
673
- value="英文男",
674
- label="说话人",
675
- info="选择语言和性别"
676
- )
677
-
678
- gr.Markdown("""
679
- <div class="model-info">
680
- <p><span class="info-icon">ℹ️</span> <strong>模型说明:</strong> 此模型支持多个语种,无需参考音频即可生成自然语音。</p>
681
- <p><span class="info-icon">💡</span> <strong>使用技巧:</strong> 输入文本语言应与选择的说话人语言一致以获得最佳效果。</p>
682
- </div>
683
- """)
684
-
685
- with gr.Column(scale=1, elem_classes="output-section"):
686
- gr.Markdown("### 输出结果")
687
- tts_sft_output = gr.Audio(
688
- type="numpy",
689
- label="生成语音",
690
- interactive=False
691
- )
692
- tts_sft_button = gr.Button(
693
- "🚀 生成语音",
694
- variant="primary",
695
- elem_classes="btn-generate"
696
- )
697
- gr.Examples(
698
- examples=[
699
- ["Hello, welcome to Marco-Voice text-to-speech system.", "英文男"],
700
- ["こんにちは、Marco-Voiceテキスト読み上げシステムへようこそ。", "日语男"],
701
- ["안녕하세요, Marco-Voice 텍스트 음성 변환 시스템에 오신 것을 환영합니다.", "韩语女"]
702
  ],
703
- inputs=[tts_text_sft, speaker_sft],
704
- label="多语种示例"
705
  )
706
-
707
- # Tab 3: 情感控制
708
- with gr.TabItem("😄 情感控制", id=2):
709
  with gr.Row():
710
  with gr.Column(scale=2, elem_classes="input-section"):
711
- gr.Markdown("### 输入设置")
712
- tts_text_v3 = gr.Textbox(
713
  lines=3,
714
- placeholder="请输入要合成的文本内容...",
715
- label="合成文本",
716
  value="这真是太令人兴奋了!我们刚刚完成了一个重大突破!"
717
  )
718
 
719
  with gr.Row():
720
  with gr.Column():
721
- speed_v3 = gr.Slider(
722
  minimum=0.5,
723
  maximum=2.0,
724
  value=1.0,
725
  step=0.1,
726
- label="语速控制"
727
  )
728
  with gr.Column():
729
- emotion_v3 = gr.Radio(
730
- choices=["Angry", "Happy", "Surprise", "Sad"],
731
  value="Happy",
732
- label="情感选择"
733
  )
734
 
735
  with gr.Row():
736
  with gr.Column():
737
- speaker_v3 = gr.Dropdown(
738
  choices=names,
739
  value="徐志胜",
740
- label="预设音色"
741
  )
742
  with gr.Column():
743
- gr.Markdown("### 或使用自定义音色")
744
- with gr.Accordion("上传参考音频", open=False, elem_classes="accordion"):
745
- gr.Markdown("上传3-10秒清晰人声作为参考音频")
746
- ref_audio_v3 = gr.Audio(
747
  type="filepath",
748
- label="上传参考音频",
749
  elem_classes="audio-upload"
750
  )
751
- ref_text_v3 = gr.Textbox(
752
  lines=2,
753
- placeholder="参考音频对应的文本...",
754
- label="参考文本"
755
  )
756
 
757
  gr.Markdown("""
758
  <div class="model-info">
759
- <p><span class="info-icon">ℹ️</span> <strong>模型说明:</strong> 此模型在音色克隆基础上增加了情感控制能力,可生成带有特定情感的语音。</p>
760
- <p><span class="info-icon">💡</span> <strong>使用技巧:</strong> 情感表达效果与文本内容相关,请确保文本与所选情感匹配。</p>
761
  </div>
762
  """)
763
 
764
  with gr.Column(scale=1, elem_classes="output-section"):
765
- gr.Markdown("### 输出结果")
766
- tts_v3_output = gr.Audio(
767
  type="filepath",
768
- label="生成语音",
769
  interactive=False
770
  )
771
- tts_v3_button = gr.Button(
772
- "🚀 生成语音",
773
  variant="primary",
774
  elem_classes="btn-generate"
775
  )
@@ -779,35 +726,28 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
779
  ["我简直不敢相信!这怎么可能发生?", "Surprise", "李雪琴"],
780
  ["这太让人失望了,我们所有的努力都白费了。", "Sad", "范志毅"]
781
  ],
782
- inputs=[tts_text_v3, emotion_v3, speaker_v3],
783
- label="情感示例"
784
  )
785
 
786
- # 页脚
787
  gr.Markdown("""
788
  <div class="footer">
789
- <p>Marco-Voice 语音合成系统 v1.0 | 基于优秀的tts 模型 | 技术支持: tech@marco-voice.com</p>
790
- <p>注意: 生成内容仅用于技术演示,请勿用于非法用途</p>
791
  </div>
792
  """)
793
 
794
- # 绑定事件 # tts_text, speed, speaker, emotion, ref_audio, ref_text
795
- tts_base_button.click(
796
- fn=generate_speech_base,
797
- inputs=[tts_text_v1, speed_v1, speaker_v1, ref_audio_v1, ref_text_v1],
798
- outputs=tts_base_output
799
- )
800
 
801
- tts_sft_button.click(
802
- fn=generate_speech_sft,
803
- inputs=[tts_text_sft, speaker_sft],
804
- outputs=tts_sft_output
805
  )
806
  # tts_text, speed, speaker, key, ref_audio, ref_text
807
- tts_v3_button.click(
808
- fn=generate_speech_speakerminus,
809
- inputs=[tts_text_v3, speed_v3, speaker_v3, emotion_v3, ref_audio_v3, ref_text_v3],
810
- outputs=tts_v3_output
811
  )
812
 
813
  if __name__ == "__main__":
@@ -815,5 +755,5 @@ if __name__ == "__main__":
815
  server_name="0.0.0.0",
816
  server_port=10163,
817
  share=True,
818
- favicon_path="/mnt/by079416/fengping/CosyVoice2/logo.png"
819
  )
 
1
  import gradio as gr
2
  import sys, os
3
+ from huggingface_hub import snapshot_download, hf_hub_download
4
+
5
  import torch
6
  from cosyvoice.utils.file_utils import load_wav
 
 
7
  from uuid import uuid1
8
  import uuid
9
+ from cosyvoice_rodis.cli.cosyvoice import CosyVoice as CosyVoiceTTS_speakerminus
 
10
  from pydub import AudioSegment
11
  import tempfile
12
  import soundfile as sf
 
16
  import numpy
17
 
18
 
19
+ import imageio_ffmpeg
20
+
21
+ ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
22
+ print(f"FFmpeg path: {ffmpeg_path}")
23
+ user_bin = os.path.expanduser("~/bin")
24
+ if not os.path.exists(user_bin):
25
+ os.makedirs(user_bin)
26
+ ffmpeg_link = os.path.join(user_bin, "ffmpeg")
27
+ if os.path.exists(ffmpeg_link):
28
+ os.remove(ffmpeg_link)
29
+ os.symlink(ffmpeg_path, ffmpeg_link)
30
+ print(f"create symbolic link: {ffmpeg_link}")
31
+ os.environ["PATH"] = f"{user_bin}:{os.environ.get('PATH', '')}"
32
 
 
 
33
 
34
  sys.path.append('third_party/Matcha-TTS')
35
  os.system('export PYTHONPATH=third_party/Matcha-TTS')
36
 
37
+ assets_dir = snapshot_download(
38
+ repo_id="tienfeng/prompt",
39
+ repo_type="dataset",
40
+ )
41
+
42
+ from huggingface_hub import hf_hub_download
43
+
44
+ model_repo_id = "AIDC-AI/Marco-Voice"
45
+ local_model = snapshot_download(
46
+ repo_id=model_repo_id,
47
+ repo_type="model"
48
+ # token=os.getenv("HF_TOKEN")
49
+ )
50
+
51
+ local_model_path = os.path.join(local_model, "marco_voice")
52
+ local_model_path_enhenced = os.path.join(local_model, "marco_voice_enhenced")
53
+
54
+
55
+ logo_path = hf_hub_download(
56
+ repo_id="tienfeng/prompt",
57
+ filename="logo2.png",
58
+ repo_type="dataset",
59
+ )
60
+
61
+ logo_path2 = hf_hub_download(
62
+ repo_id="tienfeng/prompt",
63
+ filename="logo.png",
64
+ repo_type="dataset",
65
+ )
66
+
67
+ tts_speakerminus = CosyVoiceTTS_speakerminus(model_dir=local_model_path)
68
+ tts_sft = CosyVoiceTTS_speakerminus(model_dir=local_model_path_enhenced)
69
 
70
  text_prompt = {
71
  "翟佳宁": "这个节目就是把四个男嘉宾,四个女嘉宾放一个大别墅里让他们朝夕相处一整个月,月末选择心动的彼此。",
 
93
  "赵晓卉": "zhaoxiaohui",
94
  "徐志胜": "xuzhisheng"
95
  }
96
+ audio_prompt_path = assets_dir
97
 
98
  def load_audio_and_convert_to_16bit(file_path, target_sample_rate=16000):
99
  audio = AudioSegment.from_file(file_path)
 
123
  # ]
124
  command = [
125
  './ffmpeg-7.0.2-amd64-static/ffmpeg',
126
+ '-i', input_file,
127
+ '-ar', str(target_sample_rate),
128
+ '-ac', '1',
129
+ '-b:a', '16k',
130
+ '-f', 'wav',
131
  output_file
132
  ]
133
 
 
138
 
139
  os.makedirs("./tmp", exist_ok=True)
140
 
141
+ def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # import pdb;pdb.set_trace()
143
  if not ref_audio and not ref_text:
144
  ref_text = text_prompt.get(speaker, "")
145
+ speaker_audio_name = audio_prompt.get(speaker)
146
+ if speaker_audio_name:
147
+ ref_audio = os.path.join(audio_prompt_path, f"{speaker_audio_name}.wav")
148
+ else:
149
+ raise ValueError(f"Speaker '{speaker}' not found in audio_prompt dictionary")
150
  else:
151
+ random_int = random.randint(0, 90)
152
+ soxsed_ref_audio = f"./tmp/{random_int}_ref.wav"
153
  convert_audio_with_sox(ref_audio, soxsed_ref_audio)
154
+ ref_audio = soxsed_ref_audio
155
+
156
+ if not ref_audio:
157
+ raise ValueError("Reference audio is required but not provided")
158
+ ref_audio = load_wav(ref_audio, 16000)
159
+ emo = {"Sad": "伤心", "Fearful": "恐惧", "Happy": "快乐", "Surprise": "惊喜", "Angry": "生气", "Jolliest": "戏谑"}
160
+ # key="快乐"
161
+ if key in ["Angry", "Surprise", "Happy"]:
162
+ emotion_info = torch.load("./emotion_info.pt")["male005"][key]
163
+ elif key in ["Sad"]:
164
+ emotion_info = torch.load("./emotion_info.pt")["female005"][key]
165
+ elif key in ["Fearful"]:
166
+ emotion_info = torch.load("./emotion_info.pt")["female003"][key]
167
+ else:
168
+ emotion_info = torch.load("./emotion_info.pt")["male005"][key]
169
+
170
+ sample_rate, full_audio = tts_sft.inference_zero_shot(
171
  tts_text,
172
  prompt_text = ref_text,
 
 
173
  # speaker=speaker,
174
+ prompt_speech_16k = ref_audio,
175
+ key = emo.get(key),
176
+ emotion_speakerminus=emotion_info,
177
+ # ref_audio = ref_audio,
178
+ speed=speed
179
+
180
  )
181
  print("sample_rate:", sample_rate, "full_audio:", full_audio.min(), full_audio.max())
182
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
 
189
  )
190
  audio_segment.export(output_audio_path, format="wav")
191
  print(f"Audio saved to {output_audio_path}")
 
192
  return output_audio_path
193
 
194
+
195
+ def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
196
  # import pdb;pdb.set_trace()
197
  if not ref_audio and not ref_text:
198
  ref_text = text_prompt.get(speaker, "")
199
+ speaker_audio_name = audio_prompt.get(speaker)
200
+ if speaker_audio_name:
201
+ ref_audio = os.path.join(audio_prompt_path, f"{speaker_audio_name}.wav")
202
+ else:
203
+ raise ValueError(f"Speaker '{speaker}' not found in audio_prompt dictionary")
204
  else:
205
  random_int = random.randint(0, 90)
206
+ soxsed_ref_audio = f"./tmp/{random_int}_ref.wav"
207
  convert_audio_with_sox(ref_audio, soxsed_ref_audio)
208
+ ref_audio = soxsed_ref_audio
209
+
210
+ if not ref_audio:
211
+ raise ValueError("Reference audio is required but not provided")
212
  ref_audio = load_wav(ref_audio, 16000)
213
+
214
+ emo = {"Sad": "伤心", "Fearful": "恐惧", "Happy": "快乐", "Surprise": "惊喜", "Angry": "生气", "Jolliest": "戏谑"}
215
+ # key="快乐"
216
+ if key in ["Angry", "Surprise", "Happy"]:
217
+ emotion_info = torch.load("./emotion_info.pt")["male005"][key]
218
+ elif key in ["Sad"]:
219
+ emotion_info = torch.load("./emotion_info.pt")["female005"][key]
220
+ elif key in ["Fearful"]:
221
+ emotion_info = torch.load("./emotion_info.pt")["female003"][key]
222
+ else:
223
+ emotion_info = torch.load("./emotion_info.pt")["male005"][key]
224
+
225
+ sample_rate, full_audio = tts_sft.inference_zero_shot(
226
  tts_text,
227
  prompt_text = ref_text,
228
  # speaker=speaker,
229
  prompt_speech_16k = ref_audio,
230
+ key = emo.get(key),
231
  emotion_speakerminus=emotion_info,
232
  # ref_audio = ref_audio,
233
  speed=speed
 
328
  flex-grow: 1; /* 占据所有剩余空间 */
329
  }
330
 
331
+ /* 6. title */
332
  #header-title h1 {
333
  color: white;
334
  font-size: 28px;
 
561
  }
562
  """
563
 
 
 
564
 
565
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
566
  with gr.Column(elem_classes="header"):
 
567
  with gr.Row(elem_id="header-row", variant="compact"):
 
568
  gr.Image(value=logo_path,
569
  elem_id="logo-container",
570
  show_label=False,
571
  show_download_button=False,
572
+ show_share_button=False)
573
 
 
574
  with gr.Column(elem_id="title-area"):
575
+ gr.Markdown("# 🎤 Marco-Voice ", elem_id="header-title")
576
 
 
 
 
577
  with gr.Tabs(elem_classes="tabs") as tabs:
578
+ with gr.TabItem("😄 Control of emotion", id=0):
 
579
  with gr.Row():
580
  with gr.Column(scale=2, elem_classes="input-section"):
581
+ gr.Markdown("### Input Settings")
582
  tts_text_v1 = gr.Textbox(
583
  lines=3,
584
+ placeholder="Enter the text content you want to compose...",
585
+ label="Synthesizing text",
586
+ value="这真是太令人兴奋了!我们刚刚完成了一个重大突破!"
587
  )
588
 
589
  with gr.Row():
 
593
  maximum=2.0,
594
  value=1.0,
595
  step=0.1,
596
+ label="Speaking rate control"
597
+ )
598
+ with gr.Column():
599
+ emotion_v1 = gr.Radio(
600
+ choices=["Angry", "Happy", "Surprise", "Sad", "Fearful", "Jolliest"],
601
+ value="Happy",
602
+ label="Emotion selection"
603
  )
604
+
605
+ with gr.Row():
606
  with gr.Column():
607
  speaker_v1 = gr.Dropdown(
608
  choices=names,
609
  value="徐志胜",
610
+ label="Preset timbre"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  )
612
+ with gr.Column():
613
+ gr.Markdown("### Or use a custom timbre")
614
+ with gr.Accordion("Upload reference audio", open=False, elem_classes="accordion"):
615
+ gr.Markdown("Upload 3-10 seconds of clear human voice as reference audio")
616
+ ref_audio_v1 = gr.Audio(
617
+ type="filepath",
618
+ label="upload audio",
619
+ elem_classes="audio-upload"
620
+ )
621
+ ref_text_v1 = gr.Textbox(
622
+ lines=2,
623
+ placeholder="ref text content...",
624
+ label="ref text"
625
+ )
626
 
627
  gr.Markdown("""
628
  <div class="model-info">
629
+ <p><span class="info-icon">ℹ️</span> <strong>specification of a model:</strong> This model added emotion control ability on the basis of timbre cloning, and could generate speech with specific emotion.</p>
630
+ <p><span class="info-icon">💡</span> <strong>use skill:</strong> The sentiment expression effect is related to the content of the text, make sure the text matches the selected sentiment.</p>
631
  </div>
632
  """)
633
 
634
  with gr.Column(scale=1, elem_classes="output-section"):
635
+ gr.Markdown("### output result")
636
+ tts_v1_output = gr.Audio(
637
  type="filepath",
638
+ label="Generating speech",
 
639
  interactive=False
640
  )
641
+ tts_v1_button = gr.Button(
642
+ "🚀 Generating speech",
643
  variant="primary",
644
  elem_classes="btn-generate"
645
  )
646
  gr.Examples(
647
  examples=[
648
+ ["这真是太令人兴奋了!我们刚刚完成了一个重大突破!", "Happy", "徐志胜"],
649
+ ["我简直不敢相信!这怎么可能发生?", "Surprise", "李雪琴"],
650
+ ["这太让人失望了,我们所有的努力都白费了。", "Sad", "范志毅"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
  ],
652
+ inputs=[tts_text_v1, emotion_v1, speaker_v1],
653
+ label="Emotion example"
654
  )
655
+ with gr.TabItem("😄 Control of emotion enhenced", id=1):
 
 
656
  with gr.Row():
657
  with gr.Column(scale=2, elem_classes="input-section"):
658
+ gr.Markdown("### Input Settings")
659
+ tts_text_v2 = gr.Textbox(
660
  lines=3,
661
+ placeholder="Enter the text content you want to compose...",
662
+ label="Synthesizing text",
663
  value="这真是太令人兴奋了!我们刚刚完成了一个重大突破!"
664
  )
665
 
666
  with gr.Row():
667
  with gr.Column():
668
+ speed_v2 = gr.Slider(
669
  minimum=0.5,
670
  maximum=2.0,
671
  value=1.0,
672
  step=0.1,
673
+ label="Speaking rate control"
674
  )
675
  with gr.Column():
676
+ emotion_v2 = gr.Radio(
677
+ choices=["Angry", "Happy", "Surprise", "Sad", "Fearful", "Jolliest"],
678
  value="Happy",
679
+ label="Emotion selection"
680
  )
681
 
682
  with gr.Row():
683
  with gr.Column():
684
+ speaker_v2 = gr.Dropdown(
685
  choices=names,
686
  value="徐志胜",
687
+ label="Preset timbre"
688
  )
689
  with gr.Column():
690
+ gr.Markdown("### Or use a custom timbre")
691
+ with gr.Accordion("Upload reference audio", open=False, elem_classes="accordion"):
692
+ gr.Markdown("Upload 3-10 seconds of clear human voice as reference audio")
693
+ ref_audio_v2 = gr.Audio(
694
  type="filepath",
695
+ label="upload audio",
696
  elem_classes="audio-upload"
697
  )
698
+ ref_text_v2 = gr.Textbox(
699
  lines=2,
700
+ placeholder="ref text content...",
701
+ label="ref text"
702
  )
703
 
704
  gr.Markdown("""
705
  <div class="model-info">
706
+ <p><span class="info-icon">ℹ️</span> <strong>specification of a model:</strong> This model added emotion control ability on the basis of timbre cloning, and could generate speech with specific emotion.</p>
707
+ <p><span class="info-icon">💡</span> <strong>use skill:</strong> The sentiment expression effect is related to the content of the text, make sure the text matches the selected sentiment.</p>
708
  </div>
709
  """)
710
 
711
  with gr.Column(scale=1, elem_classes="output-section"):
712
+ gr.Markdown("### output result")
713
+ tts_v2_output = gr.Audio(
714
  type="filepath",
715
+ label="Generating speech",
716
  interactive=False
717
  )
718
+ tts_v2_button = gr.Button(
719
+ "🚀 Generating speech",
720
  variant="primary",
721
  elem_classes="btn-generate"
722
  )
 
726
  ["我简直不敢相信!这怎么可能发生?", "Surprise", "李雪琴"],
727
  ["这太让人失望了,我们所有的努力都白费了。", "Sad", "范志毅"]
728
  ],
729
+ inputs=[tts_text_v2, emotion_v2, speaker_v2],
730
+ label="emotion example"
731
  )
732
 
 
733
  gr.Markdown("""
734
  <div class="footer">
735
+ <p>Marco-Voice text to speech v1.0 | based on excepent open source tts model | tech support: tech@marco-voice.com</p>
736
+ <p>attention: synthesised speech only use to tech share</p>
737
  </div>
738
  """)
739
 
 
 
 
 
 
 
740
 
741
+ tts_v1_button.click(
742
+ fn=generate_speech_speakerminus,
743
+ inputs=[tts_text_v1, speed_v1, speaker_v1, emotion_v1, ref_audio_v1, ref_text_v1],
744
+ outputs=tts_v1_output
745
  )
746
  # tts_text, speed, speaker, key, ref_audio, ref_text
747
+ tts_v2_button.click(
748
+ fn=generate_speech_sft,
749
+ inputs=[tts_text_v2, speed_v2, speaker_v2, emotion_v2, ref_audio_v2, ref_text_v2],
750
+ outputs=tts_v2_output
751
  )
752
 
753
  if __name__ == "__main__":
 
755
  server_name="0.0.0.0",
756
  server_port=10163,
757
  share=True,
758
+ favicon_path=logo_path2
759
  )
cosyvoice_rodis/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #
2
+
cosyvoice_rodis/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (165 Bytes). View file
 
cosyvoice_rodis/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (186 Bytes). View file
 
cosyvoice_rodis/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (163 Bytes). View file
 
cosyvoice_rodis/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (140 Bytes). View file
 
cosyvoice_rodis/bin/average_model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
4
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+ import argparse
20
+ import glob
21
+
22
+ import yaml
23
+ import torch
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser(description='average model')
27
+ parser.add_argument('--dst_model', required=True, help='averaged model')
28
+ parser.add_argument('--src_path',
29
+ required=True,
30
+ help='src model path for average')
31
+ parser.add_argument('--val_best',
32
+ action="store_true",
33
+ help='averaged model')
34
+ parser.add_argument('--num',
35
+ default=5,
36
+ type=int,
37
+ help='nums for averaged model')
38
+
39
+ args = parser.parse_args()
40
+ print(args)
41
+ return args
42
+
43
+ def main():
44
+ args = get_args()
45
+ val_scores = []
46
+ if args.val_best:
47
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48
+ yamls = [
49
+ f for f in yamls
50
+ if not (os.path.basename(f).startswith('train')
51
+ or os.path.basename(f).startswith('init'))
52
+ ]
53
+ for y in yamls:
54
+ with open(y, 'r') as f:
55
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56
+ loss = float(dic_yaml['loss_dict']['loss'])
57
+ epoch = int(dic_yaml['epoch'])
58
+ step = int(dic_yaml['step'])
59
+ tag = dic_yaml['tag']
60
+ val_scores += [[epoch, step, loss, tag]]
61
+ sorted_val_scores = sorted(val_scores,
62
+ key=lambda x: x[2],
63
+ reverse=False)
64
+ print("best val (epoch, step, loss, tag) = " +
65
+ str(sorted_val_scores[:args.num]))
66
+ path_list = [
67
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68
+ for score in sorted_val_scores[:args.num]
69
+ ]
70
+ print(path_list)
71
+ avg = {}
72
+ num = args.num
73
+ assert num == len(path_list)
74
+ for path in path_list:
75
+ print('Processing {}'.format(path))
76
+ states = torch.load(path, map_location=torch.device('cpu'))
77
+ for k in states.keys():
78
+ if k not in avg.keys():
79
+ avg[k] = states[k].clone()
80
+ else:
81
+ avg[k] += states[k]
82
+ # average
83
+ for k in avg.keys():
84
+ if avg[k] is not None:
85
+ # pytorch 1.6 use true_divide instead of /=
86
+ avg[k] = torch.true_divide(avg[k], num)
87
+ print('Saving to {}'.format(args.dst_model))
88
+ torch.save(avg, args.dst_model)
89
+
90
+ if __name__ == '__main__':
91
+ main()
cosyvoice_rodis/bin/export_jit.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import print_function
18
+
19
+ import argparse
20
+ import logging
21
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
22
+ import os
23
+ import sys
24
+ import torch
25
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.append('{}/../..'.format(ROOT_DIR))
27
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
28
+ from cosyvoice_rodis.cli.cosyvoice import CosyVoice
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='export your model for deployment')
32
+ parser.add_argument('--model_dir',
33
+ type=str,
34
+ default='pretrained_models/CosyVoice-300M',
35
+ help='local path')
36
+ args = parser.parse_args()
37
+ print(args)
38
+ return args
39
+
40
+ def main():
41
+ args = get_args()
42
+ logging.basicConfig(level=logging.DEBUG,
43
+ format='%(asctime)s %(levelname)s %(message)s')
44
+
45
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
46
+ torch._C._jit_set_profiling_mode(False)
47
+ torch._C._jit_set_profiling_executor(False)
48
+
49
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
50
+
51
+ # 1. export llm text_encoder
52
+ llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
53
+ script = torch.jit.script(llm_text_encoder)
54
+ script = torch.jit.freeze(script)
55
+ script = torch.jit.optimize_for_inference(script)
56
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
57
+
58
+ # 2. export llm llm
59
+ llm_llm = cosyvoice.model.llm.llm.half()
60
+ script = torch.jit.script(llm_llm)
61
+ script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
62
+ script = torch.jit.optimize_for_inference(script)
63
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
64
+
65
+ # 3. export flow encoder
66
+ flow_encoder = cosyvoice.model.flow.encoder
67
+ script = torch.jit.script(flow_encoder)
68
+ script = torch.jit.freeze(script)
69
+ script = torch.jit.optimize_for_inference(script)
70
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
71
+
72
+ if __name__ == '__main__':
73
+ main()
cosyvoice_rodis/bin/export_onnx.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
4
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from __future__ import print_function
19
+
20
+ import argparse
21
+ import logging
22
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
23
+ import os
24
+ import sys
25
+ import onnxruntime
26
+ import random
27
+ import torch
28
+ from tqdm import tqdm
29
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
30
+ sys.path.append('{}/../..'.format(ROOT_DIR))
31
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
32
+ from cosyvoice_rodis.cli.cosyvoice import CosyVoice
33
+
34
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
35
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
36
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
37
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
38
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
39
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
40
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
41
+ return x, mask, mu, t, spks, cond
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='export your model for deployment')
45
+ parser.add_argument('--model_dir',
46
+ type=str,
47
+ default='pretrained_models/CosyVoice-300M',
48
+ help='local path')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+ def main():
54
+ args = get_args()
55
+ logging.basicConfig(level=logging.DEBUG,
56
+ format='%(asctime)s %(levelname)s %(message)s')
57
+
58
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
59
+
60
+ # 1. export flow decoder estimator
61
+ estimator = cosyvoice.model.flow.decoder.estimator
62
+
63
+ device = cosyvoice.model.device
64
+ batch_size, seq_len = 1, 256
65
+ out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
66
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
67
+ torch.onnx.export(
68
+ estimator,
69
+ (x, mask, mu, t, spks, cond),
70
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
71
+ export_params=True,
72
+ opset_version=18,
73
+ do_constant_folding=True,
74
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
75
+ output_names=['estimator_out'],
76
+ dynamic_axes={
77
+ 'x': {0: 'batch_size', 2: 'seq_len'},
78
+ 'mask': {0: 'batch_size', 2: 'seq_len'},
79
+ 'mu': {0: 'batch_size', 2: 'seq_len'},
80
+ 'cond': {0: 'batch_size', 2: 'seq_len'},
81
+ 't': {0: 'batch_size'},
82
+ 'spks': {0: 'batch_size'},
83
+ 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
84
+ }
85
+ )
86
+
87
+ # 2. test computation consistency
88
+ option = onnxruntime.SessionOptions()
89
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
90
+ option.intra_op_num_threads = 1
91
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
92
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
93
+ sess_options=option, providers=providers)
94
+
95
+ for _ in tqdm(range(10)):
96
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
97
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
98
+ ort_inputs = {
99
+ 'x': x.cpu().numpy(),
100
+ 'mask': mask.cpu().numpy(),
101
+ 'mu': mu.cpu().numpy(),
102
+ 't': t.cpu().numpy(),
103
+ 'spks': spks.cpu().numpy(),
104
+ 'cond': cond.cpu().numpy()
105
+ }
106
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
107
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
108
+
109
+ if __name__ == "__main__":
110
+ main()
cosyvoice_rodis/bin/inference.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import print_function
18
+
19
+ import argparse
20
+ import logging
21
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
22
+ import os
23
+ import torch
24
+ from torch.utils.data import DataLoader
25
+ import torchaudio
26
+ from hyperpyyaml import load_hyperpyyaml
27
+ from tqdm import tqdm
28
+ from cosyvoice_rodis.cli.model import CosyVoiceModel
29
+ from cosyvoice_rodis.dataset.dataset import Dataset
30
+
31
+ def get_args():
32
+ parser = argparse.ArgumentParser(description='inference with your model')
33
+ parser.add_argument('--config', required=True, help='config file')
34
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
35
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
36
+ parser.add_argument('--tts_text', required=True, help='tts input file')
37
+ parser.add_argument('--llm_model', required=True, help='llm model file')
38
+ parser.add_argument('--flow_model', required=True, help='flow model file')
39
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40
+ parser.add_argument('--gpu',
41
+ type=int,
42
+ default=-1,
43
+ help='gpu id for this rank, -1 for cpu')
44
+ parser.add_argument('--mode',
45
+ default='sft',
46
+ choices=['sft', 'zero_shot'],
47
+ help='inference mode')
48
+ parser.add_argument('--result_dir', required=True, help='asr result file')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+ def main():
54
+ args = get_args()
55
+ logging.basicConfig(level=logging.DEBUG,
56
+ format='%(asctime)s %(levelname)s %(message)s')
57
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
58
+
59
+ # Init cosyvoice models from configs
60
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
61
+ device = torch.device('cuda' if use_cuda else 'cpu')
62
+ with open(args.config, 'r') as f:
63
+ configs = load_hyperpyyaml(f)
64
+
65
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], True)
66
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
67
+
68
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for _, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text_token = batch["text_token"].to(device)
81
+ text_token_len = batch["text_token_len"].to(device)
82
+ tts_index = batch["tts_index"]
83
+ tts_text_token = batch["tts_text_token"].to(device)
84
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
85
+ speech_token = batch["speech_token"].to(device)
86
+ speech_token_len = batch["speech_token_len"].to(device)
87
+ speech_feat = batch["speech_feat"].to(device)
88
+ speech_feat_len = batch["speech_feat_len"].to(device)
89
+ utt_embedding = batch["utt_embedding"].to(device)
90
+ spk_embedding = batch["spk_embedding"].to(device)
91
+ if args.mode == 'sft':
92
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
93
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
94
+ else:
95
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
96
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
97
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
98
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
99
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
100
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
101
+ tts_speeches = []
102
+ for model_output in model.tts(**model_input):
103
+ tts_speeches.append(model_output['tts_speech'])
104
+ tts_speeches = torch.concat(tts_speeches, dim=1)
105
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
106
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
107
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
108
+ f.write('{} {}\n'.format(tts_key, tts_fn))
109
+ f.flush()
110
+ f.close()
111
+ logging.info('Result wav.scp saved in {}'.format(fn))
112
+
113
+ if __name__ == '__main__':
114
+ main()
cosyvoice_rodis/bin/train.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import print_function
18
+ import sys
19
+ print("_+++++++++++++++++++++")
20
+ print(sys.path)
21
+ sys.path.append("/mnt/workspace/baipeng/project/Marco-Voice/Models/marco_voice")
22
+ import argparse
23
+ import datetime
24
+ import logging
25
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
26
+ from copy import deepcopy
27
+ import os
28
+ import torch
29
+ import torch.distributed as dist
30
+ import deepspeed
31
+
32
+ from hyperpyyaml import load_hyperpyyaml
33
+
34
+ from torch.distributed.elastic.multiprocessing.errors import record
35
+
36
+ from cosyvoice_rodis.utils.executor import Executor
37
+ from cosyvoice_rodis.utils.train_utils import (
38
+ init_distributed,
39
+ init_dataset_and_dataloader,
40
+ init_optimizer_and_scheduler,
41
+ init_summarywriter, save_model,
42
+ wrap_cuda_model, check_modify_and_save_config)
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='training your network')
45
+ parser.add_argument('--train_engine',
46
+ default='torch_ddp',
47
+ choices=['torch_ddp', 'deepspeed'],
48
+ help='Engine for paralleled training')
49
+ parser.add_argument('--model', required=True, help='model which will be trained')
50
+ parser.add_argument('--config', required=True, help='config file')
51
+ parser.add_argument('--train_data', required=True, help='train data file')
52
+ parser.add_argument('--cv_data', required=True, help='cv data file')
53
+ parser.add_argument('--checkpoint', help='checkpoint model')
54
+ parser.add_argument('--model_dir', required=True, help='save model dir')
55
+ parser.add_argument('--tensorboard_dir',
56
+ default='tensorboard',
57
+ help='tensorboard log dir')
58
+ parser.add_argument('--ddp.dist_backend',
59
+ dest='dist_backend',
60
+ default='nccl',
61
+ choices=['nccl', 'gloo'],
62
+ help='distributed backend')
63
+ parser.add_argument('--num_workers',
64
+ default=0,
65
+ type=int,
66
+ help='num of subprocess workers for reading')
67
+ parser.add_argument('--prefetch',
68
+ default=100,
69
+ type=int,
70
+ help='prefetch number')
71
+ parser.add_argument('--pin_memory',
72
+ action='store_true',
73
+ default=False,
74
+ help='Use pinned memory buffers used for reading')
75
+ parser.add_argument('--use_amp',
76
+ action='store_true',
77
+ default=False,
78
+ help='Use automatic mixed precision training')
79
+ parser.add_argument('--deepspeed.save_states',
80
+ dest='save_states',
81
+ default='model_only',
82
+ choices=['model_only', 'model+optimizer'],
83
+ help='save model/optimizer states')
84
+ parser.add_argument('--timeout',
85
+ default=60,
86
+ type=int,
87
+ help='timeout (in seconds) of cosyvoice_join.')
88
+ parser = deepspeed.add_config_arguments(parser)
89
+ args = parser.parse_args()
90
+ return args
91
+
92
+ @record
93
+ def main():
94
+ args = get_args()
95
+ logging.basicConfig(level=logging.DEBUG,
96
+ format='%(asctime)s %(levelname)s %(message)s')
97
+ # gan train has some special initialization logic
98
+ gan = True if args.model == 'hifigan' else False
99
+
100
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
101
+ if gan is True:
102
+ override_dict.pop('hift')
103
+ with open(args.config, 'r') as f:
104
+ configs = load_hyperpyyaml(f, overrides=override_dict)
105
+ if gan is True:
106
+ configs['train_conf'] = configs['train_conf_gan']
107
+ configs['train_conf'].update(vars(args))
108
+
109
+ # Init env for ddp
110
+ init_distributed(args)
111
+
112
+ # Get dataset & dataloader
113
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
114
+ init_dataset_and_dataloader(args, configs, gan)
115
+ # Do some sanity checks and save config to arsg.model_dir
116
+ configs = check_modify_and_save_config(args, configs)
117
+
118
+ # Tensorboard summary
119
+ writer = init_summarywriter(args)
120
+
121
+ # load checkpoint
122
+ model = configs[args.model]
123
+ if args.checkpoint is not None:
124
+ if os.path.exists(args.checkpoint):
125
+ model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
126
+ else:
127
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
128
+
129
+ # Dispatch model from cpu to gpu
130
+ model = wrap_cuda_model(args, model)
131
+
132
+ # Get optimizer & scheduler
133
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
134
+
135
+ # Save init checkpoints
136
+ info_dict = deepcopy(configs['train_conf'])
137
+ save_model(model, 'init', info_dict)
138
+
139
+ # Get executor
140
+ executor = Executor(gan=gan)
141
+
142
+ # Init scaler, used for pytorch amp mixed precision training
143
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
144
+
145
+ # Start training loop
146
+ for epoch in range(info_dict['max_epoch']):
147
+ executor.epoch = epoch
148
+ train_dataset.set_epoch(epoch)
149
+ dist.barrier()
150
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
151
+ if gan is True:
152
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
153
+ writer, info_dict, scaler, group_join)
154
+ else:
155
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join) #进
156
+ dist.destroy_process_group(group_join)
157
+
158
+ if __name__ == '__main__':
159
+ main()
cosyvoice_rodis/cli/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #
2
+
cosyvoice_rodis/cli/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (169 Bytes). View file
 
cosyvoice_rodis/cli/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (190 Bytes). View file
 
cosyvoice_rodis/cli/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (167 Bytes). View file
 
cosyvoice_rodis/cli/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (144 Bytes). View file
 
cosyvoice_rodis/cli/__pycache__/cosyvoice.cpython-310.pyc ADDED
Binary file (4.57 kB). View file
 
cosyvoice_rodis/cli/__pycache__/cosyvoice.cpython-312.pyc ADDED
Binary file (8.91 kB). View file
 
cosyvoice_rodis/cli/__pycache__/cosyvoice.cpython-38.pyc ADDED
Binary file (4.49 kB). View file
 
cosyvoice_rodis/cli/__pycache__/cosyvoice.cpython-39.pyc ADDED
Binary file (4.57 kB). View file
 
cosyvoice_rodis/cli/__pycache__/frontend.cpython-310.pyc ADDED
Binary file (7.18 kB). View file
 
cosyvoice_rodis/cli/__pycache__/frontend.cpython-38.pyc ADDED
Binary file (7.19 kB). View file
 
cosyvoice_rodis/cli/__pycache__/frontend.cpython-39.pyc ADDED
Binary file (7.16 kB). View file
 
cosyvoice_rodis/cli/__pycache__/model.cpython-310.pyc ADDED
Binary file (7.88 kB). View file
 
cosyvoice_rodis/cli/__pycache__/model.cpython-38.pyc ADDED
Binary file (7.71 kB). View file
 
cosyvoice_rodis/cli/__pycache__/model.cpython-39.pyc ADDED
Binary file (7.76 kB). View file
 
cosyvoice_rodis/cli/cosyvoice.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import os
17
+ import time
18
+ from tqdm import tqdm
19
+ from hyperpyyaml import load_hyperpyyaml
20
+ from modelscope import snapshot_download
21
+ from .frontend import CosyVoiceFrontEnd
22
+ from .model import CosyVoiceModel
23
+ from ..utils.file_utils import logging
24
+
25
+ class CosyVoice:
26
+
27
+ def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
28
+ instruct = True if '-Instruct' in model_dir else False
29
+ self.model_dir = model_dir
30
+ if not os.path.exists(model_dir):
31
+ model_dir = snapshot_download(model_dir)
32
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
33
+ configs = load_hyperpyyaml(f)
34
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
35
+ configs['feat_extractor'],
36
+ '{}/campplus.onnx'.format(model_dir),
37
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
38
+ '{}/spk2info.pt'.format(model_dir),
39
+ instruct,
40
+ configs['allowed_special'])
41
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
42
+ self.model.load('{}/llm.pt'.format(model_dir),
43
+ '{}/flow.pt'.format(model_dir),
44
+ '{}/hift.pt'.format(model_dir))
45
+ if load_jit:
46
+ self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
47
+ '{}/llm.llm.fp16.zip'.format(model_dir),
48
+ '{}/flow.encoder.fp32.zip'.format(model_dir))
49
+ if load_onnx:
50
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
51
+ del configs
52
+
53
+ def list_avaliable_spks(self):
54
+ spks = list(self.frontend.spk2info.keys())
55
+ return spks
56
+
57
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
58
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
59
+ model_input = self.frontend.frontend_sft(i, spk_id)
60
+ start_time = time.time()
61
+ logging.info('synthesis text {}'.format(i))
62
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
63
+ speech_len = model_output['tts_speech'].shape[1] / 22050
64
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
65
+ yield model_output
66
+ start_time = time.time()
67
+ def synthesize(self, tts_text, prompt_text, prompt_speech_16k, key, emotion_embedding, stream=False, speed=1.0):
68
+ prompt_text = self.frontend.text_normalize(key+'<endofprompt>' + prompt_text, split=False)
69
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
70
+ if len(i) < 0.5 * len(prompt_text):
71
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
72
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k,emotion_embedding)
73
+ print("input:", model_input)
74
+ start_time = time.time()
75
+ logging.info('synthesis text {}'.format(i))
76
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
77
+ speech_len = model_output['tts_speech'].shape[1] / 22050
78
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
79
+ yield model_output
80
+ start_time = time.time()
81
+
82
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
83
+ if self.frontend.instruct is True:
84
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
85
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
86
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
87
+ start_time = time.time()
88
+ logging.info('synthesis text {}'.format(i))
89
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
90
+ speech_len = model_output['tts_speech'].shape[1] / 22050
91
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
92
+ yield model_output
93
+ start_time = time.time()
94
+
95
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
96
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
97
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
98
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
99
+ start_time = time.time()
100
+ logging.info('synthesis text {}'.format(i))
101
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
102
+ speech_len = model_output['tts_speech'].shape[1] / 22050
103
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
104
+ yield model_output
105
+ start_time = time.time()
106
+
107
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
108
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
109
+ start_time = time.time()
110
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
111
+ speech_len = model_output['tts_speech'].shape[1] / 22050
112
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
113
+ yield model_output
114
+ start_time = time.time()
cosyvoice_rodis/cli/frontend.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from functools import partial
18
+ import onnxruntime
19
+ import torch
20
+ import numpy as np
21
+ import whisper
22
+ from typing import Callable
23
+ import torchaudio.compliance.kaldi as kaldi
24
+ import torchaudio
25
+ import os
26
+ import re
27
+ import inflect
28
+ # try:
29
+ # import ttsfrd
30
+ # use_ttsfrd = True
31
+ # except ImportError:
32
+ # print("failed to import ttsfrd, use WeTextProcessing instead")
33
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
34
+ from tn.english.normalizer import Normalizer as EnNormalizer
35
+ use_ttsfrd = False
36
+ from ..utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
37
+
38
+ class CosyVoiceFrontEnd:
39
+
40
+ def __init__(self,
41
+ get_tokenizer: Callable,
42
+ feat_extractor: Callable,
43
+ campplus_model: str,
44
+ speech_tokenizer_model: str,
45
+ spk2info: str = '',
46
+ instruct: bool = False,
47
+ allowed_special: str = 'all'):
48
+ self.tokenizer = get_tokenizer()
49
+ self.feat_extractor = feat_extractor
50
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ option = onnxruntime.SessionOptions()
52
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
53
+ option.intra_op_num_threads = 1
54
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
55
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
56
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
57
+ "CPUExecutionProvider"])
58
+ if os.path.exists(spk2info):
59
+ self.spk2info = torch.load(spk2info, map_location=self.device)
60
+ else:
61
+ self.spk2info = {}
62
+ self.instruct = instruct
63
+ self.allowed_special = allowed_special
64
+ self.inflect_parser = inflect.engine()
65
+ self.use_ttsfrd = use_ttsfrd
66
+ if self.use_ttsfrd:
67
+ self.frd = ttsfrd.TtsFrontendEngine()
68
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
69
+ assert self.frd.initialize('/mnt/workspace/baipeng/project/Marco-Voice/Models/marco_voice/utils/pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
70
+ 'failed to initialize ttsfrd resource'
71
+
72
+ self.frd.set_lang_type('pinyinvg')
73
+ self.frd.enable_pinyin_mix(True)
74
+ self.frd.set_breakmodel_index(1)
75
+ else:
76
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
77
+ self.en_tn_model = EnNormalizer()
78
+
79
+ def _extract_text_token(self, text):
80
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
81
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
82
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) # 14 21
83
+ return text_token, text_token_len
84
+
85
+ def _extract_speech_token(self, speech):
86
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
87
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
88
+ speech_token = self.speech_tokenizer_session.run(None,
89
+ {self.speech_tokenizer_session.get_inputs()[0].name:
90
+ feat.detach().cpu().numpy(),
91
+ self.speech_tokenizer_session.get_inputs()[1].name:
92
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
93
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
94
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
95
+ return speech_token, speech_token_len
96
+
97
+ def _extract_spk_embedding(self, speech):
98
+ feat = kaldi.fbank(speech,
99
+ num_mel_bins=80,
100
+ dither=0,
101
+ sample_frequency=16000)
102
+ feat = feat - feat.mean(dim=0, keepdim=True)
103
+ embedding = self.campplus_session.run(None,
104
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
105
+ embedding = torch.tensor([embedding]).to(self.device)
106
+ return embedding
107
+
108
+ def _extract_speech_feat(self, speech):
109
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
110
+ speech_feat = speech_feat.unsqueeze(dim=0)
111
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
112
+ return speech_feat, speech_feat_len
113
+
114
+ def text_normalize(self, text, split=True):
115
+ text = text.strip()
116
+ if contains_chinese(text):
117
+ if self.use_ttsfrd:
118
+ text = self.frd.get_frd_extra_info(text, 'input')
119
+ else:
120
+ text = self.zh_tn_model.normalize(text)
121
+ text = text.replace("\n", "")
122
+ text = replace_blank(text)
123
+ text = replace_corner_mark(text)
124
+ text = text.replace(".", "。")
125
+ text = text.replace(" - ", ",")
126
+ text = remove_bracket(text)
127
+ text = re.sub(r'[,,、]+$', '。', text)
128
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
129
+ token_min_n=60, merge_len=20, comma_split=False))
130
+ else:
131
+ if self.use_ttsfrd:
132
+ text = self.frd.get_frd_extra_info(text, 'input')
133
+ else:
134
+ text = self.en_tn_model.normalize(text)
135
+ text = spell_out_number(text, self.inflect_parser)
136
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
137
+ token_min_n=60, merge_len=20, comma_split=False))
138
+ if split is False:
139
+ return text
140
+
141
+ return texts
142
+
143
+ def frontend_sft(self, tts_text, spk_id):
144
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
145
+ embedding = self.spk2info[spk_id]['embedding']
146
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
147
+ return model_input
148
+
149
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, emotion_speakerminus):
150
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
151
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
152
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
153
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
154
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
155
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
156
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
157
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
158
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
159
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
160
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
161
+ 'llm_embedding': embedding, 'emotion_embedding': emotion_speakerminus, 'flow_embedding': embedding}
162
+ return model_input
163
+
164
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
165
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
166
+ # in cross lingual mode, we remove prompt in llm
167
+ del model_input['prompt_text']
168
+ del model_input['prompt_text_len']
169
+ del model_input['llm_prompt_speech_token']
170
+ del model_input['llm_prompt_speech_token_len']
171
+ return model_input
172
+
173
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
174
+ model_input = self.frontend_sft(tts_text, spk_id)
175
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
176
+ del model_input['llm_embedding']
177
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
178
+ model_input['prompt_text'] = instruct_text_token
179
+ model_input['prompt_text_len'] = instruct_text_token_len
180
+ return model_input
181
+
182
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k):
183
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
184
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
185
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
186
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
187
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
188
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
189
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
190
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
191
+ 'flow_embedding': embedding}
192
+ return model_input
cosyvoice_rodis/cli/model.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import torch
17
+ import numpy as np
18
+ import threading
19
+ import time
20
+ from torch.nn import functional as F
21
+ from contextlib import nullcontext
22
+ import uuid
23
+ from ..utils.common import fade_in_out
24
+
25
+ class CosyVoiceModel:
26
+
27
+ def __init__(self,
28
+ llm: torch.nn.Module,
29
+ flow: torch.nn.Module,
30
+ hift: torch.nn.Module,
31
+ fp16: bool):
32
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+ self.llm = llm
34
+ self.flow = flow
35
+ self.hift = hift
36
+ self.fp16 = fp16
37
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
38
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
39
+ self.token_overlap_len = 20
40
+ # mel fade in out
41
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
42
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
43
+ # hift cache
44
+ self.mel_cache_len = 20
45
+ self.source_cache_len = int(self.mel_cache_len * 256)
46
+ # speech fade in out
47
+ self.speech_window = np.hamming(2 * self.source_cache_len)
48
+ # rtf and decoding related
49
+ self.stream_scale_factor = 1
50
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
51
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
52
+ self.lock = threading.Lock()
53
+ # dict used to store session related variable
54
+ self.tts_speech_token_dict = {}
55
+ self.llm_end_dict = {}
56
+ self.mel_overlap_dict = {}
57
+ self.flow_cache_dict = {}
58
+ self.hift_cache_dict = {}
59
+
60
+ def load(self, llm_model, flow_model, hift_model):
61
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
62
+ self.llm.to(self.device).eval()
63
+ if self.fp16 is True:
64
+ self.llm.half()
65
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
66
+ self.flow.to(self.device).eval()
67
+ # in case hift_model is a hifigan model
68
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
69
+ self.hift.load_state_dict(hift_state_dict, strict=False)
70
+ self.hift.to(self.device).eval()
71
+
72
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
73
+ assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
74
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
75
+ self.llm.text_encoder = llm_text_encoder
76
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
77
+ self.llm.llm = llm_llm
78
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
79
+ self.flow.encoder = flow_encoder
80
+
81
+ def load_onnx(self, flow_decoder_estimator_model):
82
+ import onnxruntime
83
+ option = onnxruntime.SessionOptions()
84
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
85
+ option.intra_op_num_threads = 1
86
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
87
+ del self.flow.decoder.estimator
88
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
89
+
90
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, emotion_embedding, uuid):
91
+ if self.fp16 is True:
92
+ llm_embedding = llm_embedding.half()
93
+ with self.llm_context:
94
+ for i in self.llm.inference(text=text.to(self.device),
95
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
96
+ prompt_text=prompt_text.to(self.device),
97
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
98
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
99
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
100
+ embedding=llm_embedding.to(self.device),
101
+ emotion_embedding = emotion_embedding.to(self.device)):
102
+ self.tts_speech_token_dict[uuid].append(i)
103
+ self.llm_end_dict[uuid] = True
104
+
105
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
106
+ tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
107
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
108
+ prompt_token=prompt_token.to(self.device),
109
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
110
+ prompt_feat=prompt_feat.to(self.device),
111
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
112
+ embedding=embedding.to(self.device),
113
+ flow_cache=self.flow_cache_dict[uuid])
114
+ self.flow_cache_dict[uuid] = flow_cache
115
+
116
+ # mel overlap fade in out
117
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
118
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
119
+ # append hift cache
120
+ if self.hift_cache_dict[uuid] is not None:
121
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
122
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
123
+ else:
124
+ hift_cache_source = torch.zeros(1, 1, 0)
125
+ # keep overlap mel and hift cache
126
+ if finalize is False:
127
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
128
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
129
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
130
+ if self.hift_cache_dict[uuid] is not None:
131
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
132
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
133
+ 'source': tts_source[:, :, -self.source_cache_len:],
134
+ 'speech': tts_speech[:, -self.source_cache_len:]}
135
+ tts_speech = tts_speech[:, :-self.source_cache_len]
136
+ else:
137
+ if speed != 1.0:
138
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
139
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
140
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
141
+ if self.hift_cache_dict[uuid] is not None:
142
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
143
+ return tts_speech
144
+
145
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), emotion_embedding=torch.zeros(0, 192),
146
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
147
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
148
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
149
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
150
+ # this_uuid is used to track variables related to this inference thread
151
+ #print("tts函数中")
152
+ #print(text)
153
+ this_uuid = str(uuid.uuid1())
154
+ with self.lock:
155
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
156
+ self.hift_cache_dict[this_uuid] = None
157
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
158
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
159
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, emotion_embedding, this_uuid))
160
+ p.start()
161
+ if stream is True:
162
+ token_hop_len = self.token_min_hop_len
163
+ while True:
164
+ time.sleep(0.1)
165
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
166
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
167
+ .unsqueeze(dim=0)
168
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
169
+ prompt_token=flow_prompt_speech_token,
170
+ prompt_feat=prompt_speech_feat,
171
+ embedding=flow_embedding,
172
+ uuid=this_uuid,
173
+ finalize=False)
174
+ yield {'tts_speech': this_tts_speech.cpu()}
175
+ with self.lock:
176
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
177
+ # increase token_hop_len for better speech quality
178
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
179
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
180
+ break
181
+ p.join()
182
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
183
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
184
+ prompt_token=flow_prompt_speech_token,
185
+ prompt_feat=prompt_speech_feat,
186
+ embedding=flow_embedding,
187
+ uuid=this_uuid,
188
+ finalize=True)
189
+ yield {'tts_speech': this_tts_speech.cpu()}
190
+ else:
191
+ p.join()
192
+
193
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
194
+
195
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
196
+ prompt_token=flow_prompt_speech_token,
197
+ prompt_feat=prompt_speech_feat,
198
+ embedding=flow_embedding,
199
+ uuid=this_uuid,
200
+ finalize=True,
201
+ speed=speed)
202
+ yield {'tts_speech': this_tts_speech.cpu()}
203
+ with self.lock:
204
+ self.tts_speech_token_dict.pop(this_uuid)
205
+ self.llm_end_dict.pop(this_uuid)
206
+ self.mel_overlap_dict.pop(this_uuid)
207
+ self.hift_cache_dict.pop(this_uuid)
208
+ self.flow_cache_dict.pop(this_uuid)
209
+
210
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
211
+ this_uuid = str(uuid.uuid1())
212
+ with self.lock:
213
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
214
+ self.hift_cache_dict[this_uuid] = None
215
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
216
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
217
+ if stream is True:
218
+ token_hop_len = self.token_min_hop_len
219
+ while True:
220
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
221
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
222
+ .unsqueeze(dim=0)
223
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
224
+ prompt_token=flow_prompt_speech_token,
225
+ prompt_feat=prompt_speech_feat,
226
+ embedding=flow_embedding,
227
+ uuid=this_uuid,
228
+ finalize=False)
229
+ yield {'tts_speech': this_tts_speech.cpu()}
230
+ with self.lock:
231
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
232
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
233
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
234
+ break
235
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
236
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
237
+ prompt_token=flow_prompt_speech_token,
238
+ prompt_feat=prompt_speech_feat,
239
+ embedding=flow_embedding,
240
+ uuid=this_uuid,
241
+ finalize=True)
242
+ yield {'tts_speech': this_tts_speech.cpu()}
243
+ else:
244
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
245
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
246
+ prompt_token=flow_prompt_speech_token,
247
+ prompt_feat=prompt_speech_feat,
248
+ embedding=flow_embedding,
249
+ uuid=this_uuid,
250
+ finalize=True,
251
+ speed=speed)
252
+ yield {'tts_speech': this_tts_speech.cpu()}
253
+ with self.lock:
254
+ self.tts_speech_token_dict.pop(this_uuid)
255
+ self.llm_end_dict.pop(this_uuid)
256
+ self.mel_overlap_dict.pop(this_uuid)
257
+ self.hift_cache_dict.pop(this_uuid)
cosyvoice_rodis/dataset/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #
2
+
cosyvoice_rodis/dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (190 Bytes). View file
 
cosyvoice_rodis/dataset/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (192 Bytes). View file
 
cosyvoice_rodis/dataset/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (154 Bytes). View file
 
cosyvoice_rodis/dataset/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (4.96 kB). View file
 
cosyvoice_rodis/dataset/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (4.96 kB). View file
 
cosyvoice_rodis/dataset/__pycache__/processor.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
cosyvoice_rodis/dataset/__pycache__/processor.cpython-38.pyc ADDED
Binary file (13.2 kB). View file
 
cosyvoice_rodis/dataset/__pycache__/processor.cpython-39.pyc ADDED
Binary file (12.8 kB). View file
 
cosyvoice_rodis/dataset/dataset.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
4
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import random
19
+ import json
20
+ import math
21
+ from functools import partial
22
+
23
+ import torch
24
+ import torch.distributed as dist
25
+ from torch.utils.data import IterableDataset
26
+ from cosyvoice_rodis.utils.file_utils import read_lists, read_json_lists
27
+
28
+ class Processor(IterableDataset):
29
+
30
+ def __init__(self, source, f, *args, **kw):
31
+ assert callable(f)
32
+ self.source = source
33
+ self.f = f
34
+ self.args = args
35
+ self.kw = kw
36
+
37
+ def set_epoch(self, epoch):
38
+ self.source.set_epoch(epoch)
39
+
40
+ def __iter__(self):
41
+ """ Return an iterator over the source dataset processed by the
42
+ given processor.
43
+ """
44
+ assert self.source is not None
45
+ assert callable(self.f)
46
+ return self.f(iter(self.source), *self.args, **self.kw)
47
+
48
+ def apply(self, f):
49
+ assert callable(f)
50
+ return Processor(self, f, *self.args, **self.kw)
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+ class DataList(IterableDataset):
108
+
109
+ def __init__(self, lists, shuffle=True, partition=True):
110
+ self.lists = lists
111
+ self.sampler = DistributedSampler(shuffle, partition)
112
+
113
+ def set_epoch(self, epoch):
114
+ self.sampler.set_epoch(epoch)
115
+
116
+ def __iter__(self):
117
+ sampler_info = self.sampler.update()
118
+ indexes = self.sampler.sample(self.lists)
119
+ for index in indexes:
120
+ data = dict(src=self.lists[index])
121
+ data.update(sampler_info)
122
+ yield data
123
+
124
+ def Dataset(data_list_file,
125
+ data_pipeline,
126
+ mode='train',
127
+ gan=False,
128
+ shuffle=True,
129
+ partition=True,
130
+ tts_file='',
131
+ prompt_utt2data=''):
132
+ """ Construct dataset from arguments
133
+
134
+ We have two shuffle stage in the Dataset. The first is global
135
+ shuffle at shards tar/raw file level. The second is global shuffle
136
+ at training samples level.
137
+
138
+ Args:
139
+ data_type(str): raw/shard
140
+ tokenizer (BaseTokenizer): tokenizer to tokenize
141
+ partition(bool): whether to do data partition in terms of rank
142
+ """
143
+ # import pdb;pdb.set_trace()
144
+ assert mode in ['train', 'inference']
145
+ lists = read_lists(data_list_file) #读取文件数据
146
+ if mode == 'inference':
147
+ with open(tts_file) as f:
148
+ tts_data = json.load(f)
149
+ utt2lists = read_json_lists(prompt_utt2data)
150
+ # filter unnecessary file in inference mode
151
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
152
+ dataset = DataList(lists,
153
+ shuffle=shuffle,
154
+ partition=partition) #list就是tar文件
155
+ if mode == 'inference':
156
+ # map partial arg to parquet_opener func in inference mode
157
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
158
+ if gan is True:
159
+ # map partial arg to padding func in gan mode
160
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
161
+ for func in data_pipeline:
162
+ dataset = Processor(dataset, func, mode=mode)
163
+ return dataset
cosyvoice_rodis/dataset/processor.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+
3
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ import random
18
+
19
+ import pyarrow.parquet as pq
20
+ from io import BytesIO
21
+ import torch
22
+ import torchaudio
23
+ from torch.nn.utils.rnn import pad_sequence
24
+ import torch.nn.functional as F
25
+
26
+ torchaudio.set_audio_backend('soundfile')
27
+
28
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
29
+
30
+ def parquet_opener(data, mode='train', tts_data={}):
31
+ """ Give url or local file, return file descriptor
32
+ Inplace operation.
33
+
34
+ Args:
35
+ data(Iterable[str]): url or local file list
36
+
37
+ Returns:
38
+ Iterable[{src, stream}]
39
+ """
40
+ for sample in data:
41
+ assert 'src' in sample
42
+ url = sample['src'] #'/mnt/workspace/baipeng/project/Marco-Voice/Dataset/hunhe_data/LZED/processed_xiaoyu30_new/train/parquet/parquet_000000001.tar'
43
+ try:
44
+ df = pq.read_table(url).to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
55
+ except Exception as ex:
56
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
57
+
58
+ def filter(data,
59
+ max_length=10240,
60
+ min_length=10,
61
+ token_max_length=200,
62
+ token_min_length=1,
63
+ min_output_input_ratio=0.0005,
64
+ max_output_input_ratio=1,
65
+ mode='train'):
66
+ """ Filter sample according to feature and label length
67
+ Inplace operation.
68
+
69
+ Args::
70
+ data: Iterable[{key, wav, label, sample_rate}]
71
+ max_length: drop utterance which is greater than max_length(10ms)
72
+ min_length: drop utterance which is less than min_length(10ms)
73
+ token_max_length: drop utterance which is greater than
74
+ token_max_length, especially when use char unit for
75
+ english modeling
76
+ token_min_length: drop utterance which is
77
+ less than token_max_length
78
+ min_output_input_ratio: minimal ration of
79
+ token_length / feats_length(10ms)
80
+ max_output_input_ratio: maximum ration of
81
+ token_length / feats_length(10ms)
82
+
83
+ Returns:
84
+ Iterable[{key, wav, label, sample_rate}]
85
+ """
86
+ for sample in data:
87
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
88
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
89
+ del sample['audio_data']
90
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
91
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
92
+ if num_frames < min_length:
93
+ continue
94
+ if num_frames > max_length:
95
+ continue
96
+ if len(sample['text_token']) < token_min_length:
97
+ continue
98
+ if len(sample['text_token']) > token_max_length:
99
+ continue
100
+ if len(sample['speech_token']) == 0:
101
+ continue
102
+ if num_frames != 0:
103
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
104
+ continue
105
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
106
+ continue
107
+ yield sample
108
+
109
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
110
+ """ Resample data.
111
+ Inplace operation.
112
+
113
+ Args:
114
+ data: Iterable[{key, wav, label, sample_rate}]
115
+ resample_rate: target resample rate
116
+
117
+ Returns:
118
+ Iterable[{key, wav, label, sample_rate}]
119
+ """
120
+ for sample in data:
121
+ assert 'sample_rate' in sample
122
+ assert 'speech' in sample
123
+ sample_rate = sample['sample_rate']
124
+ waveform = sample['speech']
125
+ if sample_rate != resample_rate:
126
+ if sample_rate < min_sample_rate:
127
+ continue
128
+ sample['sample_rate'] = resample_rate
129
+ sample['speech'] = torchaudio.transforms.Resample(
130
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
131
+ max_val = sample['speech'].abs().max()
132
+ if max_val > 1:
133
+ sample['speech'] /= max_val
134
+ yield sample
135
+
136
+ def truncate(data, truncate_length=24576, mode='train'):
137
+ """ Truncate data.
138
+
139
+ Args:
140
+ data: Iterable[{key, wav, label, sample_rate}]
141
+ truncate_length: truncate length
142
+
143
+ Returns:
144
+ Iterable[{key, wav, label, sample_rate}]
145
+ """
146
+ for sample in data:
147
+ waveform = sample['speech']
148
+ if waveform.shape[1] > truncate_length:
149
+ start = random.randint(0, waveform.shape[1] - truncate_length)
150
+ waveform = waveform[:, start: start + truncate_length]
151
+ else:
152
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
153
+ sample['speech'] = waveform
154
+ yield sample
155
+
156
+ def compute_fbank(data,
157
+ feat_extractor,
158
+ mode='train'):
159
+ """ Extract fbank
160
+
161
+ Args:
162
+ data: Iterable[{key, wav, label, sample_rate}]
163
+
164
+ Returns:
165
+ Iterable[{key, feat, label}]
166
+ """
167
+ for sample in data:
168
+ assert 'sample_rate' in sample
169
+ assert 'speech' in sample
170
+ assert 'utt' in sample
171
+ assert 'text_token' in sample
172
+ waveform = sample['speech']
173
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
174
+ sample['speech_feat'] = mat
175
+ yield sample
176
+
177
+ def compute_f0(data, pitch_extractor, mode='train'):
178
+ """ Extract f0
179
+
180
+ Args:
181
+ data: Iterable[{key, wav, label, sample_rate}]
182
+
183
+ Returns:
184
+ Iterable[{key, feat, label}]
185
+ """
186
+ for sample in data:
187
+ assert 'sample_rate' in sample
188
+ assert 'speech' in sample
189
+ assert 'utt' in sample
190
+ assert 'text_token' in sample
191
+ waveform = sample['speech']
192
+ mat = pitch_extractor(waveform).transpose(1, 2)
193
+ mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
194
+ sample['pitch_feat'] = mat[0, 0]
195
+ yield sample
196
+
197
+ def parse_embedding(data, normalize, mode='train'):
198
+ """ Parse utt_embedding/spk_embedding/emotion_embedding
199
+
200
+ Args:
201
+ data: Iterable[{key, wav, label, sample_rate}]
202
+
203
+ Returns:
204
+ Iterable[{key, feat, label}]
205
+ """
206
+ for sample in data:
207
+ # print("sample:", sample)
208
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
209
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
210
+ if 'emotion_embedding' in sample:
211
+ sample['emotion_embedding'] = torch.tensor(sample['emotion_embedding'], dtype=torch.float32)
212
+ if normalize:
213
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
214
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
215
+ if 'emotion_embedding' in sample:
216
+ sample['emotion_embedding'] = F.normalize(sample['emotion_embedding'], dim=0)
217
+ yield sample
218
+
219
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
220
+ """ Decode text to chars or BPE
221
+ Inplace operation
222
+
223
+ Args:
224
+ data: Iterable[{key, wav, txt, sample_rate}]
225
+
226
+ Returns:
227
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
228
+ """
229
+ tokenizer = get_tokenizer()
230
+ for sample in data:
231
+ assert 'text' in sample
232
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
233
+ if mode == 'inference':
234
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
235
+ yield sample
236
+
237
+ def shuffle(data, shuffle_size=10000, mode='train'):
238
+ """ Local shuffle the data
239
+
240
+ Args:
241
+ data: Iterable[{key, feat, label}]
242
+ shuffle_size: buffer size for shuffle
243
+
244
+ Returns:
245
+ Iterable[{key, feat, label}]
246
+ """
247
+ buf = []
248
+ for sample in data:
249
+ buf.append(sample)
250
+ if len(buf) >= shuffle_size:
251
+ random.shuffle(buf)
252
+ for x in buf:
253
+ yield x
254
+ buf = []
255
+ # The sample left over
256
+ random.shuffle(buf)
257
+ for x in buf:
258
+ yield x
259
+
260
+ def sort(data, sort_size=500, mode='train'):
261
+ """ Sort the data by feature length.
262
+ Sort is used after shuffle and before batch, so we can group
263
+ utts with similar lengths into a batch, and `sort_size` should
264
+ be less than `shuffle_size`
265
+
266
+ Args:
267
+ data: Iterable[{key, feat, label}]
268
+ sort_size: buffer size for sort
269
+
270
+ Returns:
271
+ Iterable[{key, feat, label}]
272
+ """
273
+
274
+ buf = []
275
+ for sample in data:
276
+ buf.append(sample)
277
+ if len(buf) >= sort_size:
278
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
279
+ for x in buf:
280
+ yield x
281
+ buf = []
282
+ # The sample left over
283
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
284
+ for x in buf:
285
+ yield x
286
+
287
+ def static_batch(data, batch_size=16):
288
+ """ Static batch the data by `batch_size`
289
+
290
+ Args:
291
+ data: Iterable[{key, feat, label}]
292
+ batch_size: batch size
293
+
294
+ Returns:
295
+ Iterable[List[{key, feat, label}]]
296
+ """
297
+ buf = []
298
+ for sample in data:
299
+ buf.append(sample)
300
+ if len(buf) >= batch_size:
301
+ yield buf
302
+ buf = []
303
+ if len(buf) > 0:
304
+ yield buf
305
+
306
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
307
+ """ Dynamic batch the data until the total frames in batch
308
+ reach `max_frames_in_batch`
309
+
310
+ Args:
311
+ data: Iterable[{key, feat, label}]
312
+ max_frames_in_batch: max_frames in one batch
313
+
314
+ Returns:
315
+ Iterable[List[{key, feat, label}]]
316
+ """
317
+ buf = []
318
+ longest_frames = 0
319
+ for sample in data:
320
+ assert 'speech_feat' in sample
321
+ assert isinstance(sample['speech_feat'], torch.Tensor)
322
+ new_sample_frames = sample['speech_feat'].size(0)
323
+ longest_frames = max(longest_frames, new_sample_frames)
324
+ frames_after_padding = longest_frames * (len(buf) + 1)
325
+ if frames_after_padding > max_frames_in_batch:
326
+ yield buf
327
+ buf = [sample]
328
+ longest_frames = new_sample_frames
329
+ else:
330
+ buf.append(sample)
331
+ if len(buf) > 0:
332
+ yield buf
333
+
334
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
335
+ """ Wrapper for static/dynamic batch
336
+ """
337
+ if mode == 'inference':
338
+ return static_batch(data, 1)
339
+ else:
340
+ if batch_type == 'static':
341
+ return static_batch(data, batch_size)
342
+ elif batch_type == 'dynamic':
343
+ return dynamic_batch(data, max_frames_in_batch)
344
+ else:
345
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
346
+
347
+ def padding(data, use_spk_embedding, mode='train', gan=False):
348
+ """ Padding the data into training data
349
+
350
+ Args:
351
+ data: Iterable[List[{key, feat, label}]]
352
+
353
+ Returns:
354
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
355
+ """
356
+ for sample in data:
357
+ assert isinstance(sample, list)
358
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
359
+ dtype=torch.int32)
360
+ order = torch.argsort(speech_feat_len, descending=True)
361
+ # print("sample:", sample) #spk_embedding
362
+ utts = [sample[i]['utt'] for i in order]
363
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
364
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
365
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
366
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
367
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
368
+ speech_token = pad_sequence(speech_token,
369
+ batch_first=True,
370
+ padding_value=0)
371
+ speech_feat = [sample[i]['speech_feat'] for i in order]
372
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
373
+ speech_feat = pad_sequence(speech_feat,
374
+ batch_first=True,
375
+ padding_value=0)
376
+ text = [sample[i]['text'] for i in order]
377
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
378
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
379
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
380
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
381
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
382
+ if 'emotion_embedding' in sample[0]:
383
+ emotion_embedding = torch.stack([sample[i]['emotion_embedding'] for i in order], dim=0)
384
+ batch = {
385
+ "utts": utts,
386
+ "speech": speech,
387
+ "speech_len": speech_len,
388
+ "speech_token": speech_token,
389
+ "speech_token_len": speech_token_len,
390
+ "speech_feat": speech_feat,
391
+ "speech_feat_len": speech_feat_len,
392
+ "text": text,
393
+ "text_token": text_token,
394
+ "text_token_len": text_token_len,
395
+ "utt_embedding": utt_embedding,
396
+ "spk_embedding": spk_embedding,
397
+ }
398
+ if 'emotion_embedding' in sample[0]:
399
+ batch["emotion_embedding"] = emotion_embedding
400
+ if gan is True:
401
+ # in gan train, we need pitch_feat
402
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
403
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
404
+ pitch_feat = pad_sequence(pitch_feat,
405
+ batch_first=True,
406
+ padding_value=0)
407
+ batch["pitch_feat"] = pitch_feat
408
+ batch["pitch_feat_len"] = pitch_feat_len
409
+ else:
410
+ # only gan train needs speech, delete it to save memory
411
+ del batch["speech"]
412
+ del batch["speech_len"]
413
+ if mode == 'inference':
414
+ tts_text = [sample[i]['tts_text'] for i in order]
415
+ tts_index = [sample[i]['tts_index'] for i in order]
416
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
417
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
418
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
419
+ batch.update({'tts_text': tts_text,
420
+ 'tts_index': tts_index,
421
+ 'tts_text_token': tts_text_token,
422
+ 'tts_text_token_len': tts_text_token_len})
423
+ if use_spk_embedding is True:
424
+ batch["embedding"] = batch["spk_embedding"]
425
+ else:
426
+ batch["embedding"] = batch["utt_embedding"]
427
+ yield batch
cosyvoice_rodis/flow/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (5.19 kB). View file
 
cosyvoice_rodis/flow/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (5.27 kB). View file
 
cosyvoice_rodis/flow/__pycache__/decoder.cpython-39.pyc ADDED
Binary file (5.23 kB). View file
 
cosyvoice_rodis/flow/__pycache__/flow.cpython-310.pyc ADDED
Binary file (5.22 kB). View file
 
cosyvoice_rodis/flow/__pycache__/flow.cpython-38.pyc ADDED
Binary file (5.2 kB). View file
 
cosyvoice_rodis/flow/__pycache__/flow.cpython-39.pyc ADDED
Binary file (4.21 kB). View file
 
cosyvoice_rodis/flow/__pycache__/flow_matching.cpython-310.pyc ADDED
Binary file (5.6 kB). View file
 
cosyvoice_rodis/flow/__pycache__/flow_matching.cpython-38.pyc ADDED
Binary file (5.61 kB). View file
 
cosyvoice_rodis/flow/__pycache__/flow_matching.cpython-39.pyc ADDED
Binary file (5.45 kB). View file
 
cosyvoice_rodis/flow/__pycache__/length_regulator.cpython-310.pyc ADDED
Binary file (2.23 kB). View file