Spaces:
Running
Running
tianfengping.tfp
commited on
Commit
·
149fbcd
1
Parent(s):
9741af6
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CODE_OF_CONDUCT.md +76 -0
- FAQ.md +16 -0
- LICENSE +201 -0
- README.md +237 -13
- cosyvoice/__init__.py +0 -0
- cosyvoice/__pycache__/__init__.cpython-310.pyc +0 -0
- cosyvoice/__pycache__/__init__.cpython-38.pyc +0 -0
- cosyvoice/bin/average_model.py +92 -0
- cosyvoice/bin/export_jit.py +91 -0
- cosyvoice/bin/export_onnx.py +116 -0
- cosyvoice/bin/export_trt.sh +10 -0
- cosyvoice/bin/inference.py +115 -0
- cosyvoice/bin/train.py +170 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/__pycache__/__init__.cpython-310.pyc +0 -0
- cosyvoice/cli/__pycache__/__init__.cpython-38.pyc +0 -0
- cosyvoice/cli/__pycache__/model.cpython-310.pyc +0 -0
- cosyvoice/cli/__pycache__/model.cpython-38.pyc +0 -0
- cosyvoice/cli/cosyvoice.py +173 -0
- cosyvoice/cli/frontend.py +211 -0
- cosyvoice/cli/model.py +411 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
- cosyvoice/dataset/__pycache__/__init__.cpython-38.pyc +0 -0
- cosyvoice/dataset/__pycache__/processor.cpython-310.pyc +0 -0
- cosyvoice/dataset/__pycache__/processor.cpython-38.pyc +0 -0
- cosyvoice/dataset/dataset.py +164 -0
- cosyvoice/dataset/processor.py +435 -0
- cosyvoice/flow/__pycache__/decoder.cpython-310.pyc +0 -0
- cosyvoice/flow/__pycache__/decoder.cpython-38.pyc +0 -0
- cosyvoice/flow/__pycache__/flow.cpython-310.pyc +0 -0
- cosyvoice/flow/__pycache__/flow.cpython-38.pyc +0 -0
- cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc +0 -0
- cosyvoice/flow/__pycache__/flow_matching.cpython-38.pyc +0 -0
- cosyvoice/flow/__pycache__/length_regulator.cpython-310.pyc +0 -0
- cosyvoice/flow/__pycache__/length_regulator.cpython-38.pyc +0 -0
- cosyvoice/flow/decoder.py +301 -0
- cosyvoice/flow/flow.py +240 -0
- cosyvoice/flow/flow_matching.py +217 -0
- cosyvoice/flow/length_regulator.py +69 -0
- cosyvoice/flow_speaker_minus/__pycache__/decoder.cpython-310.pyc +0 -0
- cosyvoice/flow_speaker_minus/__pycache__/flow.cpython-310.pyc +0 -0
- cosyvoice/flow_speaker_minus/__pycache__/flow.cpython-38.pyc +0 -0
- cosyvoice/flow_speaker_minus/__pycache__/flow_matching.cpython-310.pyc +0 -0
- cosyvoice/flow_speaker_minus/__pycache__/length_regulator.cpython-310.pyc +0 -0
- cosyvoice/flow_speaker_minus/decoder.py +301 -0
- cosyvoice/flow_speaker_minus/flow.py +184 -0
- cosyvoice/flow_speaker_minus/flow_matching.py +217 -0
- cosyvoice/flow_speaker_minus/length_regulator.py +69 -0
- cosyvoice/hifigan/__pycache__/discriminator.cpython-310.pyc +0 -0
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to making participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies both within project spaces and in public spaces
|
| 49 |
+
when an individual is representing the project or its community. Examples of
|
| 50 |
+
representing a project or community include using an official project e-mail
|
| 51 |
+
address, posting via an official social media account, or acting as an appointed
|
| 52 |
+
representative at an online or offline event. Representation of a project may be
|
| 53 |
+
further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
## Enforcement
|
| 56 |
+
|
| 57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 58 |
+
reported by contacting the project team at mikelei@mobvoi.com. All
|
| 59 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 62 |
+
Further details of specific enforcement policies may be posted separately.
|
| 63 |
+
|
| 64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 65 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 66 |
+
members of the project's leadership.
|
| 67 |
+
|
| 68 |
+
## Attribution
|
| 69 |
+
|
| 70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 72 |
+
|
| 73 |
+
[homepage]: https://www.contributor-covenant.org
|
| 74 |
+
|
| 75 |
+
For answers to common questions about this code of conduct, see
|
| 76 |
+
https://www.contributor-covenant.org/faq
|
FAQ.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## ModuleNotFoundError: No module named 'matcha'
|
| 2 |
+
|
| 3 |
+
Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
|
| 4 |
+
|
| 5 |
+
run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
|
| 6 |
+
|
| 7 |
+
## cannot find resource.zip or cannot unzip resource.zip
|
| 8 |
+
|
| 9 |
+
Please make sure you have git-lfs installed. Execute
|
| 10 |
+
|
| 11 |
+
```sh
|
| 12 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
| 13 |
+
cd pretrained_models/CosyVoice-ttsfrd/
|
| 14 |
+
unzip resource.zip -d .
|
| 15 |
+
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
|
| 16 |
+
```
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,13 +1,237 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[](https://github.com/Akshay090/svg-banners)
|
| 2 |
+
|
| 3 |
+
## 👉🏻 CosyVoice 👈🏻
|
| 4 |
+
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
|
| 5 |
+
|
| 6 |
+
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
|
| 7 |
+
|
| 8 |
+
## Highlight🔥
|
| 9 |
+
|
| 10 |
+
**CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
|
| 11 |
+
### Multilingual
|
| 12 |
+
- **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
|
| 13 |
+
- **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
|
| 14 |
+
### Ultra-Low Latency
|
| 15 |
+
- **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
|
| 16 |
+
- **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
|
| 17 |
+
### High Accuracy
|
| 18 |
+
- **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
|
| 19 |
+
- **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
|
| 20 |
+
### Strong Stability
|
| 21 |
+
- **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
|
| 22 |
+
- **Cross-language Synthesis**: Marked improvements compared to version 1.0.
|
| 23 |
+
### Natural Experience
|
| 24 |
+
- **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
|
| 25 |
+
- **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
|
| 26 |
+
|
| 27 |
+
## Roadmap
|
| 28 |
+
|
| 29 |
+
- [x] 2024/12
|
| 30 |
+
|
| 31 |
+
- [x] 25hz cosyvoice 2.0 released
|
| 32 |
+
|
| 33 |
+
- [x] 2024/09
|
| 34 |
+
|
| 35 |
+
- [x] 25hz cosyvoice base model
|
| 36 |
+
- [x] 25hz cosyvoice voice conversion model
|
| 37 |
+
|
| 38 |
+
- [x] 2024/08
|
| 39 |
+
|
| 40 |
+
- [x] Repetition Aware Sampling(RAS) inference for llm stability
|
| 41 |
+
- [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
|
| 42 |
+
|
| 43 |
+
- [x] 2024/07
|
| 44 |
+
|
| 45 |
+
- [x] Flow matching training support
|
| 46 |
+
- [x] WeTextProcessing support when ttsfrd is not available
|
| 47 |
+
- [x] Fastapi server and client
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## Install
|
| 51 |
+
|
| 52 |
+
**Clone and install**
|
| 53 |
+
|
| 54 |
+
- Clone the repo
|
| 55 |
+
``` sh
|
| 56 |
+
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
| 57 |
+
# If you failed to clone submodule due to network failures, please run following command until success
|
| 58 |
+
cd CosyVoice
|
| 59 |
+
git submodule update --init --recursive
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
| 63 |
+
- Create Conda env:
|
| 64 |
+
|
| 65 |
+
``` sh
|
| 66 |
+
conda create -n cosyvoice -y python=3.10
|
| 67 |
+
conda activate cosyvoice
|
| 68 |
+
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
|
| 69 |
+
conda install -y -c conda-forge pynini==2.1.5
|
| 70 |
+
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
| 71 |
+
|
| 72 |
+
# If you encounter sox compatibility issues
|
| 73 |
+
# ubuntu
|
| 74 |
+
sudo apt-get install sox libsox-dev
|
| 75 |
+
# centos
|
| 76 |
+
sudo yum install sox sox-devel
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
**Model download**
|
| 80 |
+
|
| 81 |
+
We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
| 82 |
+
|
| 83 |
+
``` python
|
| 84 |
+
# SDK模型下载
|
| 85 |
+
from modelscope import snapshot_download
|
| 86 |
+
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
| 87 |
+
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
| 88 |
+
snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
|
| 89 |
+
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
| 90 |
+
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
| 91 |
+
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
``` sh
|
| 95 |
+
# git模型下载,请确保已安装git lfs
|
| 96 |
+
mkdir -p pretrained_models
|
| 97 |
+
git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
|
| 98 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
|
| 99 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
|
| 100 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
|
| 101 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
|
| 102 |
+
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Optionally, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
|
| 106 |
+
|
| 107 |
+
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
|
| 108 |
+
|
| 109 |
+
``` sh
|
| 110 |
+
cd pretrained_models/CosyVoice-ttsfrd/
|
| 111 |
+
unzip resource.zip -d .
|
| 112 |
+
pip install ttsfrd_dependency-0.1-py3-none-any.whl
|
| 113 |
+
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
**Basic Usage**
|
| 117 |
+
|
| 118 |
+
We strongly recommend using `CosyVoice2-0.5B` for better performance.
|
| 119 |
+
Follow code below for detailed usage of each model.
|
| 120 |
+
|
| 121 |
+
``` python
|
| 122 |
+
import sys
|
| 123 |
+
sys.path.append('third_party/Matcha-TTS')
|
| 124 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 125 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 126 |
+
import torchaudio
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
**CosyVoice2 Usage**
|
| 130 |
+
```python
|
| 131 |
+
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
|
| 132 |
+
|
| 133 |
+
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
| 134 |
+
# zero_shot usage
|
| 135 |
+
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
| 136 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
| 137 |
+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 138 |
+
|
| 139 |
+
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
|
| 140 |
+
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
|
| 141 |
+
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 142 |
+
|
| 143 |
+
# instruct usage
|
| 144 |
+
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
|
| 145 |
+
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 146 |
+
|
| 147 |
+
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
| 148 |
+
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
| 149 |
+
def text_generator():
|
| 150 |
+
yield '收到好友从远方寄来的生日礼物,'
|
| 151 |
+
yield '那份意外的惊喜与深深的祝福'
|
| 152 |
+
yield '让我心中充满了甜蜜的快乐,'
|
| 153 |
+
yield '笑容如花儿般绽放。'
|
| 154 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
| 155 |
+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
**CosyVoice Usage**
|
| 159 |
+
```python
|
| 160 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
|
| 161 |
+
# sft usage
|
| 162 |
+
print(cosyvoice.list_available_spks())
|
| 163 |
+
# change stream=True for chunk stream inference
|
| 164 |
+
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
| 165 |
+
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 166 |
+
|
| 167 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') # or change to pretrained_models/CosyVoice-300M-25Hz for 25Hz inference
|
| 168 |
+
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
| 169 |
+
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
| 170 |
+
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
| 171 |
+
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 172 |
+
# cross_lingual usage
|
| 173 |
+
prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
|
| 174 |
+
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
|
| 175 |
+
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 176 |
+
# vc usage
|
| 177 |
+
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
| 178 |
+
source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
|
| 179 |
+
for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
|
| 180 |
+
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 181 |
+
|
| 182 |
+
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
| 183 |
+
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
| 184 |
+
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
|
| 185 |
+
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
**Start web demo**
|
| 189 |
+
|
| 190 |
+
You can use our web demo page to get familiar with CosyVoice quickly.
|
| 191 |
+
|
| 192 |
+
Please see the demo website for details.
|
| 193 |
+
|
| 194 |
+
``` python
|
| 195 |
+
# change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
|
| 196 |
+
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
**Advanced Usage**
|
| 200 |
+
|
| 201 |
+
For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
|
| 202 |
+
|
| 203 |
+
**Build for deployment**
|
| 204 |
+
|
| 205 |
+
Optionally, if you want service deployment,
|
| 206 |
+
you can run following steps.
|
| 207 |
+
|
| 208 |
+
``` sh
|
| 209 |
+
cd runtime/python
|
| 210 |
+
docker build -t cosyvoice:v1.0 .
|
| 211 |
+
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
| 212 |
+
# for grpc usage
|
| 213 |
+
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
| 214 |
+
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
| 215 |
+
# for fastapi usage
|
| 216 |
+
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
|
| 217 |
+
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
## Discussion & Communication
|
| 221 |
+
|
| 222 |
+
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
| 223 |
+
|
| 224 |
+
You can also scan the QR code to join our official Dingding chat group.
|
| 225 |
+
|
| 226 |
+
<img src="./asset/dingding.png" width="250px">
|
| 227 |
+
|
| 228 |
+
## Acknowledge
|
| 229 |
+
|
| 230 |
+
1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
|
| 231 |
+
2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
|
| 232 |
+
3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
|
| 233 |
+
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
|
| 234 |
+
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
| 235 |
+
|
| 236 |
+
## Disclaimer
|
| 237 |
+
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
cosyvoice/__init__.py
ADDED
|
File without changes
|
cosyvoice/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (148 Bytes). View file
|
|
|
cosyvoice/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
cosyvoice/bin/average_model.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
| 2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import argparse
|
| 18 |
+
import glob
|
| 19 |
+
|
| 20 |
+
import yaml
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_args():
|
| 25 |
+
parser = argparse.ArgumentParser(description='average model')
|
| 26 |
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
| 27 |
+
parser.add_argument('--src_path',
|
| 28 |
+
required=True,
|
| 29 |
+
help='src model path for average')
|
| 30 |
+
parser.add_argument('--val_best',
|
| 31 |
+
action="store_true",
|
| 32 |
+
help='averaged model')
|
| 33 |
+
parser.add_argument('--num',
|
| 34 |
+
default=5,
|
| 35 |
+
type=int,
|
| 36 |
+
help='nums for averaged model')
|
| 37 |
+
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
print(args)
|
| 40 |
+
return args
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
args = get_args()
|
| 45 |
+
val_scores = []
|
| 46 |
+
if args.val_best:
|
| 47 |
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
| 48 |
+
yamls = [
|
| 49 |
+
f for f in yamls
|
| 50 |
+
if not (os.path.basename(f).startswith('train')
|
| 51 |
+
or os.path.basename(f).startswith('init'))
|
| 52 |
+
]
|
| 53 |
+
for y in yamls:
|
| 54 |
+
with open(y, 'r') as f:
|
| 55 |
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
| 56 |
+
loss = float(dic_yaml['loss_dict']['loss'])
|
| 57 |
+
epoch = int(dic_yaml['epoch'])
|
| 58 |
+
step = int(dic_yaml['step'])
|
| 59 |
+
tag = dic_yaml['tag']
|
| 60 |
+
val_scores += [[epoch, step, loss, tag]]
|
| 61 |
+
sorted_val_scores = sorted(val_scores,
|
| 62 |
+
key=lambda x: x[2],
|
| 63 |
+
reverse=False)
|
| 64 |
+
print("best val (epoch, step, loss, tag) = " +
|
| 65 |
+
str(sorted_val_scores[:args.num]))
|
| 66 |
+
path_list = [
|
| 67 |
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
| 68 |
+
for score in sorted_val_scores[:args.num]
|
| 69 |
+
]
|
| 70 |
+
print(path_list)
|
| 71 |
+
avg = {}
|
| 72 |
+
num = args.num
|
| 73 |
+
assert num == len(path_list)
|
| 74 |
+
for path in path_list:
|
| 75 |
+
print('Processing {}'.format(path))
|
| 76 |
+
states = torch.load(path, map_location=torch.device('cpu'))
|
| 77 |
+
for k in states.keys():
|
| 78 |
+
if k not in avg.keys():
|
| 79 |
+
avg[k] = states[k].clone()
|
| 80 |
+
else:
|
| 81 |
+
avg[k] += states[k]
|
| 82 |
+
# average
|
| 83 |
+
for k in avg.keys():
|
| 84 |
+
if avg[k] is not None:
|
| 85 |
+
# pytorch 1.6 use true_divide instead of /=
|
| 86 |
+
avg[k] = torch.true_divide(avg[k], num)
|
| 87 |
+
print('Saving to {}'.format(args.dst_model))
|
| 88 |
+
torch.save(avg, args.dst_model)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == '__main__':
|
| 92 |
+
main()
|
cosyvoice/bin/export_jit.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import torch
|
| 23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
| 25 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
| 26 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_args():
|
| 30 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
| 31 |
+
parser.add_argument('--model_dir',
|
| 32 |
+
type=str,
|
| 33 |
+
default='pretrained_models/CosyVoice-300M',
|
| 34 |
+
help='local path')
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
print(args)
|
| 37 |
+
return args
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_optimized_script(model, preserved_attrs=[]):
|
| 41 |
+
script = torch.jit.script(model)
|
| 42 |
+
if preserved_attrs != []:
|
| 43 |
+
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
| 44 |
+
else:
|
| 45 |
+
script = torch.jit.freeze(script)
|
| 46 |
+
script = torch.jit.optimize_for_inference(script)
|
| 47 |
+
return script
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
args = get_args()
|
| 52 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 53 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 54 |
+
|
| 55 |
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
| 56 |
+
torch._C._jit_set_profiling_mode(False)
|
| 57 |
+
torch._C._jit_set_profiling_executor(False)
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
model = CosyVoice(args.model_dir)
|
| 61 |
+
except Exception:
|
| 62 |
+
try:
|
| 63 |
+
model = CosyVoice2(args.model_dir)
|
| 64 |
+
except Exception:
|
| 65 |
+
raise TypeError('no valid model_type!')
|
| 66 |
+
|
| 67 |
+
if not isinstance(model, CosyVoice2):
|
| 68 |
+
# 1. export llm text_encoder
|
| 69 |
+
llm_text_encoder = model.model.llm.text_encoder
|
| 70 |
+
script = get_optimized_script(llm_text_encoder)
|
| 71 |
+
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
| 72 |
+
script = get_optimized_script(llm_text_encoder.half())
|
| 73 |
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
| 74 |
+
|
| 75 |
+
# 2. export llm llm
|
| 76 |
+
llm_llm = model.model.llm.llm
|
| 77 |
+
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
| 78 |
+
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
| 79 |
+
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
| 80 |
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
| 81 |
+
|
| 82 |
+
# 3. export flow encoder
|
| 83 |
+
flow_encoder = model.model.flow.encoder
|
| 84 |
+
script = get_optimized_script(flow_encoder)
|
| 85 |
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
| 86 |
+
script = get_optimized_script(flow_encoder.half())
|
| 87 |
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == '__main__':
|
| 91 |
+
main()
|
cosyvoice/bin/export_onnx.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
| 2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import print_function
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import logging
|
| 20 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import onnxruntime
|
| 24 |
+
import random
|
| 25 |
+
import torch
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 28 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
| 29 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
| 30 |
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
| 34 |
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 35 |
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
| 36 |
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 37 |
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
| 38 |
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
| 39 |
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 40 |
+
return x, mask, mu, t, spks, cond
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_args():
|
| 44 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
| 45 |
+
parser.add_argument('--model_dir',
|
| 46 |
+
type=str,
|
| 47 |
+
default='pretrained_models/CosyVoice-300M',
|
| 48 |
+
help='local path')
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
print(args)
|
| 51 |
+
return args
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
args = get_args()
|
| 56 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 57 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
model = CosyVoice(args.model_dir)
|
| 61 |
+
except Exception:
|
| 62 |
+
try:
|
| 63 |
+
model = CosyVoice2(args.model_dir)
|
| 64 |
+
except Exception:
|
| 65 |
+
raise TypeError('no valid model_type!')
|
| 66 |
+
|
| 67 |
+
# 1. export flow decoder estimator
|
| 68 |
+
estimator = model.model.flow.decoder.estimator
|
| 69 |
+
|
| 70 |
+
device = model.model.device
|
| 71 |
+
batch_size, seq_len = 2, 256
|
| 72 |
+
out_channels = model.model.flow.decoder.estimator.out_channels
|
| 73 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
| 74 |
+
torch.onnx.export(
|
| 75 |
+
estimator,
|
| 76 |
+
(x, mask, mu, t, spks, cond),
|
| 77 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
| 78 |
+
export_params=True,
|
| 79 |
+
opset_version=18,
|
| 80 |
+
do_constant_folding=True,
|
| 81 |
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
| 82 |
+
output_names=['estimator_out'],
|
| 83 |
+
dynamic_axes={
|
| 84 |
+
'x': {2: 'seq_len'},
|
| 85 |
+
'mask': {2: 'seq_len'},
|
| 86 |
+
'mu': {2: 'seq_len'},
|
| 87 |
+
'cond': {2: 'seq_len'},
|
| 88 |
+
'estimator_out': {2: 'seq_len'},
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# 2. test computation consistency
|
| 93 |
+
option = onnxruntime.SessionOptions()
|
| 94 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 95 |
+
option.intra_op_num_threads = 1
|
| 96 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
| 97 |
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
| 98 |
+
sess_options=option, providers=providers)
|
| 99 |
+
|
| 100 |
+
for _ in tqdm(range(10)):
|
| 101 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
| 102 |
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
| 103 |
+
ort_inputs = {
|
| 104 |
+
'x': x.cpu().numpy(),
|
| 105 |
+
'mask': mask.cpu().numpy(),
|
| 106 |
+
'mu': mu.cpu().numpy(),
|
| 107 |
+
't': t.cpu().numpy(),
|
| 108 |
+
'spks': spks.cpu().numpy(),
|
| 109 |
+
'cond': cond.cpu().numpy()
|
| 110 |
+
}
|
| 111 |
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
| 112 |
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
main()
|
cosyvoice/bin/export_trt.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
| 3 |
+
# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
|
| 4 |
+
# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
|
| 5 |
+
TRT_DIR=<YOUR_TRT_DIR>
|
| 6 |
+
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
|
| 7 |
+
|
| 8 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
|
| 9 |
+
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
|
| 10 |
+
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
|
cosyvoice/bin/inference.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
import os
|
| 21 |
+
import torch
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
import torchaudio
|
| 24 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
from cosyvoice.cli.model import CosyVoiceModel
|
| 27 |
+
from cosyvoice.dataset.dataset import Dataset
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_args():
|
| 31 |
+
parser = argparse.ArgumentParser(description='inference with your model')
|
| 32 |
+
parser.add_argument('--config', required=True, help='config file')
|
| 33 |
+
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
| 34 |
+
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
| 35 |
+
parser.add_argument('--tts_text', required=True, help='tts input file')
|
| 36 |
+
parser.add_argument('--llm_model', required=True, help='llm model file')
|
| 37 |
+
parser.add_argument('--flow_model', required=True, help='flow model file')
|
| 38 |
+
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
| 39 |
+
parser.add_argument('--gpu',
|
| 40 |
+
type=int,
|
| 41 |
+
default=-1,
|
| 42 |
+
help='gpu id for this rank, -1 for cpu')
|
| 43 |
+
parser.add_argument('--mode',
|
| 44 |
+
default='sft',
|
| 45 |
+
choices=['sft', 'zero_shot'],
|
| 46 |
+
help='inference mode')
|
| 47 |
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
print(args)
|
| 50 |
+
return args
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
args = get_args()
|
| 55 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 56 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 57 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
| 58 |
+
|
| 59 |
+
# Init cosyvoice models from configs
|
| 60 |
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
| 61 |
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
| 62 |
+
with open(args.config, 'r') as f:
|
| 63 |
+
configs = load_hyperpyyaml(f)
|
| 64 |
+
|
| 65 |
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
| 66 |
+
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
| 67 |
+
|
| 68 |
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
| 69 |
+
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
| 70 |
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
| 71 |
+
|
| 72 |
+
del configs
|
| 73 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
| 74 |
+
fn = os.path.join(args.result_dir, 'wav.scp')
|
| 75 |
+
f = open(fn, 'w')
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
for _, batch in tqdm(enumerate(test_data_loader)):
|
| 78 |
+
utts = batch["utts"]
|
| 79 |
+
assert len(utts) == 1, "inference mode only support batchsize 1"
|
| 80 |
+
text_token = batch["text_token"].to(device)
|
| 81 |
+
text_token_len = batch["text_token_len"].to(device)
|
| 82 |
+
tts_index = batch["tts_index"]
|
| 83 |
+
tts_text_token = batch["tts_text_token"].to(device)
|
| 84 |
+
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
| 85 |
+
speech_token = batch["speech_token"].to(device)
|
| 86 |
+
speech_token_len = batch["speech_token_len"].to(device)
|
| 87 |
+
speech_feat = batch["speech_feat"].to(device)
|
| 88 |
+
speech_feat_len = batch["speech_feat_len"].to(device)
|
| 89 |
+
utt_embedding = batch["utt_embedding"].to(device)
|
| 90 |
+
spk_embedding = batch["spk_embedding"].to(device)
|
| 91 |
+
if args.mode == 'sft':
|
| 92 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 93 |
+
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
| 94 |
+
else:
|
| 95 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 96 |
+
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
| 97 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 98 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 99 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 100 |
+
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
| 101 |
+
tts_speeches = []
|
| 102 |
+
for model_output in model.tts(**model_input):
|
| 103 |
+
tts_speeches.append(model_output['tts_speech'])
|
| 104 |
+
tts_speeches = torch.concat(tts_speeches, dim=1)
|
| 105 |
+
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
| 106 |
+
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
| 107 |
+
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
| 108 |
+
f.write('{} {}\n'.format(tts_key, tts_fn))
|
| 109 |
+
f.flush()
|
| 110 |
+
f.close()
|
| 111 |
+
logging.info('Result wav.scp saved in {}'.format(fn))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == '__main__':
|
| 115 |
+
main()
|
cosyvoice/bin/train.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
import argparse
|
| 17 |
+
import datetime
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
import os
|
| 22 |
+
import torch
|
| 23 |
+
import torch.distributed as dist
|
| 24 |
+
import deepspeed
|
| 25 |
+
|
| 26 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 27 |
+
|
| 28 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
| 29 |
+
|
| 30 |
+
from cosyvoice.utils.executor import Executor
|
| 31 |
+
from cosyvoice.utils.train_utils import (
|
| 32 |
+
init_distributed,
|
| 33 |
+
init_dataset_and_dataloader,
|
| 34 |
+
init_optimizer_and_scheduler,
|
| 35 |
+
init_summarywriter, save_model,
|
| 36 |
+
wrap_cuda_model, check_modify_and_save_config)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_args():
|
| 40 |
+
parser = argparse.ArgumentParser(description='training your network')
|
| 41 |
+
parser.add_argument('--train_engine',
|
| 42 |
+
default='torch_ddp',
|
| 43 |
+
choices=['torch_ddp', 'deepspeed'],
|
| 44 |
+
help='Engine for paralleled training')
|
| 45 |
+
parser.add_argument('--model', required=True, help='model which will be trained')
|
| 46 |
+
parser.add_argument('--config', required=True, help='config file')
|
| 47 |
+
parser.add_argument('--train_data', required=True, help='train data file')
|
| 48 |
+
parser.add_argument('--cv_data', required=True, help='cv data file')
|
| 49 |
+
parser.add_argument('--checkpoint', help='checkpoint model')
|
| 50 |
+
parser.add_argument('--model_dir', required=True, help='save model dir')
|
| 51 |
+
parser.add_argument('--tensorboard_dir',
|
| 52 |
+
default='tensorboard',
|
| 53 |
+
help='tensorboard log dir')
|
| 54 |
+
parser.add_argument('--ddp.dist_backend',
|
| 55 |
+
dest='dist_backend',
|
| 56 |
+
default='nccl',
|
| 57 |
+
choices=['nccl', 'gloo'],
|
| 58 |
+
help='distributed backend')
|
| 59 |
+
parser.add_argument('--num_workers',
|
| 60 |
+
default=0,
|
| 61 |
+
type=int,
|
| 62 |
+
help='num of subprocess workers for reading')
|
| 63 |
+
parser.add_argument('--prefetch',
|
| 64 |
+
default=100,
|
| 65 |
+
type=int,
|
| 66 |
+
help='prefetch number')
|
| 67 |
+
parser.add_argument('--pin_memory',
|
| 68 |
+
action='store_true',
|
| 69 |
+
default=False,
|
| 70 |
+
help='Use pinned memory buffers used for reading')
|
| 71 |
+
parser.add_argument('--use_amp',
|
| 72 |
+
action='store_true',
|
| 73 |
+
default=False,
|
| 74 |
+
help='Use automatic mixed precision training')
|
| 75 |
+
parser.add_argument('--deepspeed.save_states',
|
| 76 |
+
dest='save_states',
|
| 77 |
+
default='model_only',
|
| 78 |
+
choices=['model_only', 'model+optimizer'],
|
| 79 |
+
help='save model/optimizer states')
|
| 80 |
+
parser.add_argument('--timeout',
|
| 81 |
+
default=60,
|
| 82 |
+
type=int,
|
| 83 |
+
help='timeout (in seconds) of cosyvoice_join.')
|
| 84 |
+
parser = deepspeed.add_config_arguments(parser)
|
| 85 |
+
args = parser.parse_args()
|
| 86 |
+
return args
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@record
|
| 90 |
+
def main():
|
| 91 |
+
args = get_args()
|
| 92 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 93 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 94 |
+
# gan train has some special initialization logic
|
| 95 |
+
gan = True if args.model == 'hifigan' else False
|
| 96 |
+
|
| 97 |
+
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
| 98 |
+
if gan is True:
|
| 99 |
+
override_dict.pop('hift')
|
| 100 |
+
with open(args.config, 'r') as f:
|
| 101 |
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
| 102 |
+
if gan is True:
|
| 103 |
+
configs['train_conf'] = configs['train_conf_gan']
|
| 104 |
+
configs['train_conf'].update(vars(args))
|
| 105 |
+
|
| 106 |
+
# Init env for ddp
|
| 107 |
+
init_distributed(args)
|
| 108 |
+
|
| 109 |
+
# Get dataset & dataloader
|
| 110 |
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
| 111 |
+
init_dataset_and_dataloader(args, configs, gan)
|
| 112 |
+
|
| 113 |
+
# Do some sanity checks and save config to arsg.model_dir
|
| 114 |
+
configs = check_modify_and_save_config(args, configs)
|
| 115 |
+
|
| 116 |
+
# Tensorboard summary
|
| 117 |
+
writer = init_summarywriter(args)
|
| 118 |
+
|
| 119 |
+
# load checkpoint
|
| 120 |
+
model = configs[args.model]
|
| 121 |
+
start_step, start_epoch = 0, -1
|
| 122 |
+
if args.checkpoint is not None:
|
| 123 |
+
if os.path.exists(args.checkpoint):
|
| 124 |
+
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
| 125 |
+
model.load_state_dict(state_dict, strict=False)
|
| 126 |
+
if 'step' in state_dict:
|
| 127 |
+
start_step = state_dict['step']
|
| 128 |
+
if 'epoch' in state_dict:
|
| 129 |
+
start_epoch = state_dict['epoch']
|
| 130 |
+
else:
|
| 131 |
+
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
| 132 |
+
|
| 133 |
+
# Dispatch model from cpu to gpu
|
| 134 |
+
model = wrap_cuda_model(args, model)
|
| 135 |
+
|
| 136 |
+
# Get optimizer & scheduler
|
| 137 |
+
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
| 138 |
+
scheduler.set_step(start_step)
|
| 139 |
+
if scheduler_d is not None:
|
| 140 |
+
scheduler_d.set_step(start_step)
|
| 141 |
+
|
| 142 |
+
# Save init checkpoints
|
| 143 |
+
info_dict = deepcopy(configs['train_conf'])
|
| 144 |
+
info_dict['step'] = start_step
|
| 145 |
+
info_dict['epoch'] = start_epoch
|
| 146 |
+
save_model(model, 'init', info_dict)
|
| 147 |
+
|
| 148 |
+
# Get executor
|
| 149 |
+
executor = Executor(gan=gan)
|
| 150 |
+
executor.step = start_step
|
| 151 |
+
|
| 152 |
+
# Init scaler, used for pytorch amp mixed precision training
|
| 153 |
+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
| 154 |
+
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
| 155 |
+
# Start training loop
|
| 156 |
+
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
| 157 |
+
executor.epoch = epoch
|
| 158 |
+
train_dataset.set_epoch(epoch)
|
| 159 |
+
dist.barrier()
|
| 160 |
+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
| 161 |
+
if gan is True:
|
| 162 |
+
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
| 163 |
+
writer, info_dict, scaler, group_join)
|
| 164 |
+
else:
|
| 165 |
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
| 166 |
+
dist.destroy_process_group(group_join)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == '__main__':
|
| 170 |
+
main()
|
cosyvoice/cli/__init__.py
ADDED
|
File without changes
|
cosyvoice/cli/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (152 Bytes). View file
|
|
|
cosyvoice/cli/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (150 Bytes). View file
|
|
|
cosyvoice/cli/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
cosyvoice/cli/__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
cosyvoice/cli/cosyvoice.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from typing import Generator
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 19 |
+
from modelscope import snapshot_download
|
| 20 |
+
import torch
|
| 21 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
| 22 |
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
| 23 |
+
from cosyvoice.utils.file_utils import logging
|
| 24 |
+
from cosyvoice.utils.class_utils import get_model_type
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CosyVoice:
|
| 28 |
+
|
| 29 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
| 30 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
| 31 |
+
self.model_dir = model_dir
|
| 32 |
+
self.fp16 = fp16
|
| 33 |
+
if not os.path.exists(model_dir):
|
| 34 |
+
model_dir = snapshot_download(model_dir)
|
| 35 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
| 36 |
+
configs = load_hyperpyyaml(f)
|
| 37 |
+
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
| 38 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 39 |
+
configs['feat_extractor'],
|
| 40 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 41 |
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
| 42 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 43 |
+
configs['allowed_special'])
|
| 44 |
+
self.sample_rate = configs['sample_rate']
|
| 45 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
| 46 |
+
load_jit, load_trt, fp16 = False, False, False
|
| 47 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
| 48 |
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
| 49 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 50 |
+
'{}/flow.pt'.format(model_dir),
|
| 51 |
+
'{}/hift.pt'.format(model_dir))
|
| 52 |
+
if load_jit:
|
| 53 |
+
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 54 |
+
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 55 |
+
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
| 56 |
+
if load_trt:
|
| 57 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 58 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
| 59 |
+
self.fp16)
|
| 60 |
+
del configs
|
| 61 |
+
|
| 62 |
+
def list_available_spks(self):
|
| 63 |
+
spks = list(self.frontend.spk2info.keys())
|
| 64 |
+
return spks
|
| 65 |
+
|
| 66 |
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
| 67 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 68 |
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
| 69 |
+
start_time = time.time()
|
| 70 |
+
logging.info('synthesis text {}'.format(i))
|
| 71 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 72 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 73 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 74 |
+
yield model_output
|
| 75 |
+
start_time = time.time()
|
| 76 |
+
|
| 77 |
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 78 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
| 79 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 80 |
+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
| 81 |
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
| 82 |
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
| 83 |
+
start_time = time.time()
|
| 84 |
+
logging.info('synthesis text {}'.format(i))
|
| 85 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 86 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 87 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 88 |
+
yield model_output
|
| 89 |
+
start_time = time.time()
|
| 90 |
+
|
| 91 |
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 92 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 93 |
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
| 94 |
+
start_time = time.time()
|
| 95 |
+
logging.info('synthesis text {}'.format(i))
|
| 96 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 97 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 98 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 99 |
+
yield model_output
|
| 100 |
+
start_time = time.time()
|
| 101 |
+
|
| 102 |
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
| 103 |
+
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
| 104 |
+
if self.instruct is False:
|
| 105 |
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
| 106 |
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
| 107 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 108 |
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
| 109 |
+
start_time = time.time()
|
| 110 |
+
logging.info('synthesis text {}'.format(i))
|
| 111 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 112 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 113 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 114 |
+
yield model_output
|
| 115 |
+
start_time = time.time()
|
| 116 |
+
|
| 117 |
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
| 118 |
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
| 119 |
+
start_time = time.time()
|
| 120 |
+
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
| 121 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 122 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 123 |
+
yield model_output
|
| 124 |
+
start_time = time.time()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class CosyVoice2(CosyVoice):
|
| 128 |
+
|
| 129 |
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
| 130 |
+
self.instruct = True if '-Instruct' in model_dir else False
|
| 131 |
+
self.model_dir = model_dir
|
| 132 |
+
self.fp16 = fp16
|
| 133 |
+
if not os.path.exists(model_dir):
|
| 134 |
+
model_dir = snapshot_download(model_dir)
|
| 135 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
| 136 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
| 137 |
+
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
| 138 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 139 |
+
configs['feat_extractor'],
|
| 140 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 141 |
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
| 142 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 143 |
+
configs['allowed_special'])
|
| 144 |
+
self.sample_rate = configs['sample_rate']
|
| 145 |
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
| 146 |
+
load_jit, load_trt, fp16 = False, False, False
|
| 147 |
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
| 148 |
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
| 149 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 150 |
+
'{}/flow.pt'.format(model_dir),
|
| 151 |
+
'{}/hift.pt'.format(model_dir))
|
| 152 |
+
if load_jit:
|
| 153 |
+
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
| 154 |
+
if load_trt:
|
| 155 |
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
| 156 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
| 157 |
+
self.fp16)
|
| 158 |
+
del configs
|
| 159 |
+
|
| 160 |
+
def inference_instruct(self, *args, **kwargs):
|
| 161 |
+
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
| 162 |
+
|
| 163 |
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
| 164 |
+
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
| 165 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
| 166 |
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
| 167 |
+
start_time = time.time()
|
| 168 |
+
logging.info('synthesis text {}'.format(i))
|
| 169 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 170 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 171 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 172 |
+
yield model_output
|
| 173 |
+
start_time = time.time()
|
cosyvoice/cli/frontend.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Generator
|
| 16 |
+
import json
|
| 17 |
+
import onnxruntime
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import whisper
|
| 21 |
+
from typing import Callable
|
| 22 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 23 |
+
import torchaudio
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import inflect
|
| 27 |
+
try:
|
| 28 |
+
import ttsfrd
|
| 29 |
+
use_ttsfrd = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
| 32 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
| 33 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
| 34 |
+
use_ttsfrd = False
|
| 35 |
+
from cosyvoice.utils.file_utils import logging
|
| 36 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CosyVoiceFrontEnd:
|
| 40 |
+
|
| 41 |
+
def __init__(self,
|
| 42 |
+
get_tokenizer: Callable,
|
| 43 |
+
feat_extractor: Callable,
|
| 44 |
+
campplus_model: str,
|
| 45 |
+
speech_tokenizer_model: str,
|
| 46 |
+
spk2info: str = '',
|
| 47 |
+
allowed_special: str = 'all'):
|
| 48 |
+
self.tokenizer = get_tokenizer()
|
| 49 |
+
self.feat_extractor = feat_extractor
|
| 50 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 51 |
+
option = onnxruntime.SessionOptions()
|
| 52 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 53 |
+
option.intra_op_num_threads = 1
|
| 54 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
| 55 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 56 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 57 |
+
"CPUExecutionProvider"])
|
| 58 |
+
if os.path.exists(spk2info):
|
| 59 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
| 60 |
+
else:
|
| 61 |
+
self.spk2info = {}
|
| 62 |
+
self.allowed_special = allowed_special
|
| 63 |
+
self.use_ttsfrd = use_ttsfrd
|
| 64 |
+
if self.use_ttsfrd:
|
| 65 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
| 66 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 67 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
| 68 |
+
'failed to initialize ttsfrd resource'
|
| 69 |
+
self.frd.set_lang_type('pinyinvg')
|
| 70 |
+
else:
|
| 71 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
| 72 |
+
self.en_tn_model = EnNormalizer()
|
| 73 |
+
self.inflect_parser = inflect.engine()
|
| 74 |
+
|
| 75 |
+
def _extract_text_token(self, text):
|
| 76 |
+
if isinstance(text, Generator):
|
| 77 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
| 78 |
+
# NOTE add a dummy text_token_len for compatibility
|
| 79 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
| 80 |
+
else:
|
| 81 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
| 82 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
| 83 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 84 |
+
return text_token, text_token_len
|
| 85 |
+
|
| 86 |
+
def _extract_text_token_generator(self, text_generator):
|
| 87 |
+
for text in text_generator:
|
| 88 |
+
text_token, _ = self._extract_text_token(text)
|
| 89 |
+
for i in range(text_token.shape[1]):
|
| 90 |
+
yield text_token[:, i: i + 1]
|
| 91 |
+
|
| 92 |
+
def _extract_speech_token(self, speech):
|
| 93 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
| 94 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
| 95 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
| 96 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
| 97 |
+
feat.detach().cpu().numpy(),
|
| 98 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
| 99 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
| 100 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
| 101 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 102 |
+
return speech_token, speech_token_len
|
| 103 |
+
|
| 104 |
+
def _extract_spk_embedding(self, speech):
|
| 105 |
+
feat = kaldi.fbank(speech,
|
| 106 |
+
num_mel_bins=80,
|
| 107 |
+
dither=0,
|
| 108 |
+
sample_frequency=16000)
|
| 109 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 110 |
+
embedding = self.campplus_session.run(None,
|
| 111 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
| 112 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
| 113 |
+
return embedding
|
| 114 |
+
|
| 115 |
+
def _extract_speech_feat(self, speech):
|
| 116 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
| 117 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
| 118 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
| 119 |
+
return speech_feat, speech_feat_len
|
| 120 |
+
|
| 121 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
| 122 |
+
if isinstance(text, Generator):
|
| 123 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
| 124 |
+
return [text]
|
| 125 |
+
if text_frontend is False:
|
| 126 |
+
return [text] if split is True else text
|
| 127 |
+
text = text.strip()
|
| 128 |
+
if self.use_ttsfrd:
|
| 129 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
| 130 |
+
text = ''.join(texts)
|
| 131 |
+
else:
|
| 132 |
+
if contains_chinese(text):
|
| 133 |
+
text = self.zh_tn_model.normalize(text)
|
| 134 |
+
text = text.replace("\n", "")
|
| 135 |
+
text = replace_blank(text)
|
| 136 |
+
text = replace_corner_mark(text)
|
| 137 |
+
text = text.replace(".", "。")
|
| 138 |
+
text = text.replace(" - ", ",")
|
| 139 |
+
text = remove_bracket(text)
|
| 140 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
| 141 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
| 142 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 143 |
+
else:
|
| 144 |
+
text = self.en_tn_model.normalize(text)
|
| 145 |
+
text = spell_out_number(text, self.inflect_parser)
|
| 146 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
| 147 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 148 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
| 149 |
+
return texts if split is True else text
|
| 150 |
+
|
| 151 |
+
def frontend_sft(self, tts_text, spk_id):
|
| 152 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 153 |
+
embedding = self.spk2info[spk_id]['embedding']
|
| 154 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 155 |
+
return model_input
|
| 156 |
+
|
| 157 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
| 158 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 159 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
| 160 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 161 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 162 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 163 |
+
if resample_rate == 24000:
|
| 164 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 165 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 166 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 167 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 168 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 169 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 170 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 171 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 172 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 173 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 174 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 175 |
+
return model_input
|
| 176 |
+
|
| 177 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
| 178 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
| 179 |
+
# in cross lingual mode, we remove prompt in llm
|
| 180 |
+
del model_input['prompt_text']
|
| 181 |
+
del model_input['prompt_text_len']
|
| 182 |
+
del model_input['llm_prompt_speech_token']
|
| 183 |
+
del model_input['llm_prompt_speech_token_len']
|
| 184 |
+
return model_input
|
| 185 |
+
|
| 186 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
| 187 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
| 188 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
| 189 |
+
del model_input['llm_embedding']
|
| 190 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
| 191 |
+
model_input['prompt_text'] = instruct_text_token
|
| 192 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
| 193 |
+
return model_input
|
| 194 |
+
|
| 195 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
| 196 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
|
| 197 |
+
del model_input['llm_prompt_speech_token']
|
| 198 |
+
del model_input['llm_prompt_speech_token_len']
|
| 199 |
+
return model_input
|
| 200 |
+
|
| 201 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
| 202 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 203 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 204 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 205 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 206 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
| 207 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
| 208 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
| 209 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
| 210 |
+
'flow_embedding': embedding}
|
| 211 |
+
return model_input
|
cosyvoice/cli/model.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
from typing import Generator
|
| 16 |
+
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
+
import threading
|
| 19 |
+
import time
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
import uuid
|
| 23 |
+
from cosyvoice.utils.common import fade_in_out
|
| 24 |
+
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CosyVoiceModel:
|
| 28 |
+
|
| 29 |
+
def __init__(self,
|
| 30 |
+
llm: torch.nn.Module,
|
| 31 |
+
flow: torch.nn.Module,
|
| 32 |
+
hift: torch.nn.Module,
|
| 33 |
+
fp16: bool):
|
| 34 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 35 |
+
self.llm = llm
|
| 36 |
+
self.flow = flow
|
| 37 |
+
self.hift = hift
|
| 38 |
+
self.fp16 = fp16
|
| 39 |
+
self.llm.fp16 = fp16
|
| 40 |
+
self.flow.fp16 = fp16
|
| 41 |
+
if self.fp16 is True:
|
| 42 |
+
self.llm.half()
|
| 43 |
+
self.flow.half()
|
| 44 |
+
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
| 45 |
+
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
| 46 |
+
self.token_overlap_len = 20
|
| 47 |
+
# here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
|
| 48 |
+
self.flow.decoder.estimator.static_chunk_size = 0
|
| 49 |
+
# mel fade in out
|
| 50 |
+
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
| 51 |
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
| 52 |
+
# hift cache
|
| 53 |
+
self.mel_cache_len = 20
|
| 54 |
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
| 55 |
+
# speech fade in out
|
| 56 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
| 57 |
+
# rtf and decoding related
|
| 58 |
+
self.stream_scale_factor = 1
|
| 59 |
+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
| 60 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
| 61 |
+
self.lock = threading.Lock()
|
| 62 |
+
# dict used to store session related variable
|
| 63 |
+
self.tts_speech_token_dict = {}
|
| 64 |
+
self.llm_end_dict = {}
|
| 65 |
+
self.mel_overlap_dict = {}
|
| 66 |
+
self.flow_cache_dict = {}
|
| 67 |
+
self.hift_cache_dict = {}
|
| 68 |
+
|
| 69 |
+
def load(self, llm_model, flow_model, hift_model):
|
| 70 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
| 71 |
+
self.llm.to(self.device).eval()
|
| 72 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
| 73 |
+
self.flow.to(self.device).eval()
|
| 74 |
+
# in case hift_model is a hifigan model
|
| 75 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
| 76 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
| 77 |
+
self.hift.to(self.device).eval()
|
| 78 |
+
|
| 79 |
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
| 80 |
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
| 81 |
+
self.llm.text_encoder = llm_text_encoder
|
| 82 |
+
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
| 83 |
+
self.llm.llm = llm_llm
|
| 84 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
| 85 |
+
self.flow.encoder = flow_encoder
|
| 86 |
+
|
| 87 |
+
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
| 88 |
+
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
| 89 |
+
if not os.path.exists(flow_decoder_estimator_model):
|
| 90 |
+
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
|
| 91 |
+
if os.path.getsize(flow_decoder_estimator_model) == 0:
|
| 92 |
+
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
|
| 93 |
+
del self.flow.decoder.estimator
|
| 94 |
+
import tensorrt as trt
|
| 95 |
+
with open(flow_decoder_estimator_model, 'rb') as f:
|
| 96 |
+
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
| 97 |
+
if self.flow.decoder.estimator_engine is None:
|
| 98 |
+
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
| 99 |
+
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
| 100 |
+
|
| 101 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
| 102 |
+
with self.llm_context:
|
| 103 |
+
if isinstance(text, Generator):
|
| 104 |
+
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
| 105 |
+
for i in self.llm.inference_bistream(text=text,
|
| 106 |
+
prompt_text=prompt_text.to(self.device),
|
| 107 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
| 108 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
| 109 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 110 |
+
embedding=llm_embedding.to(self.device)):
|
| 111 |
+
self.tts_speech_token_dict[uuid].append(i)
|
| 112 |
+
else:
|
| 113 |
+
for i in self.llm.inference(text=text.to(self.device),
|
| 114 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
| 115 |
+
prompt_text=prompt_text.to(self.device),
|
| 116 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
| 117 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
| 118 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 119 |
+
embedding=llm_embedding.to(self.device)):
|
| 120 |
+
self.tts_speech_token_dict[uuid].append(i)
|
| 121 |
+
self.llm_end_dict[uuid] = True
|
| 122 |
+
|
| 123 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
| 124 |
+
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
| 125 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
| 126 |
+
prompt_token=prompt_token.to(self.device),
|
| 127 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 128 |
+
prompt_feat=prompt_feat.to(self.device),
|
| 129 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
| 130 |
+
embedding=embedding.to(self.device),
|
| 131 |
+
flow_cache=self.flow_cache_dict[uuid])
|
| 132 |
+
self.flow_cache_dict[uuid] = flow_cache
|
| 133 |
+
|
| 134 |
+
# mel overlap fade in out
|
| 135 |
+
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
| 136 |
+
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
| 137 |
+
# append hift cache
|
| 138 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 139 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
| 140 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
| 141 |
+
else:
|
| 142 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
| 143 |
+
# keep overlap mel and hift cache
|
| 144 |
+
if finalize is False:
|
| 145 |
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
| 146 |
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
| 147 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 148 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 149 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 150 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
| 151 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
| 152 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
| 153 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
| 154 |
+
else:
|
| 155 |
+
if speed != 1.0:
|
| 156 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
| 157 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
| 158 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 159 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 160 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 161 |
+
return tts_speech
|
| 162 |
+
|
| 163 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
| 164 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
| 165 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 166 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 167 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
| 168 |
+
# this_uuid is used to track variables related to this inference thread
|
| 169 |
+
this_uuid = str(uuid.uuid1())
|
| 170 |
+
with self.lock:
|
| 171 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
| 172 |
+
self.hift_cache_dict[this_uuid] = None
|
| 173 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
| 174 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
| 175 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
| 176 |
+
p.start()
|
| 177 |
+
if stream is True:
|
| 178 |
+
token_hop_len = self.token_min_hop_len
|
| 179 |
+
while True:
|
| 180 |
+
time.sleep(0.1)
|
| 181 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
| 182 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
| 183 |
+
.unsqueeze(dim=0)
|
| 184 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 185 |
+
prompt_token=flow_prompt_speech_token,
|
| 186 |
+
prompt_feat=prompt_speech_feat,
|
| 187 |
+
embedding=flow_embedding,
|
| 188 |
+
uuid=this_uuid,
|
| 189 |
+
finalize=False)
|
| 190 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 191 |
+
with self.lock:
|
| 192 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
| 193 |
+
# increase token_hop_len for better speech quality
|
| 194 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
| 195 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
| 196 |
+
break
|
| 197 |
+
p.join()
|
| 198 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 199 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 200 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 201 |
+
prompt_token=flow_prompt_speech_token,
|
| 202 |
+
prompt_feat=prompt_speech_feat,
|
| 203 |
+
embedding=flow_embedding,
|
| 204 |
+
uuid=this_uuid,
|
| 205 |
+
finalize=True)
|
| 206 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 207 |
+
else:
|
| 208 |
+
# deal with all tokens
|
| 209 |
+
p.join()
|
| 210 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 211 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 212 |
+
prompt_token=flow_prompt_speech_token,
|
| 213 |
+
prompt_feat=prompt_speech_feat,
|
| 214 |
+
embedding=flow_embedding,
|
| 215 |
+
uuid=this_uuid,
|
| 216 |
+
finalize=True,
|
| 217 |
+
speed=speed)
|
| 218 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 219 |
+
with self.lock:
|
| 220 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 221 |
+
self.llm_end_dict.pop(this_uuid)
|
| 222 |
+
self.mel_overlap_dict.pop(this_uuid)
|
| 223 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 224 |
+
self.flow_cache_dict.pop(this_uuid)
|
| 225 |
+
torch.cuda.empty_cache()
|
| 226 |
+
|
| 227 |
+
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
| 228 |
+
# this_uuid is used to track variables related to this inference thread
|
| 229 |
+
this_uuid = str(uuid.uuid1())
|
| 230 |
+
with self.lock:
|
| 231 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
| 232 |
+
self.hift_cache_dict[this_uuid] = None
|
| 233 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
| 234 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
| 235 |
+
if stream is True:
|
| 236 |
+
token_hop_len = self.token_min_hop_len
|
| 237 |
+
while True:
|
| 238 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
| 239 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
| 240 |
+
.unsqueeze(dim=0)
|
| 241 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 242 |
+
prompt_token=flow_prompt_speech_token,
|
| 243 |
+
prompt_feat=prompt_speech_feat,
|
| 244 |
+
embedding=flow_embedding,
|
| 245 |
+
uuid=this_uuid,
|
| 246 |
+
finalize=False)
|
| 247 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 248 |
+
with self.lock:
|
| 249 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
| 250 |
+
# increase token_hop_len for better speech quality
|
| 251 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
| 252 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
| 253 |
+
break
|
| 254 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 255 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 256 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 257 |
+
prompt_token=flow_prompt_speech_token,
|
| 258 |
+
prompt_feat=prompt_speech_feat,
|
| 259 |
+
embedding=flow_embedding,
|
| 260 |
+
uuid=this_uuid,
|
| 261 |
+
finalize=True)
|
| 262 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 263 |
+
else:
|
| 264 |
+
# deal with all tokens
|
| 265 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 266 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 267 |
+
prompt_token=flow_prompt_speech_token,
|
| 268 |
+
prompt_feat=prompt_speech_feat,
|
| 269 |
+
embedding=flow_embedding,
|
| 270 |
+
uuid=this_uuid,
|
| 271 |
+
finalize=True,
|
| 272 |
+
speed=speed)
|
| 273 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 274 |
+
with self.lock:
|
| 275 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 276 |
+
self.llm_end_dict.pop(this_uuid)
|
| 277 |
+
self.mel_overlap_dict.pop(this_uuid)
|
| 278 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 279 |
+
torch.cuda.empty_cache()
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class CosyVoice2Model(CosyVoiceModel):
|
| 283 |
+
|
| 284 |
+
def __init__(self,
|
| 285 |
+
llm: torch.nn.Module,
|
| 286 |
+
flow: torch.nn.Module,
|
| 287 |
+
hift: torch.nn.Module,
|
| 288 |
+
fp16: bool):
|
| 289 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 290 |
+
self.llm = llm
|
| 291 |
+
self.flow = flow
|
| 292 |
+
self.hift = hift
|
| 293 |
+
self.fp16 = fp16
|
| 294 |
+
self.llm.fp16 = fp16
|
| 295 |
+
self.flow.fp16 = fp16
|
| 296 |
+
if self.fp16 is True:
|
| 297 |
+
self.llm.half()
|
| 298 |
+
self.flow.half()
|
| 299 |
+
self.token_hop_len = 2 * self.flow.input_frame_rate
|
| 300 |
+
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
| 301 |
+
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
| 302 |
+
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
|
| 303 |
+
# hift cache
|
| 304 |
+
self.mel_cache_len = 8
|
| 305 |
+
self.source_cache_len = int(self.mel_cache_len * 480)
|
| 306 |
+
# speech fade in out
|
| 307 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
| 308 |
+
# rtf and decoding related
|
| 309 |
+
self.stream_scale_factor = 1
|
| 310 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
| 311 |
+
self.lock = threading.Lock()
|
| 312 |
+
# dict used to store session related variable
|
| 313 |
+
self.tts_speech_token_dict = {}
|
| 314 |
+
self.llm_end_dict = {}
|
| 315 |
+
self.hift_cache_dict = {}
|
| 316 |
+
|
| 317 |
+
def load_jit(self, flow_encoder_model):
|
| 318 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
| 319 |
+
self.flow.encoder = flow_encoder
|
| 320 |
+
|
| 321 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
| 322 |
+
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
| 323 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
| 324 |
+
prompt_token=prompt_token.to(self.device),
|
| 325 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 326 |
+
prompt_feat=prompt_feat.to(self.device),
|
| 327 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
| 328 |
+
embedding=embedding.to(self.device),
|
| 329 |
+
finalize=finalize)
|
| 330 |
+
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
| 331 |
+
# append hift cache
|
| 332 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 333 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
| 334 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
| 335 |
+
else:
|
| 336 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
| 337 |
+
# keep overlap mel and hift cache
|
| 338 |
+
if finalize is False:
|
| 339 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 340 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 341 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 342 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
| 343 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
| 344 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
| 345 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
| 346 |
+
else:
|
| 347 |
+
if speed != 1.0:
|
| 348 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
| 349 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
| 350 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 351 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 352 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 353 |
+
return tts_speech
|
| 354 |
+
|
| 355 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
| 356 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
| 357 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 358 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 359 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
| 360 |
+
# this_uuid is used to track variables related to this inference thread
|
| 361 |
+
this_uuid = str(uuid.uuid1())
|
| 362 |
+
with self.lock:
|
| 363 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
| 364 |
+
self.hift_cache_dict[this_uuid] = None
|
| 365 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
| 366 |
+
p.start()
|
| 367 |
+
if stream is True:
|
| 368 |
+
token_offset = 0
|
| 369 |
+
while True:
|
| 370 |
+
time.sleep(0.1)
|
| 371 |
+
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
| 372 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
| 373 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 374 |
+
prompt_token=flow_prompt_speech_token,
|
| 375 |
+
prompt_feat=prompt_speech_feat,
|
| 376 |
+
embedding=flow_embedding,
|
| 377 |
+
uuid=this_uuid,
|
| 378 |
+
token_offset=token_offset,
|
| 379 |
+
finalize=False)
|
| 380 |
+
token_offset += self.token_hop_len
|
| 381 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 382 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
| 383 |
+
break
|
| 384 |
+
p.join()
|
| 385 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 386 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 387 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 388 |
+
prompt_token=flow_prompt_speech_token,
|
| 389 |
+
prompt_feat=prompt_speech_feat,
|
| 390 |
+
embedding=flow_embedding,
|
| 391 |
+
uuid=this_uuid,
|
| 392 |
+
token_offset=token_offset,
|
| 393 |
+
finalize=True)
|
| 394 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 395 |
+
else:
|
| 396 |
+
# deal with all tokens
|
| 397 |
+
p.join()
|
| 398 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 399 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 400 |
+
prompt_token=flow_prompt_speech_token,
|
| 401 |
+
prompt_feat=prompt_speech_feat,
|
| 402 |
+
embedding=flow_embedding,
|
| 403 |
+
uuid=this_uuid,
|
| 404 |
+
token_offset=0,
|
| 405 |
+
finalize=True,
|
| 406 |
+
speed=speed)
|
| 407 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 408 |
+
with self.lock:
|
| 409 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 410 |
+
self.llm_end_dict.pop(this_uuid)
|
| 411 |
+
torch.cuda.empty_cache()
|
cosyvoice/dataset/__init__.py
ADDED
|
File without changes
|
cosyvoice/dataset/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
cosyvoice/dataset/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
cosyvoice/dataset/__pycache__/processor.cpython-310.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
cosyvoice/dataset/__pycache__/processor.cpython-38.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
cosyvoice/dataset/dataset.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
| 2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import random
|
| 17 |
+
import json
|
| 18 |
+
import math
|
| 19 |
+
from functools import partial
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from torch.utils.data import IterableDataset
|
| 24 |
+
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Processor(IterableDataset):
|
| 28 |
+
|
| 29 |
+
def __init__(self, source, f, *args, **kw):
|
| 30 |
+
assert callable(f)
|
| 31 |
+
self.source = source
|
| 32 |
+
self.f = f
|
| 33 |
+
self.args = args
|
| 34 |
+
self.kw = kw
|
| 35 |
+
|
| 36 |
+
def set_epoch(self, epoch):
|
| 37 |
+
self.source.set_epoch(epoch)
|
| 38 |
+
|
| 39 |
+
def __iter__(self):
|
| 40 |
+
""" Return an iterator over the source dataset processed by the
|
| 41 |
+
given processor.
|
| 42 |
+
"""
|
| 43 |
+
assert self.source is not None
|
| 44 |
+
assert callable(self.f)
|
| 45 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
| 46 |
+
|
| 47 |
+
def apply(self, f):
|
| 48 |
+
assert callable(f)
|
| 49 |
+
return Processor(self, f, *self.args, **self.kw)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DistributedSampler:
|
| 53 |
+
|
| 54 |
+
def __init__(self, shuffle=True, partition=True):
|
| 55 |
+
self.epoch = -1
|
| 56 |
+
self.update()
|
| 57 |
+
self.shuffle = shuffle
|
| 58 |
+
self.partition = partition
|
| 59 |
+
|
| 60 |
+
def update(self):
|
| 61 |
+
assert dist.is_available()
|
| 62 |
+
if dist.is_initialized():
|
| 63 |
+
self.rank = dist.get_rank()
|
| 64 |
+
self.world_size = dist.get_world_size()
|
| 65 |
+
else:
|
| 66 |
+
self.rank = 0
|
| 67 |
+
self.world_size = 1
|
| 68 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 69 |
+
if worker_info is None:
|
| 70 |
+
self.worker_id = 0
|
| 71 |
+
self.num_workers = 1
|
| 72 |
+
else:
|
| 73 |
+
self.worker_id = worker_info.id
|
| 74 |
+
self.num_workers = worker_info.num_workers
|
| 75 |
+
return dict(rank=self.rank,
|
| 76 |
+
world_size=self.world_size,
|
| 77 |
+
worker_id=self.worker_id,
|
| 78 |
+
num_workers=self.num_workers)
|
| 79 |
+
|
| 80 |
+
def set_epoch(self, epoch):
|
| 81 |
+
self.epoch = epoch
|
| 82 |
+
|
| 83 |
+
def sample(self, data):
|
| 84 |
+
""" Sample data according to rank/world_size/num_workers
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
data(List): input data list
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List: data list after sample
|
| 91 |
+
"""
|
| 92 |
+
data = list(range(len(data)))
|
| 93 |
+
# force datalist even
|
| 94 |
+
if self.partition:
|
| 95 |
+
if self.shuffle:
|
| 96 |
+
random.Random(self.epoch).shuffle(data)
|
| 97 |
+
if len(data) < self.world_size:
|
| 98 |
+
data = data * math.ceil(self.world_size / len(data))
|
| 99 |
+
data = data[:self.world_size]
|
| 100 |
+
data = data[self.rank::self.world_size]
|
| 101 |
+
if len(data) < self.num_workers:
|
| 102 |
+
data = data * math.ceil(self.num_workers / len(data))
|
| 103 |
+
data = data[:self.num_workers]
|
| 104 |
+
data = data[self.worker_id::self.num_workers]
|
| 105 |
+
return data
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class DataList(IterableDataset):
|
| 109 |
+
|
| 110 |
+
def __init__(self, lists, shuffle=True, partition=True):
|
| 111 |
+
self.lists = lists
|
| 112 |
+
self.sampler = DistributedSampler(shuffle, partition)
|
| 113 |
+
|
| 114 |
+
def set_epoch(self, epoch):
|
| 115 |
+
self.sampler.set_epoch(epoch)
|
| 116 |
+
|
| 117 |
+
def __iter__(self):
|
| 118 |
+
sampler_info = self.sampler.update()
|
| 119 |
+
indexes = self.sampler.sample(self.lists)
|
| 120 |
+
for index in indexes:
|
| 121 |
+
data = dict(src=self.lists[index])
|
| 122 |
+
data.update(sampler_info)
|
| 123 |
+
yield data
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def Dataset(data_list_file,
|
| 127 |
+
data_pipeline,
|
| 128 |
+
mode='train',
|
| 129 |
+
gan=False,
|
| 130 |
+
shuffle=True,
|
| 131 |
+
partition=True,
|
| 132 |
+
tts_file='',
|
| 133 |
+
prompt_utt2data=''):
|
| 134 |
+
""" Construct dataset from arguments
|
| 135 |
+
|
| 136 |
+
We have two shuffle stage in the Dataset. The first is global
|
| 137 |
+
shuffle at shards tar/raw file level. The second is global shuffle
|
| 138 |
+
at training samples level.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
data_type(str): raw/shard
|
| 142 |
+
tokenizer (BaseTokenizer): tokenizer to tokenize
|
| 143 |
+
partition(bool): whether to do data partition in terms of rank
|
| 144 |
+
"""
|
| 145 |
+
assert mode in ['train', 'inference']
|
| 146 |
+
lists = read_lists(data_list_file)
|
| 147 |
+
if mode == 'inference':
|
| 148 |
+
with open(tts_file) as f:
|
| 149 |
+
tts_data = json.load(f)
|
| 150 |
+
utt2lists = read_json_lists(prompt_utt2data)
|
| 151 |
+
# filter unnecessary file in inference mode
|
| 152 |
+
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
| 153 |
+
dataset = DataList(lists,
|
| 154 |
+
shuffle=shuffle,
|
| 155 |
+
partition=partition)
|
| 156 |
+
if mode == 'inference':
|
| 157 |
+
# map partial arg to parquet_opener func in inference mode
|
| 158 |
+
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
| 159 |
+
if gan is True:
|
| 160 |
+
# map partial arg to padding func in gan mode
|
| 161 |
+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
| 162 |
+
for func in data_pipeline:
|
| 163 |
+
dataset = Processor(dataset, func, mode=mode)
|
| 164 |
+
return dataset
|
cosyvoice/dataset/processor.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import random
|
| 16 |
+
|
| 17 |
+
import pyarrow.parquet as pq
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import torch
|
| 20 |
+
import torchaudio
|
| 21 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import pyworld as pw
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parquet_opener(data, mode='train', tts_data={}):
|
| 30 |
+
""" Give url or local file, return file descriptor
|
| 31 |
+
Inplace operation.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
data(Iterable[str]): url or local file list
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Iterable[{src, stream}]
|
| 38 |
+
"""
|
| 39 |
+
for sample in data:
|
| 40 |
+
assert 'src' in sample
|
| 41 |
+
url = sample['src']
|
| 42 |
+
try:
|
| 43 |
+
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
| 44 |
+
df = df.to_pandas()
|
| 45 |
+
for i in range(len(df)):
|
| 46 |
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
| 47 |
+
continue
|
| 48 |
+
sample.update(dict(df.loc[i]))
|
| 49 |
+
if mode == 'train':
|
| 50 |
+
# NOTE do not return sample directly, must initialize a new dict
|
| 51 |
+
yield {**sample}
|
| 52 |
+
else:
|
| 53 |
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
| 54 |
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
| 55 |
+
except Exception as ex:
|
| 56 |
+
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def filter(data,
|
| 60 |
+
max_length=10240,
|
| 61 |
+
min_length=10,
|
| 62 |
+
token_max_length=200,
|
| 63 |
+
token_min_length=1,
|
| 64 |
+
min_output_input_ratio=0.0005,
|
| 65 |
+
max_output_input_ratio=1,
|
| 66 |
+
mode='train'):
|
| 67 |
+
""" Filter sample according to feature and label length
|
| 68 |
+
Inplace operation.
|
| 69 |
+
|
| 70 |
+
Args::
|
| 71 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 72 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
| 73 |
+
min_length: drop utterance which is less than min_length(10ms)
|
| 74 |
+
token_max_length: drop utterance which is greater than
|
| 75 |
+
token_max_length, especially when use char unit for
|
| 76 |
+
english modeling
|
| 77 |
+
token_min_length: drop utterance which is
|
| 78 |
+
less than token_max_length
|
| 79 |
+
min_output_input_ratio: minimal ration of
|
| 80 |
+
token_length / feats_length(10ms)
|
| 81 |
+
max_output_input_ratio: maximum ration of
|
| 82 |
+
token_length / feats_length(10ms)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 86 |
+
"""
|
| 87 |
+
for sample in data:
|
| 88 |
+
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
| 89 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
| 90 |
+
del sample['audio_data']
|
| 91 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
| 92 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
| 93 |
+
if num_frames < min_length:
|
| 94 |
+
continue
|
| 95 |
+
if num_frames > max_length:
|
| 96 |
+
continue
|
| 97 |
+
if len(sample['text_token']) < token_min_length:
|
| 98 |
+
continue
|
| 99 |
+
if len(sample['text_token']) > token_max_length:
|
| 100 |
+
continue
|
| 101 |
+
if len(sample['speech_token']) == 0:
|
| 102 |
+
continue
|
| 103 |
+
if num_frames != 0:
|
| 104 |
+
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
| 105 |
+
continue
|
| 106 |
+
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
| 107 |
+
continue
|
| 108 |
+
yield sample
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
| 112 |
+
""" Resample data.
|
| 113 |
+
Inplace operation.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 117 |
+
resample_rate: target resample rate
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 121 |
+
"""
|
| 122 |
+
for sample in data:
|
| 123 |
+
assert 'sample_rate' in sample
|
| 124 |
+
assert 'speech' in sample
|
| 125 |
+
sample_rate = sample['sample_rate']
|
| 126 |
+
waveform = sample['speech']
|
| 127 |
+
if sample_rate != resample_rate:
|
| 128 |
+
if sample_rate < min_sample_rate:
|
| 129 |
+
continue
|
| 130 |
+
sample['sample_rate'] = resample_rate
|
| 131 |
+
sample['speech'] = torchaudio.transforms.Resample(
|
| 132 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
| 133 |
+
max_val = sample['speech'].abs().max()
|
| 134 |
+
if max_val > 1:
|
| 135 |
+
sample['speech'] /= max_val
|
| 136 |
+
yield sample
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def truncate(data, truncate_length=24576, mode='train'):
|
| 140 |
+
""" Truncate data.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 144 |
+
truncate_length: truncate length
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 148 |
+
"""
|
| 149 |
+
for sample in data:
|
| 150 |
+
waveform = sample['speech']
|
| 151 |
+
if waveform.shape[1] > truncate_length:
|
| 152 |
+
start = random.randint(0, waveform.shape[1] - truncate_length)
|
| 153 |
+
waveform = waveform[:, start: start + truncate_length]
|
| 154 |
+
else:
|
| 155 |
+
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
| 156 |
+
sample['speech'] = waveform
|
| 157 |
+
yield sample
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def compute_fbank(data,
|
| 161 |
+
feat_extractor,
|
| 162 |
+
mode='train'):
|
| 163 |
+
""" Extract fbank
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Iterable[{key, feat, label}]
|
| 170 |
+
"""
|
| 171 |
+
for sample in data:
|
| 172 |
+
assert 'sample_rate' in sample
|
| 173 |
+
assert 'speech' in sample
|
| 174 |
+
assert 'utt' in sample
|
| 175 |
+
assert 'text_token' in sample
|
| 176 |
+
waveform = sample['speech']
|
| 177 |
+
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
| 178 |
+
sample['speech_feat'] = mat
|
| 179 |
+
yield sample
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
| 183 |
+
""" Extract f0
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Iterable[{key, feat, label}]
|
| 190 |
+
"""
|
| 191 |
+
frame_period = hop_size * 1000 / sample_rate
|
| 192 |
+
for sample in data:
|
| 193 |
+
assert 'sample_rate' in sample
|
| 194 |
+
assert 'speech' in sample
|
| 195 |
+
assert 'utt' in sample
|
| 196 |
+
assert 'text_token' in sample
|
| 197 |
+
waveform = sample['speech']
|
| 198 |
+
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
|
| 199 |
+
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
|
| 200 |
+
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
|
| 201 |
+
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
|
| 202 |
+
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
|
| 203 |
+
sample['pitch_feat'] = f0
|
| 204 |
+
yield sample
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def parse_embedding(data, normalize, mode='train'):
|
| 208 |
+
""" Parse utt_embedding/spk_embedding
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Iterable[{key, feat, label}]
|
| 215 |
+
"""
|
| 216 |
+
for sample in data:
|
| 217 |
+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
| 218 |
+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
| 219 |
+
if normalize:
|
| 220 |
+
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
| 221 |
+
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
| 222 |
+
yield sample
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
| 226 |
+
""" Decode text to chars or BPE
|
| 227 |
+
Inplace operation
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
| 234 |
+
"""
|
| 235 |
+
tokenizer = get_tokenizer()
|
| 236 |
+
for sample in data:
|
| 237 |
+
assert 'text' in sample
|
| 238 |
+
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
| 239 |
+
if mode == 'inference':
|
| 240 |
+
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
| 241 |
+
yield sample
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def shuffle(data, shuffle_size=10000, mode='train'):
|
| 245 |
+
""" Local shuffle the data
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
data: Iterable[{key, feat, label}]
|
| 249 |
+
shuffle_size: buffer size for shuffle
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Iterable[{key, feat, label}]
|
| 253 |
+
"""
|
| 254 |
+
buf = []
|
| 255 |
+
for sample in data:
|
| 256 |
+
buf.append(sample)
|
| 257 |
+
if len(buf) >= shuffle_size:
|
| 258 |
+
random.shuffle(buf)
|
| 259 |
+
for x in buf:
|
| 260 |
+
yield x
|
| 261 |
+
buf = []
|
| 262 |
+
# The sample left over
|
| 263 |
+
random.shuffle(buf)
|
| 264 |
+
for x in buf:
|
| 265 |
+
yield x
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def sort(data, sort_size=500, mode='train'):
|
| 269 |
+
""" Sort the data by feature length.
|
| 270 |
+
Sort is used after shuffle and before batch, so we can group
|
| 271 |
+
utts with similar lengths into a batch, and `sort_size` should
|
| 272 |
+
be less than `shuffle_size`
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
data: Iterable[{key, feat, label}]
|
| 276 |
+
sort_size: buffer size for sort
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Iterable[{key, feat, label}]
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
buf = []
|
| 283 |
+
for sample in data:
|
| 284 |
+
buf.append(sample)
|
| 285 |
+
if len(buf) >= sort_size:
|
| 286 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
| 287 |
+
for x in buf:
|
| 288 |
+
yield x
|
| 289 |
+
buf = []
|
| 290 |
+
# The sample left over
|
| 291 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
| 292 |
+
for x in buf:
|
| 293 |
+
yield x
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def static_batch(data, batch_size=16):
|
| 297 |
+
""" Static batch the data by `batch_size`
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
data: Iterable[{key, feat, label}]
|
| 301 |
+
batch_size: batch size
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Iterable[List[{key, feat, label}]]
|
| 305 |
+
"""
|
| 306 |
+
buf = []
|
| 307 |
+
for sample in data:
|
| 308 |
+
buf.append(sample)
|
| 309 |
+
if len(buf) >= batch_size:
|
| 310 |
+
yield buf
|
| 311 |
+
buf = []
|
| 312 |
+
if len(buf) > 0:
|
| 313 |
+
yield buf
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
| 317 |
+
""" Dynamic batch the data until the total frames in batch
|
| 318 |
+
reach `max_frames_in_batch`
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
data: Iterable[{key, feat, label}]
|
| 322 |
+
max_frames_in_batch: max_frames in one batch
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
Iterable[List[{key, feat, label}]]
|
| 326 |
+
"""
|
| 327 |
+
buf = []
|
| 328 |
+
longest_frames = 0
|
| 329 |
+
for sample in data:
|
| 330 |
+
assert 'speech_feat' in sample
|
| 331 |
+
assert isinstance(sample['speech_feat'], torch.Tensor)
|
| 332 |
+
new_sample_frames = sample['speech_feat'].size(0)
|
| 333 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
| 334 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
| 335 |
+
if frames_after_padding > max_frames_in_batch:
|
| 336 |
+
yield buf
|
| 337 |
+
buf = [sample]
|
| 338 |
+
longest_frames = new_sample_frames
|
| 339 |
+
else:
|
| 340 |
+
buf.append(sample)
|
| 341 |
+
if len(buf) > 0:
|
| 342 |
+
yield buf
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
| 346 |
+
""" Wrapper for static/dynamic batch
|
| 347 |
+
"""
|
| 348 |
+
if mode == 'inference':
|
| 349 |
+
return static_batch(data, 1)
|
| 350 |
+
else:
|
| 351 |
+
if batch_type == 'static':
|
| 352 |
+
return static_batch(data, batch_size)
|
| 353 |
+
elif batch_type == 'dynamic':
|
| 354 |
+
return dynamic_batch(data, max_frames_in_batch)
|
| 355 |
+
else:
|
| 356 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def padding(data, use_spk_embedding, mode='train', gan=False):
|
| 360 |
+
""" Padding the data into training data
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
data: Iterable[List[{key, feat, label}]]
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
| 367 |
+
"""
|
| 368 |
+
for sample in data:
|
| 369 |
+
assert isinstance(sample, list)
|
| 370 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
| 371 |
+
dtype=torch.int32)
|
| 372 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
| 373 |
+
|
| 374 |
+
utts = [sample[i]['utt'] for i in order]
|
| 375 |
+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
| 376 |
+
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
| 377 |
+
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
| 378 |
+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
| 379 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
| 380 |
+
speech_token = pad_sequence(speech_token,
|
| 381 |
+
batch_first=True,
|
| 382 |
+
padding_value=0)
|
| 383 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
| 384 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
| 385 |
+
speech_feat = pad_sequence(speech_feat,
|
| 386 |
+
batch_first=True,
|
| 387 |
+
padding_value=0)
|
| 388 |
+
text = [sample[i]['text'] for i in order]
|
| 389 |
+
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
| 390 |
+
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
| 391 |
+
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
| 392 |
+
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
| 393 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
| 394 |
+
batch = {
|
| 395 |
+
"utts": utts,
|
| 396 |
+
"speech": speech,
|
| 397 |
+
"speech_len": speech_len,
|
| 398 |
+
"speech_token": speech_token,
|
| 399 |
+
"speech_token_len": speech_token_len,
|
| 400 |
+
"speech_feat": speech_feat,
|
| 401 |
+
"speech_feat_len": speech_feat_len,
|
| 402 |
+
"text": text,
|
| 403 |
+
"text_token": text_token,
|
| 404 |
+
"text_token_len": text_token_len,
|
| 405 |
+
"utt_embedding": utt_embedding,
|
| 406 |
+
"spk_embedding": spk_embedding,
|
| 407 |
+
}
|
| 408 |
+
if gan is True:
|
| 409 |
+
# in gan train, we need pitch_feat
|
| 410 |
+
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
| 411 |
+
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
| 412 |
+
pitch_feat = pad_sequence(pitch_feat,
|
| 413 |
+
batch_first=True,
|
| 414 |
+
padding_value=0)
|
| 415 |
+
batch["pitch_feat"] = pitch_feat
|
| 416 |
+
batch["pitch_feat_len"] = pitch_feat_len
|
| 417 |
+
else:
|
| 418 |
+
# only gan train needs speech, delete it to save memory
|
| 419 |
+
del batch["speech"]
|
| 420 |
+
del batch["speech_len"]
|
| 421 |
+
if mode == 'inference':
|
| 422 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
| 423 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
| 424 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
| 425 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
| 426 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
| 427 |
+
batch.update({'tts_text': tts_text,
|
| 428 |
+
'tts_index': tts_index,
|
| 429 |
+
'tts_text_token': tts_text_token,
|
| 430 |
+
'tts_text_token_len': tts_text_token_len})
|
| 431 |
+
if use_spk_embedding is True:
|
| 432 |
+
batch["embedding"] = batch["spk_embedding"]
|
| 433 |
+
else:
|
| 434 |
+
batch["embedding"] = batch["utt_embedding"]
|
| 435 |
+
yield batch
|
cosyvoice/flow/__pycache__/decoder.cpython-310.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
cosyvoice/flow/__pycache__/decoder.cpython-38.pyc
ADDED
|
Binary file (8.16 kB). View file
|
|
|
cosyvoice/flow/__pycache__/flow.cpython-310.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
cosyvoice/flow/__pycache__/flow.cpython-38.pyc
ADDED
|
Binary file (6.36 kB). View file
|
|
|
cosyvoice/flow/__pycache__/flow_matching.cpython-310.pyc
ADDED
|
Binary file (6.89 kB). View file
|
|
|
cosyvoice/flow/__pycache__/flow_matching.cpython-38.pyc
ADDED
|
Binary file (6.87 kB). View file
|
|
|
cosyvoice/flow/__pycache__/length_regulator.cpython-310.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
cosyvoice/flow/__pycache__/length_regulator.cpython-38.pyc
ADDED
|
Binary file (2.18 kB). View file
|
|
|
cosyvoice/flow/decoder.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from einops import pack, rearrange, repeat
|
| 18 |
+
from cosyvoice.utils.common import mask_to_bias
|
| 19 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
| 20 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
| 21 |
+
from matcha.models.components.transformer import BasicTransformerBlock
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Transpose(torch.nn.Module):
|
| 25 |
+
def __init__(self, dim0: int, dim1: int):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.dim0 = dim0
|
| 28 |
+
self.dim1 = dim1
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor):
|
| 31 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CausalBlock1D(Block1D):
|
| 36 |
+
def __init__(self, dim: int, dim_out: int):
|
| 37 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
| 38 |
+
self.block = torch.nn.Sequential(
|
| 39 |
+
CausalConv1d(dim, dim_out, 3),
|
| 40 |
+
Transpose(1, 2),
|
| 41 |
+
nn.LayerNorm(dim_out),
|
| 42 |
+
Transpose(1, 2),
|
| 43 |
+
nn.Mish(),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
| 47 |
+
output = self.block(x * mask)
|
| 48 |
+
return output * mask
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
| 52 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
| 53 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
| 54 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
| 55 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
in_channels: int,
|
| 62 |
+
out_channels: int,
|
| 63 |
+
kernel_size: int,
|
| 64 |
+
stride: int = 1,
|
| 65 |
+
dilation: int = 1,
|
| 66 |
+
groups: int = 1,
|
| 67 |
+
bias: bool = True,
|
| 68 |
+
padding_mode: str = 'zeros',
|
| 69 |
+
device=None,
|
| 70 |
+
dtype=None
|
| 71 |
+
) -> None:
|
| 72 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
| 73 |
+
kernel_size, stride,
|
| 74 |
+
padding=0, dilation=dilation,
|
| 75 |
+
groups=groups, bias=bias,
|
| 76 |
+
padding_mode=padding_mode,
|
| 77 |
+
device=device, dtype=dtype)
|
| 78 |
+
assert stride == 1
|
| 79 |
+
self.causal_padding = (kernel_size - 1, 0)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor):
|
| 82 |
+
x = F.pad(x, self.causal_padding)
|
| 83 |
+
x = super(CausalConv1d, self).forward(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ConditionalDecoder(nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
in_channels,
|
| 91 |
+
out_channels,
|
| 92 |
+
causal=False,
|
| 93 |
+
channels=(256, 256),
|
| 94 |
+
dropout=0.05,
|
| 95 |
+
attention_head_dim=64,
|
| 96 |
+
n_blocks=1,
|
| 97 |
+
num_mid_blocks=2,
|
| 98 |
+
num_heads=4,
|
| 99 |
+
act_fn="snake",
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 103 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 104 |
+
"""
|
| 105 |
+
super().__init__()
|
| 106 |
+
channels = tuple(channels)
|
| 107 |
+
self.in_channels = in_channels
|
| 108 |
+
self.out_channels = out_channels
|
| 109 |
+
self.causal = causal
|
| 110 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 111 |
+
time_embed_dim = channels[0] * 4
|
| 112 |
+
self.time_mlp = TimestepEmbedding(
|
| 113 |
+
in_channels=in_channels,
|
| 114 |
+
time_embed_dim=time_embed_dim,
|
| 115 |
+
act_fn="silu",
|
| 116 |
+
)
|
| 117 |
+
self.down_blocks = nn.ModuleList([])
|
| 118 |
+
self.mid_blocks = nn.ModuleList([])
|
| 119 |
+
self.up_blocks = nn.ModuleList([])
|
| 120 |
+
|
| 121 |
+
output_channel = in_channels
|
| 122 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 123 |
+
input_channel = output_channel
|
| 124 |
+
output_channel = channels[i]
|
| 125 |
+
is_last = i == len(channels) - 1
|
| 126 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 127 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 128 |
+
transformer_blocks = nn.ModuleList(
|
| 129 |
+
[
|
| 130 |
+
BasicTransformerBlock(
|
| 131 |
+
dim=output_channel,
|
| 132 |
+
num_attention_heads=num_heads,
|
| 133 |
+
attention_head_dim=attention_head_dim,
|
| 134 |
+
dropout=dropout,
|
| 135 |
+
activation_fn=act_fn,
|
| 136 |
+
)
|
| 137 |
+
for _ in range(n_blocks)
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
downsample = (
|
| 141 |
+
Downsample1D(output_channel) if not is_last else
|
| 142 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 143 |
+
)
|
| 144 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 145 |
+
|
| 146 |
+
for _ in range(num_mid_blocks):
|
| 147 |
+
input_channel = channels[-1]
|
| 148 |
+
out_channels = channels[-1]
|
| 149 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 150 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 151 |
+
|
| 152 |
+
transformer_blocks = nn.ModuleList(
|
| 153 |
+
[
|
| 154 |
+
BasicTransformerBlock(
|
| 155 |
+
dim=output_channel,
|
| 156 |
+
num_attention_heads=num_heads,
|
| 157 |
+
attention_head_dim=attention_head_dim,
|
| 158 |
+
dropout=dropout,
|
| 159 |
+
activation_fn=act_fn,
|
| 160 |
+
)
|
| 161 |
+
for _ in range(n_blocks)
|
| 162 |
+
]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 166 |
+
|
| 167 |
+
channels = channels[::-1] + (channels[0],)
|
| 168 |
+
for i in range(len(channels) - 1):
|
| 169 |
+
input_channel = channels[i] * 2
|
| 170 |
+
output_channel = channels[i + 1]
|
| 171 |
+
is_last = i == len(channels) - 2
|
| 172 |
+
resnet = CausalResnetBlock1D(
|
| 173 |
+
dim=input_channel,
|
| 174 |
+
dim_out=output_channel,
|
| 175 |
+
time_emb_dim=time_embed_dim,
|
| 176 |
+
) if self.causal else ResnetBlock1D(
|
| 177 |
+
dim=input_channel,
|
| 178 |
+
dim_out=output_channel,
|
| 179 |
+
time_emb_dim=time_embed_dim,
|
| 180 |
+
)
|
| 181 |
+
transformer_blocks = nn.ModuleList(
|
| 182 |
+
[
|
| 183 |
+
BasicTransformerBlock(
|
| 184 |
+
dim=output_channel,
|
| 185 |
+
num_attention_heads=num_heads,
|
| 186 |
+
attention_head_dim=attention_head_dim,
|
| 187 |
+
dropout=dropout,
|
| 188 |
+
activation_fn=act_fn,
|
| 189 |
+
)
|
| 190 |
+
for _ in range(n_blocks)
|
| 191 |
+
]
|
| 192 |
+
)
|
| 193 |
+
upsample = (
|
| 194 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 195 |
+
if not is_last
|
| 196 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 197 |
+
)
|
| 198 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 199 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
| 200 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 201 |
+
self.initialize_weights()
|
| 202 |
+
|
| 203 |
+
def initialize_weights(self):
|
| 204 |
+
for m in self.modules():
|
| 205 |
+
if isinstance(m, nn.Conv1d):
|
| 206 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 207 |
+
if m.bias is not None:
|
| 208 |
+
nn.init.constant_(m.bias, 0)
|
| 209 |
+
elif isinstance(m, nn.GroupNorm):
|
| 210 |
+
nn.init.constant_(m.weight, 1)
|
| 211 |
+
nn.init.constant_(m.bias, 0)
|
| 212 |
+
elif isinstance(m, nn.Linear):
|
| 213 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 214 |
+
if m.bias is not None:
|
| 215 |
+
nn.init.constant_(m.bias, 0)
|
| 216 |
+
|
| 217 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
| 218 |
+
"""Forward pass of the UNet1DConditional model.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 222 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 223 |
+
t (_type_): shape (batch_size)
|
| 224 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 225 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 226 |
+
|
| 227 |
+
Raises:
|
| 228 |
+
ValueError: _description_
|
| 229 |
+
ValueError: _description_
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
_type_: _description_
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
t = self.time_embeddings(t).to(t.dtype)
|
| 236 |
+
t = self.time_mlp(t)
|
| 237 |
+
|
| 238 |
+
x = pack([x, mu], "b * t")[0]
|
| 239 |
+
|
| 240 |
+
if spks is not None:
|
| 241 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 242 |
+
x = pack([x, spks], "b * t")[0]
|
| 243 |
+
if cond is not None:
|
| 244 |
+
x = pack([x, cond], "b * t")[0]
|
| 245 |
+
|
| 246 |
+
hiddens = []
|
| 247 |
+
masks = [mask]
|
| 248 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 249 |
+
mask_down = masks[-1]
|
| 250 |
+
x = resnet(x, mask_down, t)
|
| 251 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 252 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
| 253 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1)
|
| 254 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 255 |
+
for transformer_block in transformer_blocks:
|
| 256 |
+
x = transformer_block(
|
| 257 |
+
hidden_states=x,
|
| 258 |
+
attention_mask=attn_mask,
|
| 259 |
+
timestep=t,
|
| 260 |
+
)
|
| 261 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 262 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 263 |
+
x = downsample(x * mask_down)
|
| 264 |
+
masks.append(mask_down[:, :, ::2])
|
| 265 |
+
masks = masks[:-1]
|
| 266 |
+
mask_mid = masks[-1]
|
| 267 |
+
|
| 268 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 269 |
+
x = resnet(x, mask_mid, t)
|
| 270 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 271 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
| 272 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1)
|
| 273 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 274 |
+
for transformer_block in transformer_blocks:
|
| 275 |
+
x = transformer_block(
|
| 276 |
+
hidden_states=x,
|
| 277 |
+
attention_mask=attn_mask,
|
| 278 |
+
timestep=t,
|
| 279 |
+
)
|
| 280 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 281 |
+
|
| 282 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 283 |
+
mask_up = masks.pop()
|
| 284 |
+
skip = hiddens.pop()
|
| 285 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 286 |
+
x = resnet(x, mask_up, t)
|
| 287 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 288 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
| 289 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1)
|
| 290 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 291 |
+
for transformer_block in transformer_blocks:
|
| 292 |
+
x = transformer_block(
|
| 293 |
+
hidden_states=x,
|
| 294 |
+
attention_mask=attn_mask,
|
| 295 |
+
timestep=t,
|
| 296 |
+
)
|
| 297 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 298 |
+
x = upsample(x * mask_up)
|
| 299 |
+
x = self.final_block(x, mask_up)
|
| 300 |
+
output = self.final_proj(x * mask_up)
|
| 301 |
+
return output * mask
|
cosyvoice/flow/flow.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import random
|
| 16 |
+
from typing import Dict, Optional
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
from omegaconf import DictConfig
|
| 21 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
| 25 |
+
def __init__(self,
|
| 26 |
+
input_size: int = 512,
|
| 27 |
+
output_size: int = 80,
|
| 28 |
+
spk_embed_dim: int = 192,
|
| 29 |
+
output_type: str = "mel",
|
| 30 |
+
vocab_size: int = 4096,
|
| 31 |
+
input_frame_rate: int = 50,
|
| 32 |
+
only_mask_loss: bool = True,
|
| 33 |
+
encoder: torch.nn.Module = None,
|
| 34 |
+
length_regulator: torch.nn.Module = None,
|
| 35 |
+
decoder: torch.nn.Module = None,
|
| 36 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
| 37 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
| 38 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
| 39 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
| 40 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
| 41 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
| 42 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.input_size = input_size
|
| 45 |
+
self.output_size = output_size
|
| 46 |
+
self.decoder_conf = decoder_conf
|
| 47 |
+
self.mel_feat_conf = mel_feat_conf
|
| 48 |
+
self.vocab_size = vocab_size
|
| 49 |
+
self.output_type = output_type
|
| 50 |
+
self.input_frame_rate = input_frame_rate
|
| 51 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 52 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 53 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 54 |
+
self.encoder = encoder
|
| 55 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
| 56 |
+
self.decoder = decoder
|
| 57 |
+
self.length_regulator = length_regulator
|
| 58 |
+
self.only_mask_loss = only_mask_loss
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
batch: dict,
|
| 63 |
+
device: torch.device,
|
| 64 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 65 |
+
token = batch['speech_token'].to(device)
|
| 66 |
+
token_len = batch['speech_token_len'].to(device)
|
| 67 |
+
feat = batch['speech_feat'].to(device)
|
| 68 |
+
feat_len = batch['speech_feat_len'].to(device)
|
| 69 |
+
embedding = batch['embedding'].to(device)
|
| 70 |
+
|
| 71 |
+
# xvec projection
|
| 72 |
+
embedding = F.normalize(embedding, dim=1)
|
| 73 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 74 |
+
|
| 75 |
+
# concat text and prompt_text
|
| 76 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 77 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 78 |
+
|
| 79 |
+
# text encode
|
| 80 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 81 |
+
h = self.encoder_proj(h)
|
| 82 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
| 83 |
+
|
| 84 |
+
# get conditions
|
| 85 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
| 86 |
+
for i, j in enumerate(feat_len):
|
| 87 |
+
if random.random() < 0.5:
|
| 88 |
+
continue
|
| 89 |
+
index = random.randint(0, int(0.3 * j))
|
| 90 |
+
conds[i, :index] = feat[i, :index]
|
| 91 |
+
conds = conds.transpose(1, 2)
|
| 92 |
+
|
| 93 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
| 94 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
| 95 |
+
loss, _ = self.decoder.compute_loss(
|
| 96 |
+
feat.transpose(1, 2).contiguous(),
|
| 97 |
+
mask.unsqueeze(1),
|
| 98 |
+
h.transpose(1, 2).contiguous(),
|
| 99 |
+
embedding,
|
| 100 |
+
cond=conds
|
| 101 |
+
)
|
| 102 |
+
return {'loss': loss}
|
| 103 |
+
|
| 104 |
+
@torch.inference_mode()
|
| 105 |
+
def inference(self,
|
| 106 |
+
token,
|
| 107 |
+
token_len,
|
| 108 |
+
prompt_token,
|
| 109 |
+
prompt_token_len,
|
| 110 |
+
prompt_feat,
|
| 111 |
+
prompt_feat_len,
|
| 112 |
+
embedding,
|
| 113 |
+
flow_cache):
|
| 114 |
+
# # if self.fp16 is True:
|
| 115 |
+
# prompt_feat = prompt_feat.half()
|
| 116 |
+
# embedding = embedding.half()
|
| 117 |
+
|
| 118 |
+
assert token.shape[0] == 1
|
| 119 |
+
# xvec projection
|
| 120 |
+
embedding = F.normalize(embedding, dim=1)
|
| 121 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 122 |
+
|
| 123 |
+
# concat text and prompt_text
|
| 124 |
+
print("prompt_token:", prompt_token, "token:", token)
|
| 125 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
| 126 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 127 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 128 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 129 |
+
|
| 130 |
+
# text encode
|
| 131 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 132 |
+
h = self.encoder_proj(h)
|
| 133 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
| 134 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
| 135 |
+
|
| 136 |
+
# get conditions
|
| 137 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
| 138 |
+
conds[:, :mel_len1] = prompt_feat
|
| 139 |
+
conds = conds.transpose(1, 2)
|
| 140 |
+
|
| 141 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
| 142 |
+
feat, flow_cache = self.decoder(
|
| 143 |
+
mu=h.transpose(1, 2).contiguous(),
|
| 144 |
+
mask=mask.unsqueeze(1),
|
| 145 |
+
spks=embedding,
|
| 146 |
+
cond=conds,
|
| 147 |
+
n_timesteps=10,
|
| 148 |
+
prompt_len=mel_len1,
|
| 149 |
+
flow_cache=flow_cache
|
| 150 |
+
)
|
| 151 |
+
feat = feat[:, :, mel_len1:]
|
| 152 |
+
assert feat.shape[2] == mel_len2
|
| 153 |
+
return feat.float(), flow_cache
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
| 157 |
+
def __init__(self,
|
| 158 |
+
input_size: int = 512,
|
| 159 |
+
output_size: int = 80,
|
| 160 |
+
spk_embed_dim: int = 192,
|
| 161 |
+
output_type: str = "mel",
|
| 162 |
+
vocab_size: int = 4096,
|
| 163 |
+
input_frame_rate: int = 50,
|
| 164 |
+
only_mask_loss: bool = True,
|
| 165 |
+
token_mel_ratio: int = 2,
|
| 166 |
+
pre_lookahead_len: int = 3,
|
| 167 |
+
encoder: torch.nn.Module = None,
|
| 168 |
+
decoder: torch.nn.Module = None,
|
| 169 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
| 170 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
| 171 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
| 172 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
| 173 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
| 174 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
| 175 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.input_size = input_size
|
| 178 |
+
self.output_size = output_size
|
| 179 |
+
self.decoder_conf = decoder_conf
|
| 180 |
+
self.mel_feat_conf = mel_feat_conf
|
| 181 |
+
self.vocab_size = vocab_size
|
| 182 |
+
self.output_type = output_type
|
| 183 |
+
self.input_frame_rate = input_frame_rate
|
| 184 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 185 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 186 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 187 |
+
self.encoder = encoder
|
| 188 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
| 189 |
+
self.decoder = decoder
|
| 190 |
+
self.only_mask_loss = only_mask_loss
|
| 191 |
+
self.token_mel_ratio = token_mel_ratio
|
| 192 |
+
self.pre_lookahead_len = pre_lookahead_len
|
| 193 |
+
|
| 194 |
+
@torch.inference_mode()
|
| 195 |
+
def inference(self,
|
| 196 |
+
token,
|
| 197 |
+
token_len,
|
| 198 |
+
prompt_token,
|
| 199 |
+
prompt_token_len,
|
| 200 |
+
prompt_feat,
|
| 201 |
+
prompt_feat_len,
|
| 202 |
+
embedding,
|
| 203 |
+
finalize):
|
| 204 |
+
# if self.fp16 is True:
|
| 205 |
+
# prompt_feat = prompt_feat.half()
|
| 206 |
+
# embedding = embedding.half()
|
| 207 |
+
|
| 208 |
+
assert token.shape[0] == 1
|
| 209 |
+
# xvec projection
|
| 210 |
+
embedding = F.normalize(embedding, dim=1)
|
| 211 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 212 |
+
|
| 213 |
+
# concat text and prompt_text
|
| 214 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 215 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 216 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 217 |
+
|
| 218 |
+
# text encode
|
| 219 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 220 |
+
if finalize is False:
|
| 221 |
+
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
| 222 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
| 223 |
+
h = self.encoder_proj(h)
|
| 224 |
+
|
| 225 |
+
# get conditions
|
| 226 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
| 227 |
+
conds[:, :mel_len1] = prompt_feat
|
| 228 |
+
conds = conds.transpose(1, 2)
|
| 229 |
+
|
| 230 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
| 231 |
+
feat, _ = self.decoder(
|
| 232 |
+
mu=h.transpose(1, 2).contiguous(),
|
| 233 |
+
mask=mask.unsqueeze(1),
|
| 234 |
+
spks=embedding,
|
| 235 |
+
cond=conds,
|
| 236 |
+
n_timesteps=10
|
| 237 |
+
)
|
| 238 |
+
feat = feat[:, :, mel_len1:]
|
| 239 |
+
assert feat.shape[2] == mel_len2
|
| 240 |
+
return feat.float(), None
|
cosyvoice/flow/flow_matching.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import threading
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from matcha.models.components.flow_matching import BASECFM
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ConditionalCFM(BASECFM):
|
| 21 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| 22 |
+
super().__init__(
|
| 23 |
+
n_feats=in_channels,
|
| 24 |
+
cfm_params=cfm_params,
|
| 25 |
+
n_spks=n_spks,
|
| 26 |
+
spk_emb_dim=spk_emb_dim,
|
| 27 |
+
)
|
| 28 |
+
self.t_scheduler = cfm_params.t_scheduler
|
| 29 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
| 30 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| 31 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
| 32 |
+
# Just change the architecture of the estimator here
|
| 33 |
+
self.estimator = estimator
|
| 34 |
+
self.lock = threading.Lock()
|
| 35 |
+
|
| 36 |
+
@torch.inference_mode()
|
| 37 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
| 38 |
+
"""Forward diffusion
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
mu (torch.Tensor): output of encoder
|
| 42 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 43 |
+
mask (torch.Tensor): output_mask
|
| 44 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 45 |
+
n_timesteps (int): number of diffusion steps
|
| 46 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 47 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 48 |
+
shape: (batch_size, spk_emb_dim)
|
| 49 |
+
cond: Not used but kept for future purposes
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
sample: generated mel-spectrogram
|
| 53 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
| 57 |
+
cache_size = flow_cache.shape[2]
|
| 58 |
+
# fix prompt and overlap part mu and z
|
| 59 |
+
if cache_size != 0:
|
| 60 |
+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
| 61 |
+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
| 62 |
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
| 63 |
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
| 64 |
+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
| 65 |
+
|
| 66 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 67 |
+
if self.t_scheduler == 'cosine':
|
| 68 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 69 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
| 70 |
+
|
| 71 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
| 72 |
+
"""
|
| 73 |
+
Fixed euler solver for ODEs.
|
| 74 |
+
Args:
|
| 75 |
+
x (torch.Tensor): random noise
|
| 76 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 77 |
+
shape: (n_timesteps + 1,)
|
| 78 |
+
mu (torch.Tensor): output of encoder
|
| 79 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 80 |
+
mask (torch.Tensor): output_mask
|
| 81 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 82 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 83 |
+
shape: (batch_size, spk_emb_dim)
|
| 84 |
+
cond: Not used but kept for future purposes
|
| 85 |
+
"""
|
| 86 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 87 |
+
t = t.unsqueeze(dim=0)
|
| 88 |
+
|
| 89 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 90 |
+
# Or in future might add like a return_all_steps flag
|
| 91 |
+
sol = []
|
| 92 |
+
|
| 93 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
| 94 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 95 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
| 96 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 97 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
| 98 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
| 99 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 100 |
+
for step in range(1, len(t_span)):
|
| 101 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 102 |
+
x_in[:] = x
|
| 103 |
+
mask_in[:] = mask
|
| 104 |
+
mu_in[0] = mu
|
| 105 |
+
t_in[:] = t.unsqueeze(0)
|
| 106 |
+
spks_in[0] = spks
|
| 107 |
+
cond_in[0] = cond
|
| 108 |
+
dphi_dt = self.forward_estimator(
|
| 109 |
+
x_in, mask_in,
|
| 110 |
+
mu_in, t_in,
|
| 111 |
+
spks_in,
|
| 112 |
+
cond_in
|
| 113 |
+
)
|
| 114 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
| 115 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
| 116 |
+
x = x + dt * dphi_dt
|
| 117 |
+
t = t + dt
|
| 118 |
+
sol.append(x)
|
| 119 |
+
if step < len(t_span) - 1:
|
| 120 |
+
dt = t_span[step + 1] - t
|
| 121 |
+
|
| 122 |
+
return sol[-1].float()
|
| 123 |
+
|
| 124 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
| 125 |
+
if isinstance(self.estimator, torch.nn.Module):
|
| 126 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
| 127 |
+
else:
|
| 128 |
+
with self.lock:
|
| 129 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
| 130 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
| 131 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
| 132 |
+
self.estimator.set_input_shape('t', (2,))
|
| 133 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
| 134 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
| 135 |
+
# run trt engine
|
| 136 |
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
| 137 |
+
mask.contiguous().data_ptr(),
|
| 138 |
+
mu.contiguous().data_ptr(),
|
| 139 |
+
t.contiguous().data_ptr(),
|
| 140 |
+
spks.contiguous().data_ptr(),
|
| 141 |
+
cond.contiguous().data_ptr(),
|
| 142 |
+
x.data_ptr()])
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| 146 |
+
"""Computes diffusion loss
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
x1 (torch.Tensor): Target
|
| 150 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 151 |
+
mask (torch.Tensor): target mask
|
| 152 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 153 |
+
mu (torch.Tensor): output of encoder
|
| 154 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 155 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 156 |
+
shape: (batch_size, spk_emb_dim)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
loss: conditional flow matching loss
|
| 160 |
+
y: conditional flow
|
| 161 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 162 |
+
"""
|
| 163 |
+
b, _, t = mu.shape
|
| 164 |
+
|
| 165 |
+
# random timestep
|
| 166 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 167 |
+
if self.t_scheduler == 'cosine':
|
| 168 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 169 |
+
# sample noise p(x_0)
|
| 170 |
+
z = torch.randn_like(x1)
|
| 171 |
+
|
| 172 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 173 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 174 |
+
|
| 175 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
| 176 |
+
if self.training_cfg_rate > 0:
|
| 177 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| 178 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 179 |
+
spks = spks * cfg_mask.view(-1, 1)
|
| 180 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 181 |
+
|
| 182 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
| 183 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 184 |
+
return loss, y
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class CausalConditionalCFM(ConditionalCFM):
|
| 188 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| 189 |
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
| 190 |
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
| 191 |
+
|
| 192 |
+
@torch.inference_mode()
|
| 193 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| 194 |
+
"""Forward diffusion
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
mu (torch.Tensor): output of encoder
|
| 198 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 199 |
+
mask (torch.Tensor): output_mask
|
| 200 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 201 |
+
n_timesteps (int): number of diffusion steps
|
| 202 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 203 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 204 |
+
shape: (batch_size, spk_emb_dim)
|
| 205 |
+
cond: Not used but kept for future purposes
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
sample: generated mel-spectrogram
|
| 209 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
| 213 |
+
# fix prompt and overlap part mu and z
|
| 214 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 215 |
+
if self.t_scheduler == 'cosine':
|
| 216 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 217 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
cosyvoice/flow/length_regulator.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class InterpolateRegulator(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
channels: int,
|
| 25 |
+
sampling_ratios: Tuple,
|
| 26 |
+
out_channels: int = None,
|
| 27 |
+
groups: int = 1,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.sampling_ratios = sampling_ratios
|
| 31 |
+
out_channels = out_channels or channels
|
| 32 |
+
model = nn.ModuleList([])
|
| 33 |
+
if len(sampling_ratios) > 0:
|
| 34 |
+
for _ in sampling_ratios:
|
| 35 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
| 36 |
+
norm = nn.GroupNorm(groups, channels)
|
| 37 |
+
act = nn.Mish()
|
| 38 |
+
model.extend([module, norm, act])
|
| 39 |
+
model.append(
|
| 40 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
| 41 |
+
)
|
| 42 |
+
self.model = nn.Sequential(*model)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, ylens=None):
|
| 45 |
+
# x in (B, T, D)
|
| 46 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
| 47 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
| 48 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 49 |
+
olens = ylens
|
| 50 |
+
return out * mask, olens
|
| 51 |
+
|
| 52 |
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
| 53 |
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
| 54 |
+
# x in (B, T, D)
|
| 55 |
+
if x2.shape[1] > 40:
|
| 56 |
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| 57 |
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
| 58 |
+
mode='linear')
|
| 59 |
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| 60 |
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
| 61 |
+
else:
|
| 62 |
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
| 63 |
+
if x1.shape[1] != 0:
|
| 64 |
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
| 65 |
+
x = torch.concat([x1, x2], dim=2)
|
| 66 |
+
else:
|
| 67 |
+
x = x2
|
| 68 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 69 |
+
return out, mel_len1 + mel_len2
|
cosyvoice/flow_speaker_minus/__pycache__/decoder.cpython-310.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
cosyvoice/flow_speaker_minus/__pycache__/flow.cpython-310.pyc
ADDED
|
Binary file (5.05 kB). View file
|
|
|
cosyvoice/flow_speaker_minus/__pycache__/flow.cpython-38.pyc
ADDED
|
Binary file (5.01 kB). View file
|
|
|
cosyvoice/flow_speaker_minus/__pycache__/flow_matching.cpython-310.pyc
ADDED
|
Binary file (6.89 kB). View file
|
|
|
cosyvoice/flow_speaker_minus/__pycache__/length_regulator.cpython-310.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
cosyvoice/flow_speaker_minus/decoder.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from einops import pack, rearrange, repeat
|
| 18 |
+
from cosyvoice.utils.common import mask_to_bias
|
| 19 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
| 20 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
| 21 |
+
from matcha.models.components.transformer import BasicTransformerBlock
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Transpose(torch.nn.Module):
|
| 25 |
+
def __init__(self, dim0: int, dim1: int):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.dim0 = dim0
|
| 28 |
+
self.dim1 = dim1
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor):
|
| 31 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CausalBlock1D(Block1D):
|
| 36 |
+
def __init__(self, dim: int, dim_out: int):
|
| 37 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
| 38 |
+
self.block = torch.nn.Sequential(
|
| 39 |
+
CausalConv1d(dim, dim_out, 3),
|
| 40 |
+
Transpose(1, 2),
|
| 41 |
+
nn.LayerNorm(dim_out),
|
| 42 |
+
Transpose(1, 2),
|
| 43 |
+
nn.Mish(),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
| 47 |
+
output = self.block(x * mask)
|
| 48 |
+
return output * mask
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
| 52 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
| 53 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
| 54 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
| 55 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
in_channels: int,
|
| 62 |
+
out_channels: int,
|
| 63 |
+
kernel_size: int,
|
| 64 |
+
stride: int = 1,
|
| 65 |
+
dilation: int = 1,
|
| 66 |
+
groups: int = 1,
|
| 67 |
+
bias: bool = True,
|
| 68 |
+
padding_mode: str = 'zeros',
|
| 69 |
+
device=None,
|
| 70 |
+
dtype=None
|
| 71 |
+
) -> None:
|
| 72 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
| 73 |
+
kernel_size, stride,
|
| 74 |
+
padding=0, dilation=dilation,
|
| 75 |
+
groups=groups, bias=bias,
|
| 76 |
+
padding_mode=padding_mode,
|
| 77 |
+
device=device, dtype=dtype)
|
| 78 |
+
assert stride == 1
|
| 79 |
+
self.causal_padding = (kernel_size - 1, 0)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor):
|
| 82 |
+
x = F.pad(x, self.causal_padding)
|
| 83 |
+
x = super(CausalConv1d, self).forward(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ConditionalDecoder(nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
in_channels,
|
| 91 |
+
out_channels,
|
| 92 |
+
causal=False,
|
| 93 |
+
channels=(256, 256),
|
| 94 |
+
dropout=0.05,
|
| 95 |
+
attention_head_dim=64,
|
| 96 |
+
n_blocks=1,
|
| 97 |
+
num_mid_blocks=2,
|
| 98 |
+
num_heads=4,
|
| 99 |
+
act_fn="snake",
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 103 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 104 |
+
"""
|
| 105 |
+
super().__init__()
|
| 106 |
+
channels = tuple(channels)
|
| 107 |
+
self.in_channels = in_channels
|
| 108 |
+
self.out_channels = out_channels
|
| 109 |
+
self.causal = causal
|
| 110 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 111 |
+
time_embed_dim = channels[0] * 4
|
| 112 |
+
self.time_mlp = TimestepEmbedding(
|
| 113 |
+
in_channels=in_channels,
|
| 114 |
+
time_embed_dim=time_embed_dim,
|
| 115 |
+
act_fn="silu",
|
| 116 |
+
)
|
| 117 |
+
self.down_blocks = nn.ModuleList([])
|
| 118 |
+
self.mid_blocks = nn.ModuleList([])
|
| 119 |
+
self.up_blocks = nn.ModuleList([])
|
| 120 |
+
|
| 121 |
+
output_channel = in_channels
|
| 122 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 123 |
+
input_channel = output_channel
|
| 124 |
+
output_channel = channels[i]
|
| 125 |
+
is_last = i == len(channels) - 1
|
| 126 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 127 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 128 |
+
transformer_blocks = nn.ModuleList(
|
| 129 |
+
[
|
| 130 |
+
BasicTransformerBlock(
|
| 131 |
+
dim=output_channel,
|
| 132 |
+
num_attention_heads=num_heads,
|
| 133 |
+
attention_head_dim=attention_head_dim,
|
| 134 |
+
dropout=dropout,
|
| 135 |
+
activation_fn=act_fn,
|
| 136 |
+
)
|
| 137 |
+
for _ in range(n_blocks)
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
downsample = (
|
| 141 |
+
Downsample1D(output_channel) if not is_last else
|
| 142 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 143 |
+
)
|
| 144 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 145 |
+
|
| 146 |
+
for _ in range(num_mid_blocks):
|
| 147 |
+
input_channel = channels[-1]
|
| 148 |
+
out_channels = channels[-1]
|
| 149 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 150 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 151 |
+
|
| 152 |
+
transformer_blocks = nn.ModuleList(
|
| 153 |
+
[
|
| 154 |
+
BasicTransformerBlock(
|
| 155 |
+
dim=output_channel,
|
| 156 |
+
num_attention_heads=num_heads,
|
| 157 |
+
attention_head_dim=attention_head_dim,
|
| 158 |
+
dropout=dropout,
|
| 159 |
+
activation_fn=act_fn,
|
| 160 |
+
)
|
| 161 |
+
for _ in range(n_blocks)
|
| 162 |
+
]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 166 |
+
|
| 167 |
+
channels = channels[::-1] + (channels[0],)
|
| 168 |
+
for i in range(len(channels) - 1):
|
| 169 |
+
input_channel = channels[i] * 2
|
| 170 |
+
output_channel = channels[i + 1]
|
| 171 |
+
is_last = i == len(channels) - 2
|
| 172 |
+
resnet = CausalResnetBlock1D(
|
| 173 |
+
dim=input_channel,
|
| 174 |
+
dim_out=output_channel,
|
| 175 |
+
time_emb_dim=time_embed_dim,
|
| 176 |
+
) if self.causal else ResnetBlock1D(
|
| 177 |
+
dim=input_channel,
|
| 178 |
+
dim_out=output_channel,
|
| 179 |
+
time_emb_dim=time_embed_dim,
|
| 180 |
+
)
|
| 181 |
+
transformer_blocks = nn.ModuleList(
|
| 182 |
+
[
|
| 183 |
+
BasicTransformerBlock(
|
| 184 |
+
dim=output_channel,
|
| 185 |
+
num_attention_heads=num_heads,
|
| 186 |
+
attention_head_dim=attention_head_dim,
|
| 187 |
+
dropout=dropout,
|
| 188 |
+
activation_fn=act_fn,
|
| 189 |
+
)
|
| 190 |
+
for _ in range(n_blocks)
|
| 191 |
+
]
|
| 192 |
+
)
|
| 193 |
+
upsample = (
|
| 194 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 195 |
+
if not is_last
|
| 196 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 197 |
+
)
|
| 198 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 199 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
| 200 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 201 |
+
self.initialize_weights()
|
| 202 |
+
|
| 203 |
+
def initialize_weights(self):
|
| 204 |
+
for m in self.modules():
|
| 205 |
+
if isinstance(m, nn.Conv1d):
|
| 206 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 207 |
+
if m.bias is not None:
|
| 208 |
+
nn.init.constant_(m.bias, 0)
|
| 209 |
+
elif isinstance(m, nn.GroupNorm):
|
| 210 |
+
nn.init.constant_(m.weight, 1)
|
| 211 |
+
nn.init.constant_(m.bias, 0)
|
| 212 |
+
elif isinstance(m, nn.Linear):
|
| 213 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 214 |
+
if m.bias is not None:
|
| 215 |
+
nn.init.constant_(m.bias, 0)
|
| 216 |
+
|
| 217 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
| 218 |
+
"""Forward pass of the UNet1DConditional model.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 222 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 223 |
+
t (_type_): shape (batch_size)
|
| 224 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 225 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 226 |
+
|
| 227 |
+
Raises:
|
| 228 |
+
ValueError: _description_
|
| 229 |
+
ValueError: _description_
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
_type_: _description_
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
t = self.time_embeddings(t).to(t.dtype)
|
| 236 |
+
t = self.time_mlp(t)
|
| 237 |
+
|
| 238 |
+
x = pack([x, mu], "b * t")[0]
|
| 239 |
+
|
| 240 |
+
if spks is not None:
|
| 241 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 242 |
+
x = pack([x, spks], "b * t")[0]
|
| 243 |
+
if cond is not None:
|
| 244 |
+
x = pack([x, cond], "b * t")[0]
|
| 245 |
+
|
| 246 |
+
hiddens = []
|
| 247 |
+
masks = [mask]
|
| 248 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 249 |
+
mask_down = masks[-1]
|
| 250 |
+
x = resnet(x, mask_down, t)
|
| 251 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 252 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
| 253 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 254 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 255 |
+
for transformer_block in transformer_blocks:
|
| 256 |
+
x = transformer_block(
|
| 257 |
+
hidden_states=x,
|
| 258 |
+
attention_mask=attn_mask,
|
| 259 |
+
timestep=t,
|
| 260 |
+
)
|
| 261 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 262 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 263 |
+
x = downsample(x * mask_down)
|
| 264 |
+
masks.append(mask_down[:, :, ::2])
|
| 265 |
+
masks = masks[:-1]
|
| 266 |
+
mask_mid = masks[-1]
|
| 267 |
+
|
| 268 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 269 |
+
x = resnet(x, mask_mid, t)
|
| 270 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 271 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
| 272 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 273 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 274 |
+
for transformer_block in transformer_blocks:
|
| 275 |
+
x = transformer_block(
|
| 276 |
+
hidden_states=x,
|
| 277 |
+
attention_mask=attn_mask,
|
| 278 |
+
timestep=t,
|
| 279 |
+
)
|
| 280 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 281 |
+
|
| 282 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 283 |
+
mask_up = masks.pop()
|
| 284 |
+
skip = hiddens.pop()
|
| 285 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 286 |
+
x = resnet(x, mask_up, t)
|
| 287 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 288 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
| 289 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 290 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 291 |
+
for transformer_block in transformer_blocks:
|
| 292 |
+
x = transformer_block(
|
| 293 |
+
hidden_states=x,
|
| 294 |
+
attention_mask=attn_mask,
|
| 295 |
+
timestep=t,
|
| 296 |
+
)
|
| 297 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 298 |
+
x = upsample(x * mask_up)
|
| 299 |
+
x = self.final_block(x, mask_up)
|
| 300 |
+
output = self.final_proj(x * mask_up)
|
| 301 |
+
return output * mask
|
cosyvoice/flow_speaker_minus/flow.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from omegaconf import DictConfig
|
| 8 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 9 |
+
from cosyvoice.utils.losses import OrthogonalityLoss
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
| 13 |
+
def __init__(self,
|
| 14 |
+
input_size: int = 512,
|
| 15 |
+
output_size: int = 80,
|
| 16 |
+
spk_embed_dim: int = 192,
|
| 17 |
+
output_type: str = "mel",
|
| 18 |
+
vocab_size: int = 4096,
|
| 19 |
+
input_frame_rate: int = 50,
|
| 20 |
+
only_mask_loss: bool = True,
|
| 21 |
+
encoder: torch.nn.Module = None,
|
| 22 |
+
length_regulator: torch.nn.Module = None,
|
| 23 |
+
decoder: torch.nn.Module = None,
|
| 24 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
| 25 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
| 26 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
| 27 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
| 28 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
| 29 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
| 30 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000},
|
| 31 |
+
flow_emotion_embedding: bool = False, # 新增 flow_emotion_embedding
|
| 32 |
+
flow_orth_loss: bool = False,
|
| 33 |
+
cross_orth_loss: bool = False): # 新增 flow_orth_loss
|
| 34 |
+
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.input_size = input_size
|
| 37 |
+
self.output_size = output_size
|
| 38 |
+
self.decoder_conf = decoder_conf
|
| 39 |
+
self.mel_feat_conf = mel_feat_conf
|
| 40 |
+
self.vocab_size = vocab_size
|
| 41 |
+
self.output_type = output_type
|
| 42 |
+
self.input_frame_rate = input_frame_rate
|
| 43 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 44 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 45 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 46 |
+
self.encoder = encoder
|
| 47 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
| 48 |
+
self.decoder = decoder
|
| 49 |
+
self.length_regulator = length_regulator
|
| 50 |
+
self.only_mask_loss = only_mask_loss
|
| 51 |
+
self.flow_emotion_embedding = flow_emotion_embedding
|
| 52 |
+
self.flow_orth_loss = flow_orth_loss
|
| 53 |
+
self.cross_orth_loss = cross_orth_loss
|
| 54 |
+
|
| 55 |
+
# 如果启用 flow_emotion_embedding,增加情感嵌入的投影层
|
| 56 |
+
if self.flow_emotion_embedding:
|
| 57 |
+
self.flow_emotion_embedding_proj = torch.nn.Linear(spk_embed_dim, spk_embed_dim)
|
| 58 |
+
self.speaker_projector = nn.Linear(spk_embed_dim, spk_embed_dim)
|
| 59 |
+
|
| 60 |
+
# 如果启用 flow_orth_loss,增加正交损失的计算
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def forward(
|
| 64 |
+
self,
|
| 65 |
+
batch: dict,
|
| 66 |
+
device: torch.device,
|
| 67 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 68 |
+
token = batch['speech_token'].to(device)
|
| 69 |
+
token_len = batch['speech_token_len'].to(device)
|
| 70 |
+
feat = batch['speech_feat'].to(device)
|
| 71 |
+
feat_len = batch['speech_feat_len'].to(device)
|
| 72 |
+
embedding = batch['embedding'].to(device)
|
| 73 |
+
|
| 74 |
+
# 处理 flow_emotion_embedding
|
| 75 |
+
if self.flow_emotion_embedding:
|
| 76 |
+
flow_emotion_embedding = batch['emotion_embedding'].to(device)
|
| 77 |
+
flow_emotion_embedding = F.normalize(flow_emotion_embedding, dim=1)
|
| 78 |
+
flow_emotion_embedding = self.flow_emotion_embedding_proj(flow_emotion_embedding)
|
| 79 |
+
embedding = self.speaker_projector(embedding)
|
| 80 |
+
embedding += flow_emotion_embedding # 将情感嵌入加到说话人嵌入中
|
| 81 |
+
if self.cross_orth_loss:
|
| 82 |
+
orth_loss = 0.0
|
| 83 |
+
batch_size = embedding.size(0)
|
| 84 |
+
for i in range(batch_size):
|
| 85 |
+
for j in range(i + 1, batch_size):
|
| 86 |
+
# 计算 embedding[i] 和 emotion_embedding[j] 之间的正交损失
|
| 87 |
+
orth_loss += torch.abs(torch.dot(embedding[i], emotion_embedding[j]))
|
| 88 |
+
orth_loss /= (batch_size * (batch_size - 1)) / 2
|
| 89 |
+
else:
|
| 90 |
+
orth_loss = OrthogonalityLoss(embedding, flow_emotion_embedding)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# xvec projection
|
| 94 |
+
embedding = F.normalize(embedding, dim=1)
|
| 95 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 96 |
+
|
| 97 |
+
# concat text and prompt_text
|
| 98 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 99 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 100 |
+
|
| 101 |
+
# text encode
|
| 102 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 103 |
+
h = self.encoder_proj(h)
|
| 104 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
| 105 |
+
|
| 106 |
+
# get conditions
|
| 107 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
| 108 |
+
for i, j in enumerate(feat_len):
|
| 109 |
+
if random.random() < 0.5:
|
| 110 |
+
continue
|
| 111 |
+
index = random.randint(0, int(0.3 * j))
|
| 112 |
+
conds[i, :index] = feat[i, :index]
|
| 113 |
+
conds = conds.transpose(1, 2)
|
| 114 |
+
|
| 115 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
| 116 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
| 117 |
+
loss, _ = self.decoder.compute_loss(
|
| 118 |
+
feat.transpose(1, 2).contiguous(),
|
| 119 |
+
mask.unsqueeze(1),
|
| 120 |
+
h.transpose(1, 2).contiguous(),
|
| 121 |
+
embedding,
|
| 122 |
+
cond=conds
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# 计算正交损失(如果启用)
|
| 126 |
+
if self.flow_orth_loss and self.flow_emotion_embedding:
|
| 127 |
+
|
| 128 |
+
loss += orth_loss
|
| 129 |
+
|
| 130 |
+
return {'loss': loss}
|
| 131 |
+
|
| 132 |
+
@torch.inference_mode()
|
| 133 |
+
def inference(self,
|
| 134 |
+
token,
|
| 135 |
+
token_len,
|
| 136 |
+
prompt_token,
|
| 137 |
+
prompt_token_len,
|
| 138 |
+
prompt_feat,
|
| 139 |
+
prompt_feat_len,
|
| 140 |
+
embedding,
|
| 141 |
+
flow_cache,
|
| 142 |
+
flow_emotion_embedding=None): # 新增 flow_emotion_embedding
|
| 143 |
+
assert token.shape[0] == 1
|
| 144 |
+
# 处理 flow_emotion_embedding
|
| 145 |
+
if self.flow_emotion_embedding and flow_emotion_embedding is not None:
|
| 146 |
+
flow_emotion_embedding = F.normalize(flow_emotion_embedding.unsqueeze(0).to(torch.float16), dim=1)
|
| 147 |
+
flow_emotion_embedding = self.flow_emotion_embedding_proj(flow_emotion_embedding) # * 1.5
|
| 148 |
+
embedding = self.speaker_projector(embedding)
|
| 149 |
+
embedding += flow_emotion_embedding # 将情感嵌入加到说话人嵌入中
|
| 150 |
+
|
| 151 |
+
# xvec projection
|
| 152 |
+
embedding = F.normalize(embedding, dim=1)
|
| 153 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 154 |
+
|
| 155 |
+
# concat text and prompt_text
|
| 156 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
| 157 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 158 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 159 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 160 |
+
|
| 161 |
+
# text encode
|
| 162 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 163 |
+
h = self.encoder_proj(h)
|
| 164 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
| 165 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
| 166 |
+
|
| 167 |
+
# get conditions
|
| 168 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
| 169 |
+
conds[:, :mel_len1] = prompt_feat
|
| 170 |
+
conds = conds.transpose(1, 2)
|
| 171 |
+
|
| 172 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
| 173 |
+
feat, flow_cache = self.decoder(
|
| 174 |
+
mu=h.transpose(1, 2).contiguous(),
|
| 175 |
+
mask=mask.unsqueeze(1),
|
| 176 |
+
spks=embedding,
|
| 177 |
+
cond=conds,
|
| 178 |
+
n_timesteps=10,
|
| 179 |
+
prompt_len=mel_len1,
|
| 180 |
+
flow_cache=flow_cache
|
| 181 |
+
)
|
| 182 |
+
feat = feat[:, :, mel_len1:]
|
| 183 |
+
assert feat.shape[2] == mel_len2
|
| 184 |
+
return feat, flow_cache
|
cosyvoice/flow_speaker_minus/flow_matching.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import threading
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from matcha.models.components.flow_matching import BASECFM
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ConditionalCFM(BASECFM):
|
| 21 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| 22 |
+
super().__init__(
|
| 23 |
+
n_feats=in_channels,
|
| 24 |
+
cfm_params=cfm_params,
|
| 25 |
+
n_spks=n_spks,
|
| 26 |
+
spk_emb_dim=spk_emb_dim,
|
| 27 |
+
)
|
| 28 |
+
self.t_scheduler = cfm_params.t_scheduler
|
| 29 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
| 30 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| 31 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
| 32 |
+
# Just change the architecture of the estimator here
|
| 33 |
+
self.estimator = estimator
|
| 34 |
+
self.lock = threading.Lock()
|
| 35 |
+
|
| 36 |
+
@torch.inference_mode()
|
| 37 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
| 38 |
+
"""Forward diffusion
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
mu (torch.Tensor): output of encoder
|
| 42 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 43 |
+
mask (torch.Tensor): output_mask
|
| 44 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 45 |
+
n_timesteps (int): number of diffusion steps
|
| 46 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 47 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 48 |
+
shape: (batch_size, spk_emb_dim)
|
| 49 |
+
cond: Not used but kept for future purposes
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
sample: generated mel-spectrogram
|
| 53 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
| 57 |
+
cache_size = flow_cache.shape[2]
|
| 58 |
+
# fix prompt and overlap part mu and z
|
| 59 |
+
if cache_size != 0:
|
| 60 |
+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
| 61 |
+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
| 62 |
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
| 63 |
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
| 64 |
+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
| 65 |
+
|
| 66 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 67 |
+
if self.t_scheduler == 'cosine':
|
| 68 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 69 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
| 70 |
+
|
| 71 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
| 72 |
+
"""
|
| 73 |
+
Fixed euler solver for ODEs.
|
| 74 |
+
Args:
|
| 75 |
+
x (torch.Tensor): random noise
|
| 76 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 77 |
+
shape: (n_timesteps + 1,)
|
| 78 |
+
mu (torch.Tensor): output of encoder
|
| 79 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 80 |
+
mask (torch.Tensor): output_mask
|
| 81 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 82 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 83 |
+
shape: (batch_size, spk_emb_dim)
|
| 84 |
+
cond: Not used but kept for future purposes
|
| 85 |
+
"""
|
| 86 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 87 |
+
t = t.unsqueeze(dim=0)
|
| 88 |
+
|
| 89 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 90 |
+
# Or in future might add like a return_all_steps flag
|
| 91 |
+
sol = []
|
| 92 |
+
|
| 93 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
| 94 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 95 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
| 96 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 97 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
| 98 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
| 99 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 100 |
+
for step in range(1, len(t_span)):
|
| 101 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 102 |
+
x_in[:] = x
|
| 103 |
+
mask_in[:] = mask
|
| 104 |
+
mu_in[0] = mu
|
| 105 |
+
t_in[:] = t.unsqueeze(0)
|
| 106 |
+
spks_in[0] = spks
|
| 107 |
+
cond_in[0] = cond
|
| 108 |
+
dphi_dt = self.forward_estimator(
|
| 109 |
+
x_in, mask_in,
|
| 110 |
+
mu_in, t_in,
|
| 111 |
+
spks_in,
|
| 112 |
+
cond_in
|
| 113 |
+
)
|
| 114 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
| 115 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
| 116 |
+
x = x + dt * dphi_dt
|
| 117 |
+
t = t + dt
|
| 118 |
+
sol.append(x)
|
| 119 |
+
if step < len(t_span) - 1:
|
| 120 |
+
dt = t_span[step + 1] - t
|
| 121 |
+
|
| 122 |
+
return sol[-1].float()
|
| 123 |
+
|
| 124 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
| 125 |
+
if isinstance(self.estimator, torch.nn.Module):
|
| 126 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
| 127 |
+
else:
|
| 128 |
+
with self.lock:
|
| 129 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
| 130 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
| 131 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
| 132 |
+
self.estimator.set_input_shape('t', (2,))
|
| 133 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
| 134 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
| 135 |
+
# run trt engine
|
| 136 |
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
| 137 |
+
mask.contiguous().data_ptr(),
|
| 138 |
+
mu.contiguous().data_ptr(),
|
| 139 |
+
t.contiguous().data_ptr(),
|
| 140 |
+
spks.contiguous().data_ptr(),
|
| 141 |
+
cond.contiguous().data_ptr(),
|
| 142 |
+
x.data_ptr()])
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| 146 |
+
"""Computes diffusion loss
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
x1 (torch.Tensor): Target
|
| 150 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 151 |
+
mask (torch.Tensor): target mask
|
| 152 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 153 |
+
mu (torch.Tensor): output of encoder
|
| 154 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 155 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 156 |
+
shape: (batch_size, spk_emb_dim)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
loss: conditional flow matching loss
|
| 160 |
+
y: conditional flow
|
| 161 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 162 |
+
"""
|
| 163 |
+
b, _, t = mu.shape
|
| 164 |
+
|
| 165 |
+
# random timestep
|
| 166 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 167 |
+
if self.t_scheduler == 'cosine':
|
| 168 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 169 |
+
# sample noise p(x_0)
|
| 170 |
+
z = torch.randn_like(x1)
|
| 171 |
+
|
| 172 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 173 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 174 |
+
|
| 175 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
| 176 |
+
if self.training_cfg_rate > 0:
|
| 177 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| 178 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 179 |
+
spks = spks * cfg_mask.view(-1, 1)
|
| 180 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 181 |
+
|
| 182 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
| 183 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 184 |
+
return loss, y
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class CausalConditionalCFM(ConditionalCFM):
|
| 188 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| 189 |
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
| 190 |
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
| 191 |
+
|
| 192 |
+
@torch.inference_mode()
|
| 193 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| 194 |
+
"""Forward diffusion
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
mu (torch.Tensor): output of encoder
|
| 198 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 199 |
+
mask (torch.Tensor): output_mask
|
| 200 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 201 |
+
n_timesteps (int): number of diffusion steps
|
| 202 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 203 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 204 |
+
shape: (batch_size, spk_emb_dim)
|
| 205 |
+
cond: Not used but kept for future purposes
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
sample: generated mel-spectrogram
|
| 209 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
| 213 |
+
# fix prompt and overlap part mu and z
|
| 214 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 215 |
+
if self.t_scheduler == 'cosine':
|
| 216 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 217 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
cosyvoice/flow_speaker_minus/length_regulator.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class InterpolateRegulator(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
channels: int,
|
| 25 |
+
sampling_ratios: Tuple,
|
| 26 |
+
out_channels: int = None,
|
| 27 |
+
groups: int = 1,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.sampling_ratios = sampling_ratios
|
| 31 |
+
out_channels = out_channels or channels
|
| 32 |
+
model = nn.ModuleList([])
|
| 33 |
+
if len(sampling_ratios) > 0:
|
| 34 |
+
for _ in sampling_ratios:
|
| 35 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
| 36 |
+
norm = nn.GroupNorm(groups, channels)
|
| 37 |
+
act = nn.Mish()
|
| 38 |
+
model.extend([module, norm, act])
|
| 39 |
+
model.append(
|
| 40 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
| 41 |
+
)
|
| 42 |
+
self.model = nn.Sequential(*model)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, ylens=None):
|
| 45 |
+
# x in (B, T, D)
|
| 46 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
| 47 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
| 48 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 49 |
+
olens = ylens
|
| 50 |
+
return out * mask, olens
|
| 51 |
+
|
| 52 |
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
| 53 |
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
| 54 |
+
# x in (B, T, D)
|
| 55 |
+
if x2.shape[1] > 40:
|
| 56 |
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| 57 |
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
| 58 |
+
mode='linear')
|
| 59 |
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| 60 |
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
| 61 |
+
else:
|
| 62 |
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
| 63 |
+
if x1.shape[1] != 0:
|
| 64 |
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
| 65 |
+
x = torch.concat([x1, x2], dim=2)
|
| 66 |
+
else:
|
| 67 |
+
x = x2
|
| 68 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 69 |
+
return out, mel_len1 + mel_len2
|
cosyvoice/hifigan/__pycache__/discriminator.cpython-310.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|