Spaces:
Running
on
Zero
Running
on
Zero
Upload 21 files
Browse files- .gitattributes +6 -0
- LICENSE +9 -0
- LICENSE-APACHE +203 -0
- app.py +0 -0
- autoencoder.py +1227 -0
- inference.py +290 -0
- model.py +650 -0
- packages.txt +1 -0
- prompt_audio/EARS p004 freeform.mp3 +3 -0
- prompt_audio/EARS p005 freeform.mp3 +3 -0
- prompt_audio/EARS p028 freeform.mp3 +3 -0
- prompt_audio/EARS p036 freeform.mp3 +3 -0
- prompt_audio/expresso_02_ex03-ex01_calm_005.wav +3 -0
- prompt_audio/freesound_demon_chant(use_forcespeaker).mp3 +3 -0
- requirements.txt +8 -0
- sampler_presets.json +120 -0
- samplers.py +690 -0
- silentcipher/__init__.py +3 -0
- silentcipher/model.py +95 -0
- silentcipher/server.py +480 -0
- silentcipher/stft.py +40 -0
- text_presets.txt +42 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
prompt_audio/EARS[[:space:]]p004[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
prompt_audio/EARS[[:space:]]p005[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
prompt_audio/EARS[[:space:]]p028[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
prompt_audio/EARS[[:space:]]p036[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
prompt_audio/expresso_02_ex03-ex01_calm_005.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
prompt_audio/freesound_demon_chant(use_forcespeaker).mp3 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
Copyright 2025 Jordan Darefsky
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 6 |
+
|
| 7 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 8 |
+
|
| 9 |
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
LICENSE-APACHE
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This Apache 2.0 license applies only to autoencoder.py
|
| 2 |
+
|
| 3 |
+
Apache License
|
| 4 |
+
Version 2.0, January 2004
|
| 5 |
+
http://www.apache.org/licenses/
|
| 6 |
+
|
| 7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 8 |
+
|
| 9 |
+
1. Definitions.
|
| 10 |
+
|
| 11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 13 |
+
|
| 14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 15 |
+
the copyright owner that is granting the License.
|
| 16 |
+
|
| 17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 18 |
+
other entities that control, are controlled by, or are under common
|
| 19 |
+
control with that entity. For the purposes of this definition,
|
| 20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 21 |
+
direction or management of such entity, whether by contract or
|
| 22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 24 |
+
|
| 25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 26 |
+
exercising permissions granted by this License.
|
| 27 |
+
|
| 28 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 29 |
+
including but not limited to software source code, documentation
|
| 30 |
+
source, and configuration files.
|
| 31 |
+
|
| 32 |
+
"Object" form shall mean any form resulting from mechanical
|
| 33 |
+
transformation or translation of a Source form, including but
|
| 34 |
+
not limited to compiled object code, generated documentation,
|
| 35 |
+
and conversions to other media types.
|
| 36 |
+
|
| 37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 38 |
+
Object form, made available under the License, as indicated by a
|
| 39 |
+
copyright notice that is included in or attached to the work
|
| 40 |
+
(an example is provided in the Appendix below).
|
| 41 |
+
|
| 42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 43 |
+
form, that is based on (or derived from) the Work and for which the
|
| 44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 46 |
+
of this License, Derivative Works shall not include works that remain
|
| 47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 48 |
+
the Work and Derivative Works thereof.
|
| 49 |
+
|
| 50 |
+
"Contribution" shall mean any work of authorship, including
|
| 51 |
+
the original version of the Work and any modifications or additions
|
| 52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 56 |
+
means any form of electronic, verbal, or written communication sent
|
| 57 |
+
to the Licensor or its representatives, including but not limited to
|
| 58 |
+
communication on electronic mailing lists, source code control systems,
|
| 59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 61 |
+
excluding communication that is conspicuously marked or otherwise
|
| 62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 63 |
+
|
| 64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 66 |
+
subsequently incorporated within the Work.
|
| 67 |
+
|
| 68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 73 |
+
Work and such Derivative Works in Source or Object form.
|
| 74 |
+
|
| 75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 78 |
+
(except as stated in this section) patent license to make, have made,
|
| 79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 80 |
+
where such license applies only to those patent claims licensable
|
| 81 |
+
by such Contributor that are necessarily infringed by their
|
| 82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 84 |
+
institute patent litigation against any entity (including a
|
| 85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 86 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 87 |
+
or contributory patent infringement, then any patent licenses
|
| 88 |
+
granted to You under this License for that Work shall terminate
|
| 89 |
+
as of the date such litigation is filed.
|
| 90 |
+
|
| 91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 92 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 93 |
+
modifications, and in Source or Object form, provided that You
|
| 94 |
+
meet the following conditions:
|
| 95 |
+
|
| 96 |
+
(a) You must give any other recipients of the Work or
|
| 97 |
+
Derivative Works a copy of this License; and
|
| 98 |
+
|
| 99 |
+
(b) You must cause any modified files to carry prominent notices
|
| 100 |
+
stating that You changed the files; and
|
| 101 |
+
|
| 102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 103 |
+
that You distribute, all copyright, patent, trademark, and
|
| 104 |
+
attribution notices from the Source form of the Work,
|
| 105 |
+
excluding those notices that do not pertain to any part of
|
| 106 |
+
the Derivative Works; and
|
| 107 |
+
|
| 108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 109 |
+
distribution, then any Derivative Works that You distribute must
|
| 110 |
+
include a readable copy of the attribution notices contained
|
| 111 |
+
within such NOTICE file, excluding those notices that do not
|
| 112 |
+
pertain to any part of the Derivative Works, in at least one
|
| 113 |
+
of the following places: within a NOTICE text file distributed
|
| 114 |
+
as part of the Derivative Works; within the Source form or
|
| 115 |
+
documentation, if provided along with the Derivative Works; or,
|
| 116 |
+
within a display generated by the Derivative Works, if and
|
| 117 |
+
wherever such third-party notices normally appear. The contents
|
| 118 |
+
of the NOTICE file are for informational purposes only and
|
| 119 |
+
do not modify the License. You may add Your own attribution
|
| 120 |
+
notices within Derivative Works that You distribute, alongside
|
| 121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 122 |
+
that such additional attribution notices cannot be construed
|
| 123 |
+
as modifying the License.
|
| 124 |
+
|
| 125 |
+
You may add Your own copyright statement to Your modifications and
|
| 126 |
+
may provide additional or different license terms and conditions
|
| 127 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 128 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 129 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 130 |
+
the conditions stated in this License.
|
| 131 |
+
|
| 132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 134 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 135 |
+
this License, without any additional terms or conditions.
|
| 136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 137 |
+
the terms of any separate license agreement you may have executed
|
| 138 |
+
with Licensor regarding such Contributions.
|
| 139 |
+
|
| 140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 142 |
+
except as required for reasonable and customary use in describing the
|
| 143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 144 |
+
|
| 145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 146 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 149 |
+
implied, including, without limitation, any warranties or conditions
|
| 150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 152 |
+
appropriateness of using or redistributing the Work and assume any
|
| 153 |
+
risks associated with Your exercise of permissions under this License.
|
| 154 |
+
|
| 155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 156 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 157 |
+
unless required by applicable law (such as deliberate and grossly
|
| 158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 159 |
+
liable to You for damages, including any direct, indirect, special,
|
| 160 |
+
incidental, or consequential damages of any character arising as a
|
| 161 |
+
result of this License or out of the use or inability to use the
|
| 162 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 163 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 164 |
+
other commercial damages or losses), even if such Contributor
|
| 165 |
+
has been advised of the possibility of such damages.
|
| 166 |
+
|
| 167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 170 |
+
or other liability obligations and/or rights consistent with this
|
| 171 |
+
License. However, in accepting such obligations, You may act only
|
| 172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 173 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 174 |
+
defend, and hold each Contributor harmless for any liability
|
| 175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 176 |
+
of your accepting any such warranty or additional liability.
|
| 177 |
+
|
| 178 |
+
END OF TERMS AND CONDITIONS
|
| 179 |
+
|
| 180 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 181 |
+
|
| 182 |
+
To apply the Apache License to your work, attach the following
|
| 183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 184 |
+
replaced with your own identifying information. (Don't include
|
| 185 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 186 |
+
comment syntax for the file format. We also recommend that a
|
| 187 |
+
file or class name and description of purpose be included on the
|
| 188 |
+
same "printed page" as the copyright notice for easier
|
| 189 |
+
identification within third-party archives.
|
| 190 |
+
|
| 191 |
+
Copyright 2024 Fish Audio
|
| 192 |
+
|
| 193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 194 |
+
you may not use this file except in compliance with the License.
|
| 195 |
+
You may obtain a copy of the License at
|
| 196 |
+
|
| 197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 198 |
+
|
| 199 |
+
Unless required by applicable law or agreed to in writing, software
|
| 200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 202 |
+
See the License for the specific language governing permissions and
|
| 203 |
+
limitations under the License.
|
app.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
autoencoder.py
ADDED
|
@@ -0,0 +1,1227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: 2025 Jordan Darefsky
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# This file contains portions adapted from:
|
| 5 |
+
# • Descript Audio Codec (DAC) — MIT License (full text appended below)
|
| 6 |
+
# • Fish-Speech S1 DAC Autoencoder — reference implementation (Apache-2.0 / CC-BY-NC),
|
| 7 |
+
# rewritten here in a single-file Torch module for interoperability and transparency.
|
| 8 |
+
#
|
| 9 |
+
# OVERALL LICENSE (this file): Apache-2.0, except where explicitly marked:
|
| 10 |
+
# # SPDX-License-Identifier: MIT
|
| 11 |
+
# Keep these notices and the embedded MIT text if you redistribute this file.
|
| 12 |
+
|
| 13 |
+
# NOTE (style/provenance):
|
| 14 |
+
# Code in this module has been largely copy-and-pasted from the Fish-S1-DAC and DAC repositories,
|
| 15 |
+
# and refactored with help from ChatGPT/Claude (these models also helped with licensing).
|
| 16 |
+
# Thus, it stylistically differs from the rest of the codebase (I'm not even sure about internal consistency)
|
| 17 |
+
# and is likely much messier than it would have been had it been written from scratch.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import List, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from torch import Tensor, nn
|
| 29 |
+
from torch.nn import functional as F
|
| 30 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 31 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
| 32 |
+
|
| 33 |
+
from einops import rearrange
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# --------------------------------------------------------------------
|
| 37 |
+
# Shared helpers
|
| 38 |
+
# --------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def find_multiple(n: int, k: int) -> int:
|
| 41 |
+
return n if n % k == 0 else n + k - (n % k)
|
| 42 |
+
|
| 43 |
+
def unpad1d(x: Tensor, paddings: Tuple[int, int]) -> Tensor:
|
| 44 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 45 |
+
padding_left, padding_right = paddings
|
| 46 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 47 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 48 |
+
end = x.shape[-1] - padding_right
|
| 49 |
+
return x[..., padding_left:end]
|
| 50 |
+
|
| 51 |
+
def get_extra_padding_for_conv1d(
|
| 52 |
+
x: Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 53 |
+
) -> int:
|
| 54 |
+
"""See pad_for_conv1d; enough right pad so striding evenly covers length."""
|
| 55 |
+
length = x.shape[-1]
|
| 56 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 57 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 58 |
+
return ideal_length - length
|
| 59 |
+
|
| 60 |
+
def pad1d(
|
| 61 |
+
x: Tensor,
|
| 62 |
+
paddings: Tuple[int, int],
|
| 63 |
+
mode: str = "zeros",
|
| 64 |
+
value: float = 0.0,
|
| 65 |
+
) -> Tensor:
|
| 66 |
+
"""
|
| 67 |
+
Reflect‑safe 1D pad: if reflect would underflow on small inputs, insert
|
| 68 |
+
temporary right zero-pad before reflecting.
|
| 69 |
+
"""
|
| 70 |
+
length = x.shape[-1]
|
| 71 |
+
padding_left, padding_right = paddings
|
| 72 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 73 |
+
if mode == "reflect":
|
| 74 |
+
max_pad = max(padding_left, padding_right)
|
| 75 |
+
extra_pad = 0
|
| 76 |
+
if length <= max_pad:
|
| 77 |
+
extra_pad = max_pad - length + 1
|
| 78 |
+
x = F.pad(x, (0, extra_pad))
|
| 79 |
+
padded = F.pad(x, (padding_left, padding_right), mode, value)
|
| 80 |
+
end = padded.shape[-1] - extra_pad
|
| 81 |
+
return padded[..., :end]
|
| 82 |
+
else:
|
| 83 |
+
return F.pad(x, (padding_left, padding_right), mode, value)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# --------------------------------------------------------------------
|
| 87 |
+
# DAC Layers (adapted) — MIT
|
| 88 |
+
# Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py
|
| 89 |
+
# SPDX-License-Identifier: MIT
|
| 90 |
+
# --------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def WNConv1d(*args, **kwargs):
|
| 93 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 94 |
+
|
| 95 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 96 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 97 |
+
|
| 98 |
+
@torch.jit.script
|
| 99 |
+
def snake(x: Tensor, alpha: Tensor) -> Tensor:
|
| 100 |
+
shape = x.shape
|
| 101 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 102 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 103 |
+
x = x.reshape(shape)
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
class Snake1d(nn.Module):
|
| 107 |
+
def __init__(self, channels: int):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 110 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 111 |
+
return snake(x, self.alpha)
|
| 112 |
+
|
| 113 |
+
# --------------------------------------------------------------------
|
| 114 |
+
# DAC Vector Quantize (adapted) — MIT
|
| 115 |
+
# Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/quantize.py
|
| 116 |
+
# SPDX-License-Identifier: MIT
|
| 117 |
+
# --------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
class VectorQuantize(nn.Module):
|
| 120 |
+
"""
|
| 121 |
+
VQ with factorized, l2-normalized codes (ViT‑VQGAN style).
|
| 122 |
+
I/O in (B, D, T).
|
| 123 |
+
"""
|
| 124 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.codebook_size = codebook_size
|
| 127 |
+
self.codebook_dim = codebook_dim
|
| 128 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
| 129 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
| 130 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 131 |
+
|
| 132 |
+
def forward(self, z: Tensor):
|
| 133 |
+
z_e = self.in_proj(z) # (B, D, T)
|
| 134 |
+
z_q, indices = self.decode_latents(z_e)
|
| 135 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 136 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 137 |
+
z_q = z_e + (z_q - z_e).detach() # straight‑through
|
| 138 |
+
z_q = self.out_proj(z_q)
|
| 139 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 140 |
+
|
| 141 |
+
def embed_code(self, embed_id: Tensor) -> Tensor:
|
| 142 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 143 |
+
|
| 144 |
+
def decode_code(self, embed_id: Tensor) -> Tensor:
|
| 145 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 146 |
+
|
| 147 |
+
def decode_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
|
| 148 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 149 |
+
codebook = self.codebook.weight
|
| 150 |
+
encodings = F.normalize(encodings)
|
| 151 |
+
codebook = F.normalize(codebook)
|
| 152 |
+
dist = (
|
| 153 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 154 |
+
- 2 * encodings @ codebook.t()
|
| 155 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 156 |
+
)
|
| 157 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 158 |
+
z_q = self.decode_code(indices)
|
| 159 |
+
return z_q, indices
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class ResidualVectorQuantize(nn.Module):
|
| 163 |
+
"""SoundStream-style residual VQ stack."""
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
input_dim: int = 512,
|
| 167 |
+
n_codebooks: int = 9,
|
| 168 |
+
codebook_size: int = 1024,
|
| 169 |
+
codebook_dim: Union[int, List[int]] = 8,
|
| 170 |
+
quantizer_dropout: float = 0.0,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
if isinstance(codebook_dim, int):
|
| 174 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 175 |
+
|
| 176 |
+
self.n_codebooks = n_codebooks
|
| 177 |
+
self.codebook_dim = codebook_dim
|
| 178 |
+
self.codebook_size = codebook_size
|
| 179 |
+
|
| 180 |
+
self.quantizers = nn.ModuleList([
|
| 181 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
| 182 |
+
for i in range(n_codebooks)
|
| 183 |
+
])
|
| 184 |
+
self.quantizer_dropout = quantizer_dropout
|
| 185 |
+
|
| 186 |
+
def forward(self, z: Tensor, n_quantizers: Optional[int] = None):
|
| 187 |
+
z_q = 0
|
| 188 |
+
residual = z
|
| 189 |
+
commitment_loss = 0
|
| 190 |
+
codebook_loss = 0
|
| 191 |
+
|
| 192 |
+
codebook_indices = []
|
| 193 |
+
latents = []
|
| 194 |
+
|
| 195 |
+
if n_quantizers is None:
|
| 196 |
+
n_quantizers = self.n_codebooks
|
| 197 |
+
if self.training:
|
| 198 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 199 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 200 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 201 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 202 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 203 |
+
|
| 204 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 205 |
+
if self.training is False and i >= n_quantizers:
|
| 206 |
+
break
|
| 207 |
+
|
| 208 |
+
z_q_i, commit_i, codebk_i, indices_i, z_e_i = quantizer(residual)
|
| 209 |
+
|
| 210 |
+
mask = (torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers)
|
| 211 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 212 |
+
residual = residual - z_q_i
|
| 213 |
+
|
| 214 |
+
commitment_loss += (commit_i * mask).mean()
|
| 215 |
+
codebook_loss += (codebk_i * mask).mean()
|
| 216 |
+
|
| 217 |
+
codebook_indices.append(indices_i)
|
| 218 |
+
latents.append(z_e_i)
|
| 219 |
+
|
| 220 |
+
codes = torch.stack(codebook_indices, dim=1)
|
| 221 |
+
latents = torch.cat(latents, dim=1)
|
| 222 |
+
|
| 223 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
| 224 |
+
|
| 225 |
+
def from_codes(self, codes: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
| 226 |
+
z_q = 0.0
|
| 227 |
+
z_p = []
|
| 228 |
+
n_codebooks = codes.shape[1]
|
| 229 |
+
for i in range(n_codebooks):
|
| 230 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 231 |
+
z_p.append(z_p_i)
|
| 232 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 233 |
+
z_q = z_q + z_q_i
|
| 234 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
| 235 |
+
|
| 236 |
+
def from_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
| 237 |
+
z_q = 0
|
| 238 |
+
z_p = []
|
| 239 |
+
codes = []
|
| 240 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 241 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
| 242 |
+
for i in range(n_codebooks):
|
| 243 |
+
j, k = dims[i], dims[i + 1]
|
| 244 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 245 |
+
z_p.append(z_p_i)
|
| 246 |
+
codes.append(codes_i)
|
| 247 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 248 |
+
z_q = z_q + z_q_i
|
| 249 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# --------------------------------------------------------------------
|
| 253 |
+
# S1 DAC rvq
|
| 254 |
+
# --------------------------------------------------------------------
|
| 255 |
+
|
| 256 |
+
@dataclass
|
| 257 |
+
class VQResult:
|
| 258 |
+
z: Tensor
|
| 259 |
+
codes: Tensor
|
| 260 |
+
latents: Tensor
|
| 261 |
+
codebook_loss: Tensor
|
| 262 |
+
commitment_loss: Tensor
|
| 263 |
+
semantic_distill_z: Optional[Tensor] = None
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class CausalConvNet(nn.Module):
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
in_channels,
|
| 270 |
+
out_channels,
|
| 271 |
+
kernel_size,
|
| 272 |
+
dilation=1,
|
| 273 |
+
stride=1,
|
| 274 |
+
groups=1,
|
| 275 |
+
padding=None,
|
| 276 |
+
):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.conv = nn.Conv1d(
|
| 279 |
+
in_channels, out_channels, kernel_size,
|
| 280 |
+
stride=stride, dilation=dilation, groups=groups,
|
| 281 |
+
)
|
| 282 |
+
self.stride = stride
|
| 283 |
+
self.kernel_size = (kernel_size - 1) * dilation + 1
|
| 284 |
+
self.dilation = dilation
|
| 285 |
+
self.padding = self.kernel_size - self.stride
|
| 286 |
+
|
| 287 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 288 |
+
pad = self.padding
|
| 289 |
+
extra = get_extra_padding_for_conv1d(x, self.kernel_size, self.stride, pad)
|
| 290 |
+
x = pad1d(x, (pad, extra), mode="constant", value=0)
|
| 291 |
+
return self.conv(x).contiguous()
|
| 292 |
+
|
| 293 |
+
def weight_norm(self, name="weight", dim=0):
|
| 294 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
| 295 |
+
return self
|
| 296 |
+
|
| 297 |
+
def remove_weight_norm(self):
|
| 298 |
+
self.conv = remove_parametrizations(self.conv)
|
| 299 |
+
return self
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class CausalTransConvNet(nn.Module):
|
| 303 |
+
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None):
|
| 304 |
+
super().__init__()
|
| 305 |
+
self.conv = nn.ConvTranspose1d(
|
| 306 |
+
in_channels, out_channels, kernel_size,
|
| 307 |
+
stride=stride, dilation=dilation
|
| 308 |
+
)
|
| 309 |
+
self.stride = stride
|
| 310 |
+
self.kernel_size = kernel_size
|
| 311 |
+
|
| 312 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 313 |
+
x = self.conv(x)
|
| 314 |
+
pad = self.kernel_size - self.stride
|
| 315 |
+
padding_right = math.ceil(pad)
|
| 316 |
+
padding_left = pad - padding_right
|
| 317 |
+
x = unpad1d(x, (padding_left, padding_right))
|
| 318 |
+
return x.contiguous()
|
| 319 |
+
|
| 320 |
+
def weight_norm(self, name="weight", dim=0):
|
| 321 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
| 322 |
+
return self
|
| 323 |
+
|
| 324 |
+
def remove_weight_norm(self):
|
| 325 |
+
self.conv = remove_parametrizations(self.conv)
|
| 326 |
+
return self
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def CausalWNConv1d(*args, **kwargs):
|
| 330 |
+
return CausalConvNet(*args, **kwargs).weight_norm()
|
| 331 |
+
|
| 332 |
+
def CausalWNConvTranspose1d(*args, **kwargs):
|
| 333 |
+
return CausalTransConvNet(*args, **kwargs).weight_norm()
|
| 334 |
+
|
| 335 |
+
class ConvNeXtBlock(nn.Module):
|
| 336 |
+
r"""ConvNeXt Block (1D).
|
| 337 |
+
DwConv -> (N, C, L) → (N, L, C) -> LN -> Linear -> GELU -> Linear -> (N, C, L) with residual
|
| 338 |
+
"""
|
| 339 |
+
def __init__(
|
| 340 |
+
self,
|
| 341 |
+
dim: int,
|
| 342 |
+
layer_scale_init_value: float = 1e-6,
|
| 343 |
+
mlp_ratio: float = 4.0,
|
| 344 |
+
kernel_size: int = 7,
|
| 345 |
+
dilation: int = 1,
|
| 346 |
+
):
|
| 347 |
+
super().__init__()
|
| 348 |
+
convnet_type = CausalConvNet
|
| 349 |
+
self.dwconv = convnet_type(
|
| 350 |
+
dim, dim, kernel_size=kernel_size,
|
| 351 |
+
groups=dim, dilation=dilation,
|
| 352 |
+
) # depthwise conv
|
| 353 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 354 |
+
self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim))
|
| 355 |
+
self.act = nn.GELU()
|
| 356 |
+
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
| 357 |
+
self.gamma = (
|
| 358 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 359 |
+
if layer_scale_init_value > 0 else None
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
def forward(self, x: Tensor, apply_residual: bool = True) -> Tensor:
|
| 363 |
+
inp = x
|
| 364 |
+
x = self.dwconv(x)
|
| 365 |
+
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
| 366 |
+
x = self.norm(x)
|
| 367 |
+
x = self.pwconv1(x)
|
| 368 |
+
x = self.act(x)
|
| 369 |
+
x = self.pwconv2(x)
|
| 370 |
+
if self.gamma is not None:
|
| 371 |
+
x = self.gamma * x
|
| 372 |
+
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
| 373 |
+
if apply_residual:
|
| 374 |
+
x = inp + x
|
| 375 |
+
return x
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class DownsampleResidualVectorQuantize(nn.Module):
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
input_dim: int = 1024,
|
| 382 |
+
n_codebooks: int = 9,
|
| 383 |
+
codebook_dim: int = 8,
|
| 384 |
+
quantizer_dropout: float = 0.5,
|
| 385 |
+
codebook_size: int = 1024,
|
| 386 |
+
semantic_codebook_size: int = 4096,
|
| 387 |
+
downsample_factor: Tuple[int, ...] = (2, 2),
|
| 388 |
+
downsample_dims: Optional[Tuple[int, ...]] = None,
|
| 389 |
+
pre_module: Optional[nn.Module] = None,
|
| 390 |
+
post_module: Optional[nn.Module] = None,
|
| 391 |
+
semantic_predictor_module: Optional[nn.Module] = None,
|
| 392 |
+
):
|
| 393 |
+
super().__init__()
|
| 394 |
+
|
| 395 |
+
if downsample_dims is None:
|
| 396 |
+
downsample_dims = tuple(input_dim for _ in range(len(downsample_factor)))
|
| 397 |
+
|
| 398 |
+
all_dims = (input_dim,) + tuple(downsample_dims)
|
| 399 |
+
|
| 400 |
+
self.semantic_quantizer = ResidualVectorQuantize(
|
| 401 |
+
input_dim=input_dim,
|
| 402 |
+
n_codebooks=1,
|
| 403 |
+
codebook_size=semantic_codebook_size,
|
| 404 |
+
codebook_dim=codebook_dim,
|
| 405 |
+
quantizer_dropout=0.0,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
self.quantizer = ResidualVectorQuantize(
|
| 409 |
+
input_dim=input_dim,
|
| 410 |
+
n_codebooks=n_codebooks,
|
| 411 |
+
codebook_size=codebook_size,
|
| 412 |
+
codebook_dim=codebook_dim,
|
| 413 |
+
quantizer_dropout=quantizer_dropout,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
convnet_type = CausalConvNet
|
| 417 |
+
transconvnet_type = CausalTransConvNet
|
| 418 |
+
|
| 419 |
+
self.downsample = nn.Sequential(
|
| 420 |
+
*[
|
| 421 |
+
nn.Sequential(
|
| 422 |
+
convnet_type(all_dims[idx], all_dims[idx + 1], kernel_size=factor, stride=factor),
|
| 423 |
+
ConvNeXtBlock(dim=all_dims[idx + 1]),
|
| 424 |
+
)
|
| 425 |
+
for idx, factor in enumerate(downsample_factor)
|
| 426 |
+
]
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
self.upsample = nn.Sequential(
|
| 430 |
+
*[
|
| 431 |
+
nn.Sequential(
|
| 432 |
+
transconvnet_type(all_dims[idx + 1], all_dims[idx], kernel_size=factor, stride=factor),
|
| 433 |
+
ConvNeXtBlock(dim=all_dims[idx]),
|
| 434 |
+
)
|
| 435 |
+
for idx, factor in reversed(list(enumerate(downsample_factor)))
|
| 436 |
+
]
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
self.apply(self._init_weights)
|
| 440 |
+
self.pre_module = pre_module if pre_module is not None else nn.Identity()
|
| 441 |
+
self.post_module = post_module if post_module is not None else nn.Identity()
|
| 442 |
+
self.semantic_predictor_module = (
|
| 443 |
+
semantic_predictor_module if semantic_predictor_module is not None else nn.Identity()
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
@staticmethod
|
| 447 |
+
def _init_weights(m):
|
| 448 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 449 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 450 |
+
if getattr(m, "bias", None) is not None:
|
| 451 |
+
nn.init.constant_(m.bias, 0)
|
| 452 |
+
|
| 453 |
+
def forward(self, z: Tensor, n_quantizers: Optional[int] = None, semantic_len: Optional[Tensor] = None, **kwargs):
|
| 454 |
+
# z: (B, D, T)
|
| 455 |
+
original_shape = z.shape
|
| 456 |
+
if semantic_len is None:
|
| 457 |
+
semantic_len = torch.LongTensor([z.shape[-1]])
|
| 458 |
+
|
| 459 |
+
z = self.downsample(z)
|
| 460 |
+
z = self.pre_module(z) # (B, D, T) or (B, T, D) depending on module; original uses channels-first in/out
|
| 461 |
+
|
| 462 |
+
semantic_z, semantic_codes, semantic_latents, semantic_commitment_loss, semantic_codebook_loss = \
|
| 463 |
+
self.semantic_quantizer(z)
|
| 464 |
+
residual_z = z - semantic_z
|
| 465 |
+
residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(residual_z, n_quantizers=n_quantizers)
|
| 466 |
+
z = semantic_z + residual_z
|
| 467 |
+
commitment_loss = commitment_loss + semantic_commitment_loss
|
| 468 |
+
codebook_loss = codebook_loss + semantic_codebook_loss
|
| 469 |
+
codes = torch.cat([semantic_codes, codes], dim=1)
|
| 470 |
+
latents = torch.cat([semantic_latents, latents], dim=1)
|
| 471 |
+
z = self.post_module(z)
|
| 472 |
+
z = self.upsample(z)
|
| 473 |
+
|
| 474 |
+
# Pad or crop z to match original shape (time dimension)
|
| 475 |
+
diff = original_shape[-1] - z.shape[-1]
|
| 476 |
+
right = 0
|
| 477 |
+
left = abs(diff) - right
|
| 478 |
+
if diff > 0:
|
| 479 |
+
z = F.pad(z, (left, right))
|
| 480 |
+
elif diff < 0:
|
| 481 |
+
z = z[..., left:]
|
| 482 |
+
|
| 483 |
+
return VQResult(
|
| 484 |
+
z=z, codes=codes, latents=latents,
|
| 485 |
+
commitment_loss=commitment_loss, codebook_loss=codebook_loss,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def decode(self, indices: Tensor) -> Tensor:
|
| 489 |
+
new_indices = torch.zeros_like(indices)
|
| 490 |
+
new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.semantic_quantizer.codebook_size - 1)
|
| 491 |
+
new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.codebook_size - 1)
|
| 492 |
+
|
| 493 |
+
z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0]
|
| 494 |
+
z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0]
|
| 495 |
+
z_q = z_q_semantic + z_q_residual
|
| 496 |
+
z_q = self.post_module(z_q)
|
| 497 |
+
z_q = self.upsample(z_q)
|
| 498 |
+
return z_q
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# --------------------------------------------------------------------
|
| 502 |
+
# Transformer stack
|
| 503 |
+
# --------------------------------------------------------------------
|
| 504 |
+
|
| 505 |
+
@dataclass
|
| 506 |
+
class ModelArgs:
|
| 507 |
+
block_size: int = 2048
|
| 508 |
+
n_layer: int = 8
|
| 509 |
+
n_head: int = 8
|
| 510 |
+
dim: int = 512
|
| 511 |
+
intermediate_size: int = 1536
|
| 512 |
+
n_local_heads: int = -1
|
| 513 |
+
head_dim: int = 64
|
| 514 |
+
rope_base: float = 10000
|
| 515 |
+
norm_eps: float = 1e-5
|
| 516 |
+
dropout_rate: float = 0.1
|
| 517 |
+
attn_dropout_rate: float = 0.1
|
| 518 |
+
channels_first: bool = True # to be compatible with conv1d input/output
|
| 519 |
+
pos_embed_type: str = "rope" # "rope" or "conformer"
|
| 520 |
+
max_relative_position: int = 128
|
| 521 |
+
|
| 522 |
+
def __post_init__(self):
|
| 523 |
+
if self.n_local_heads == -1:
|
| 524 |
+
self.n_local_heads = self.n_head
|
| 525 |
+
if self.intermediate_size is None:
|
| 526 |
+
hidden_dim = 4 * self.dim
|
| 527 |
+
n_hidden = int(2 * hidden_dim / 3)
|
| 528 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
| 529 |
+
assert self.pos_embed_type in ["rope", "conformer"]
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class KVCache(nn.Module):
|
| 533 |
+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
|
| 534 |
+
super().__init__()
|
| 535 |
+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
| 536 |
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
| 537 |
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
| 538 |
+
|
| 539 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
| 540 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
| 541 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
| 542 |
+
k_out = self.k_cache
|
| 543 |
+
v_out = self.v_cache
|
| 544 |
+
k_out[:, :, input_pos] = k_val
|
| 545 |
+
v_out[:, :, input_pos] = v_val
|
| 546 |
+
return (
|
| 547 |
+
k_out[:, :, : input_pos.max() + 1, :],
|
| 548 |
+
v_out[:, :, : input_pos.max() + 1, :],
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
def clear_cache(self, prompt_len: int):
|
| 552 |
+
self.k_cache[:, :, prompt_len:, :].fill_(0)
|
| 553 |
+
self.v_cache[:, :, prompt_len:, :].fill_(0)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class Transformer(nn.Module):
|
| 557 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 558 |
+
super().__init__()
|
| 559 |
+
self.config = config
|
| 560 |
+
|
| 561 |
+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
| 562 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 563 |
+
|
| 564 |
+
if config.pos_embed_type == "rope":
|
| 565 |
+
freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base)
|
| 566 |
+
self.register_buffer("freqs_cis", freqs_cis)
|
| 567 |
+
else:
|
| 568 |
+
self.register_buffer("freqs_cis", None)
|
| 569 |
+
|
| 570 |
+
causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool))
|
| 571 |
+
self.register_buffer("causal_mask", causal_mask)
|
| 572 |
+
|
| 573 |
+
self.max_batch_size = -1
|
| 574 |
+
self.max_seq_length = -1
|
| 575 |
+
self.use_kv_cache = False
|
| 576 |
+
|
| 577 |
+
def setup_caches(self, max_batch_size, max_seq_length):
|
| 578 |
+
head_dim = self.config.dim // self.config.n_head
|
| 579 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
| 580 |
+
self.max_seq_length = max_seq_length
|
| 581 |
+
self.max_batch_size = max_batch_size
|
| 582 |
+
dtype = self.norm.weight.dtype
|
| 583 |
+
device = self.norm.weight.device
|
| 584 |
+
|
| 585 |
+
for b in self.layers:
|
| 586 |
+
b.attention.kv_cache = KVCache(
|
| 587 |
+
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype
|
| 588 |
+
).to(device)
|
| 589 |
+
|
| 590 |
+
self.use_kv_cache = True
|
| 591 |
+
|
| 592 |
+
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor:
|
| 593 |
+
if self.config.pos_embed_type == "rope":
|
| 594 |
+
assert self.freqs_cis is not None
|
| 595 |
+
freqs_cis = self.freqs_cis[input_pos]
|
| 596 |
+
else:
|
| 597 |
+
freqs_cis = None
|
| 598 |
+
|
| 599 |
+
if mask is None:
|
| 600 |
+
if not self.training and self.use_kv_cache:
|
| 601 |
+
mask = self.causal_mask[None, None, input_pos]
|
| 602 |
+
mask = mask[..., : input_pos.max() + 1]
|
| 603 |
+
else:
|
| 604 |
+
mask = self.causal_mask[None, None, input_pos]
|
| 605 |
+
mask = mask[..., input_pos]
|
| 606 |
+
|
| 607 |
+
for layer in self.layers:
|
| 608 |
+
x = layer(x, input_pos, freqs_cis, mask)
|
| 609 |
+
x = self.norm(x)
|
| 610 |
+
return x
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class TransformerBlock(nn.Module):
|
| 614 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 615 |
+
super().__init__()
|
| 616 |
+
self.attention = Attention(config)
|
| 617 |
+
self.feed_forward = FeedForward(config)
|
| 618 |
+
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 619 |
+
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 620 |
+
self.attention_layer_scale = LayerScale(config.dim, inplace=True)
|
| 621 |
+
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
|
| 622 |
+
|
| 623 |
+
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
|
| 624 |
+
h = x + self.attention_layer_scale(
|
| 625 |
+
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
| 626 |
+
)
|
| 627 |
+
out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
|
| 628 |
+
return out
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
class Attention(nn.Module):
|
| 632 |
+
def __init__(self, config: ModelArgs):
|
| 633 |
+
super().__init__()
|
| 634 |
+
assert config.dim % config.n_head == 0
|
| 635 |
+
|
| 636 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
| 637 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
| 638 |
+
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
| 639 |
+
self.kv_cache = None
|
| 640 |
+
|
| 641 |
+
self.n_head = config.n_head
|
| 642 |
+
self.head_dim = config.head_dim
|
| 643 |
+
self.n_local_heads = config.n_local_heads
|
| 644 |
+
self.dim = config.dim
|
| 645 |
+
self.attn_dropout_rate = config.attn_dropout_rate
|
| 646 |
+
self.pos_embed_type = config.pos_embed_type
|
| 647 |
+
|
| 648 |
+
if self.pos_embed_type == "conformer":
|
| 649 |
+
self.max_relative_position = config.max_relative_position
|
| 650 |
+
num_pos_embeddings = 2 * config.max_relative_position + 1
|
| 651 |
+
self.rel_pos_embeddings = nn.Parameter(torch.zeros(num_pos_embeddings, self.head_dim))
|
| 652 |
+
nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
|
| 653 |
+
|
| 654 |
+
def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
|
| 655 |
+
positions = torch.arange(seqlen, device=q.device)
|
| 656 |
+
relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
|
| 657 |
+
relative_positions = torch.clamp(relative_positions + self.max_relative_position,
|
| 658 |
+
0, 2 * self.max_relative_position)
|
| 659 |
+
rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
|
| 660 |
+
q = q.transpose(1, 2) # [B, S, H, D]
|
| 661 |
+
rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
|
| 662 |
+
rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
|
| 663 |
+
return rel_logits
|
| 664 |
+
|
| 665 |
+
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
| 666 |
+
bsz, seqlen, _ = x.shape
|
| 667 |
+
|
| 668 |
+
kv_size = self.n_local_heads * self.head_dim
|
| 669 |
+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
| 670 |
+
context_seqlen = seqlen
|
| 671 |
+
|
| 672 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 673 |
+
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
| 674 |
+
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
| 675 |
+
|
| 676 |
+
if self.pos_embed_type == "rope":
|
| 677 |
+
q = apply_rotary_emb(q, freqs_cis)
|
| 678 |
+
k = apply_rotary_emb(k, freqs_cis)
|
| 679 |
+
|
| 680 |
+
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
|
| 681 |
+
|
| 682 |
+
if self.kv_cache is not None:
|
| 683 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
| 684 |
+
|
| 685 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
| 686 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
| 687 |
+
|
| 688 |
+
if self.pos_embed_type == "conformer":
|
| 689 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
| 690 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
| 691 |
+
rel_scores = self._compute_conformer_pos_scores(q, seqlen)
|
| 692 |
+
scores = scores + rel_scores
|
| 693 |
+
if mask is not None:
|
| 694 |
+
scores = scores.masked_fill(~mask, float("-inf"))
|
| 695 |
+
attn = F.softmax(scores, dim=-1)
|
| 696 |
+
if self.attn_dropout_rate > 0 and self.training:
|
| 697 |
+
attn = F.dropout(attn, p=self.attn_dropout_rate)
|
| 698 |
+
y = torch.matmul(attn, v)
|
| 699 |
+
else:
|
| 700 |
+
y = F.scaled_dot_product_attention(
|
| 701 |
+
q, k, v,
|
| 702 |
+
dropout_p=self.attn_dropout_rate if self.training else 0.0,
|
| 703 |
+
attn_mask=mask,
|
| 704 |
+
)
|
| 705 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
|
| 706 |
+
y = self.wo(y)
|
| 707 |
+
return y
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
class FeedForward(nn.Module):
|
| 711 |
+
def __init__(self, config: ModelArgs) -> None:
|
| 712 |
+
super().__init__()
|
| 713 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
| 714 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
| 715 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
| 716 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 717 |
+
|
| 718 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 719 |
+
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class RMSNorm(nn.Module):
|
| 723 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 724 |
+
super().__init__()
|
| 725 |
+
self.eps = eps
|
| 726 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 727 |
+
|
| 728 |
+
def _norm(self, x):
|
| 729 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
| 730 |
+
|
| 731 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 732 |
+
output = self._norm(x.float()).type_as(x)
|
| 733 |
+
return output * self.weight
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
class LayerScale(nn.Module):
|
| 737 |
+
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-2, inplace: bool = False) -> None:
|
| 738 |
+
super().__init__()
|
| 739 |
+
self.inplace = inplace
|
| 740 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 741 |
+
|
| 742 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 743 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
class WindowLimitedTransformer(Transformer):
|
| 747 |
+
"""Transformer with window-limited causal attention."""
|
| 748 |
+
def __init__(
|
| 749 |
+
self,
|
| 750 |
+
config: ModelArgs,
|
| 751 |
+
input_dim: int = 512,
|
| 752 |
+
window_size: Optional[int] = None,
|
| 753 |
+
causal: bool = True,
|
| 754 |
+
look_ahead_conv: Optional[nn.Module] = None,
|
| 755 |
+
):
|
| 756 |
+
super().__init__(config)
|
| 757 |
+
self.window_size = window_size
|
| 758 |
+
self.causal = causal
|
| 759 |
+
self.channels_first = config.channels_first
|
| 760 |
+
self.look_ahead_conv = look_ahead_conv if look_ahead_conv is not None else nn.Identity()
|
| 761 |
+
self.input_proj = nn.Linear(input_dim, config.dim) if input_dim != config.dim else nn.Identity()
|
| 762 |
+
self.output_proj = nn.Linear(config.dim, input_dim) if input_dim != config.dim else nn.Identity()
|
| 763 |
+
|
| 764 |
+
def make_window_limited_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor:
|
| 765 |
+
if self.causal:
|
| 766 |
+
mask = torch.tril(torch.ones(max_length, max_length))
|
| 767 |
+
row_indices = torch.arange(max_length).view(-1, 1)
|
| 768 |
+
window_size = self.window_size or max_length
|
| 769 |
+
valid_range = (row_indices - window_size + 1).clamp(min=0)
|
| 770 |
+
column_indices = torch.arange(max_length)
|
| 771 |
+
mask = (column_indices >= valid_range) & mask.bool()
|
| 772 |
+
else:
|
| 773 |
+
raise NotImplementedError
|
| 774 |
+
mask = mask.bool()[None, None]
|
| 775 |
+
return mask
|
| 776 |
+
|
| 777 |
+
def make_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor:
|
| 778 |
+
if self.causal:
|
| 779 |
+
mask = torch.tril(torch.ones(max_length, max_length))
|
| 780 |
+
else:
|
| 781 |
+
mask = torch.ones(max_length, max_length)
|
| 782 |
+
mask = mask.bool()[None, None]
|
| 783 |
+
for i, x_len in enumerate(x_lens):
|
| 784 |
+
mask[:x_len, i] = 0
|
| 785 |
+
mask = mask.bool()[None, None]
|
| 786 |
+
return mask
|
| 787 |
+
|
| 788 |
+
def forward(self, x: Tensor, x_lens: Optional[Tensor] = None) -> Tensor:
|
| 789 |
+
if self.channels_first:
|
| 790 |
+
x = x.transpose(1, 2)
|
| 791 |
+
x = self.input_proj(x)
|
| 792 |
+
x = self.look_ahead_conv(x)
|
| 793 |
+
input_pos = torch.arange(x.shape[1], device=x.device)
|
| 794 |
+
max_length = x.shape[1]
|
| 795 |
+
if self.window_size is not None:
|
| 796 |
+
mask = self.make_window_limited_mask(max_length, x_lens)
|
| 797 |
+
else:
|
| 798 |
+
mask = self.make_mask(max_length, x_lens)
|
| 799 |
+
mask = mask.to(x.device)
|
| 800 |
+
x = super().forward(x, input_pos, mask)
|
| 801 |
+
x = self.output_proj(x)
|
| 802 |
+
if self.channels_first:
|
| 803 |
+
x = x.transpose(1, 2)
|
| 804 |
+
return x
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def precompute_freqs_cis(
|
| 808 |
+
seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
|
| 809 |
+
) -> Tensor:
|
| 810 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
| 811 |
+
t = torch.arange(seq_len, device=freqs.device)
|
| 812 |
+
freqs = torch.outer(t, freqs)
|
| 813 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 814 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
| 815 |
+
return cache.to(dtype=dtype)
|
| 816 |
+
|
| 817 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
| 818 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
| 819 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
| 820 |
+
x_out2 = torch.stack(
|
| 821 |
+
[
|
| 822 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
| 823 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
| 824 |
+
],
|
| 825 |
+
-1,
|
| 826 |
+
)
|
| 827 |
+
x_out2 = x_out2.flatten(3)
|
| 828 |
+
return x_out2.type_as(x)
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def init_weights(m):
|
| 832 |
+
if isinstance(m, nn.Conv1d):
|
| 833 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 834 |
+
nn.init.constant_(m.bias, 0)
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
# --------------------------------------------------------------------
|
| 838 |
+
# Top-level AE
|
| 839 |
+
# --------------------------------------------------------------------
|
| 840 |
+
|
| 841 |
+
class EncoderBlock(nn.Module):
|
| 842 |
+
def __init__(
|
| 843 |
+
self,
|
| 844 |
+
dim: int = 16,
|
| 845 |
+
stride: int = 1,
|
| 846 |
+
causal: bool = False,
|
| 847 |
+
n_t_layer: int = 0,
|
| 848 |
+
transformer_general_config=None,
|
| 849 |
+
):
|
| 850 |
+
super().__init__()
|
| 851 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
| 852 |
+
transformer_module = (
|
| 853 |
+
nn.Identity()
|
| 854 |
+
if n_t_layer == 0
|
| 855 |
+
else WindowLimitedTransformer(
|
| 856 |
+
causal=causal,
|
| 857 |
+
input_dim=dim,
|
| 858 |
+
window_size=512,
|
| 859 |
+
config=transformer_general_config(
|
| 860 |
+
n_layer=n_t_layer,
|
| 861 |
+
n_head=dim // 64,
|
| 862 |
+
dim=dim,
|
| 863 |
+
intermediate_size=dim * 3,
|
| 864 |
+
),
|
| 865 |
+
)
|
| 866 |
+
)
|
| 867 |
+
self.block = nn.Sequential(
|
| 868 |
+
# three multi‑receptive‑field residual units
|
| 869 |
+
ResidualUnit(dim // 2, dilation=1, causal=causal),
|
| 870 |
+
ResidualUnit(dim // 2, dilation=3, causal=causal),
|
| 871 |
+
ResidualUnit(dim // 2, dilation=9, causal=causal),
|
| 872 |
+
Snake1d(dim // 2),
|
| 873 |
+
conv_class(dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
|
| 874 |
+
transformer_module,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 878 |
+
return self.block(x)
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
class ResidualUnit(nn.Module):
|
| 882 |
+
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
|
| 883 |
+
super().__init__()
|
| 884 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
| 885 |
+
pad = ((7 - 1) * dilation) // 2
|
| 886 |
+
self.block = nn.Sequential(
|
| 887 |
+
Snake1d(dim),
|
| 888 |
+
conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
| 889 |
+
Snake1d(dim),
|
| 890 |
+
conv_class(dim, dim, kernel_size=1),
|
| 891 |
+
)
|
| 892 |
+
self.causal = causal
|
| 893 |
+
|
| 894 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 895 |
+
y = self.block(x)
|
| 896 |
+
pad = x.shape[-1] - y.shape[-1]
|
| 897 |
+
if pad > 0:
|
| 898 |
+
if self.causal:
|
| 899 |
+
x = x[..., :-pad]
|
| 900 |
+
else:
|
| 901 |
+
x = x[..., pad // 2 : -pad // 2]
|
| 902 |
+
return x + y
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class Encoder(nn.Module):
|
| 906 |
+
def __init__(
|
| 907 |
+
self,
|
| 908 |
+
d_model: int = 64,
|
| 909 |
+
strides: List[int] = [2, 4, 8, 8],
|
| 910 |
+
d_latent: int = 64,
|
| 911 |
+
n_transformer_layers: List[int] = [0, 0, 4, 4],
|
| 912 |
+
transformer_general_config: Optional[ModelArgs] = None,
|
| 913 |
+
causal: bool = False,
|
| 914 |
+
):
|
| 915 |
+
super().__init__()
|
| 916 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
| 917 |
+
layers: List[nn.Module] = [conv_class(1, d_model, kernel_size=7, padding=3)]
|
| 918 |
+
for stride, n_t_layer in zip(strides, n_transformer_layers):
|
| 919 |
+
d_model *= 2
|
| 920 |
+
layers.append(
|
| 921 |
+
EncoderBlock(
|
| 922 |
+
d_model, stride=stride, causal=causal,
|
| 923 |
+
n_t_layer=n_t_layer, transformer_general_config=transformer_general_config,
|
| 924 |
+
)
|
| 925 |
+
)
|
| 926 |
+
layers += [Snake1d(d_model), conv_class(d_model, d_latent, kernel_size=3, padding=1)]
|
| 927 |
+
self.block = nn.Sequential(*layers)
|
| 928 |
+
self.enc_dim = d_model
|
| 929 |
+
|
| 930 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 931 |
+
return self.block(x)
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
class DecoderBlock(nn.Module):
|
| 935 |
+
def __init__(
|
| 936 |
+
self,
|
| 937 |
+
input_dim: int = 16,
|
| 938 |
+
output_dim: int = 8,
|
| 939 |
+
stride: int = 1,
|
| 940 |
+
causal: bool = False,
|
| 941 |
+
n_t_layer: int = 0,
|
| 942 |
+
transformer_general_config=None,
|
| 943 |
+
):
|
| 944 |
+
super().__init__()
|
| 945 |
+
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
|
| 946 |
+
transformer_module = (
|
| 947 |
+
nn.Identity()
|
| 948 |
+
if n_t_layer == 0
|
| 949 |
+
else WindowLimitedTransformer(
|
| 950 |
+
causal=causal,
|
| 951 |
+
input_dim=input_dim,
|
| 952 |
+
window_size=None,
|
| 953 |
+
config=transformer_general_config(
|
| 954 |
+
n_layer=n_t_layer,
|
| 955 |
+
n_head=input_dim // 64,
|
| 956 |
+
dim=input_dim,
|
| 957 |
+
intermediate_size=input_dim * 3,
|
| 958 |
+
),
|
| 959 |
+
)
|
| 960 |
+
)
|
| 961 |
+
self.block = nn.Sequential(
|
| 962 |
+
Snake1d(input_dim),
|
| 963 |
+
conv_trans_class(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
|
| 964 |
+
ResidualUnit(output_dim, dilation=1, causal=causal),
|
| 965 |
+
ResidualUnit(output_dim, dilation=3, causal=causal),
|
| 966 |
+
ResidualUnit(output_dim, dilation=9, causal=causal),
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 970 |
+
return self.block(x)
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
class Decoder(nn.Module):
|
| 974 |
+
def __init__(
|
| 975 |
+
self,
|
| 976 |
+
input_channel: int,
|
| 977 |
+
channels: int,
|
| 978 |
+
rates: List[int],
|
| 979 |
+
d_out: int = 1,
|
| 980 |
+
causal: bool = False,
|
| 981 |
+
n_transformer_layers: List[int] = [0, 0, 0, 0],
|
| 982 |
+
transformer_general_config=None,
|
| 983 |
+
):
|
| 984 |
+
super().__init__()
|
| 985 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
| 986 |
+
layers: List[nn.Module] = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
|
| 987 |
+
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
|
| 988 |
+
input_dim = channels // 2**i
|
| 989 |
+
output_dim = channels // 2 ** (i + 1)
|
| 990 |
+
layers.append(
|
| 991 |
+
DecoderBlock(
|
| 992 |
+
input_dim, output_dim, stride, causal=causal,
|
| 993 |
+
n_t_layer=n_t_layer, transformer_general_config=transformer_general_config,
|
| 994 |
+
)
|
| 995 |
+
)
|
| 996 |
+
layers += [Snake1d(output_dim), conv_class(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh()]
|
| 997 |
+
self.model = nn.Sequential(*layers)
|
| 998 |
+
|
| 999 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1000 |
+
return self.model(x)
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
class DAC(nn.Module):
|
| 1004 |
+
def __init__(
|
| 1005 |
+
self,
|
| 1006 |
+
encoder_dim: int = 64,
|
| 1007 |
+
encoder_rates: List[int] = [2, 4, 8, 8],
|
| 1008 |
+
latent_dim: Optional[int] = None,
|
| 1009 |
+
decoder_dim: int = 1536,
|
| 1010 |
+
decoder_rates: List[int] = [8, 8, 4, 2],
|
| 1011 |
+
quantizer: Optional[nn.Module] = None,
|
| 1012 |
+
sample_rate: int = 44100,
|
| 1013 |
+
causal: bool = True,
|
| 1014 |
+
encoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
| 1015 |
+
decoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
| 1016 |
+
transformer_general_config=None,
|
| 1017 |
+
):
|
| 1018 |
+
super().__init__()
|
| 1019 |
+
|
| 1020 |
+
self.encoder_dim = encoder_dim
|
| 1021 |
+
self.encoder_rates = encoder_rates
|
| 1022 |
+
self.decoder_dim = decoder_dim
|
| 1023 |
+
self.decoder_rates = decoder_rates
|
| 1024 |
+
self.sample_rate = sample_rate
|
| 1025 |
+
|
| 1026 |
+
if latent_dim is None:
|
| 1027 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
| 1028 |
+
self.latent_dim = latent_dim
|
| 1029 |
+
|
| 1030 |
+
self.hop_length = int(np.prod(encoder_rates))
|
| 1031 |
+
self.encoder = Encoder(
|
| 1032 |
+
encoder_dim, encoder_rates, latent_dim, causal=causal,
|
| 1033 |
+
n_transformer_layers=encoder_transformer_layers,
|
| 1034 |
+
transformer_general_config=transformer_general_config,
|
| 1035 |
+
)
|
| 1036 |
+
self.quantizer = quantizer
|
| 1037 |
+
self.decoder = Decoder(
|
| 1038 |
+
latent_dim, decoder_dim, decoder_rates, causal=causal,
|
| 1039 |
+
n_transformer_layers=decoder_transformer_layers,
|
| 1040 |
+
transformer_general_config=transformer_general_config,
|
| 1041 |
+
)
|
| 1042 |
+
self.sample_rate = sample_rate
|
| 1043 |
+
self.apply(init_weights)
|
| 1044 |
+
|
| 1045 |
+
self.delay = self.get_delay()
|
| 1046 |
+
self.frame_length = self.hop_length * 4
|
| 1047 |
+
|
| 1048 |
+
def get_output_length(self, input_length: int) -> int:
|
| 1049 |
+
length = input_length
|
| 1050 |
+
for stride in self.encoder_rates:
|
| 1051 |
+
length = math.ceil(length / stride)
|
| 1052 |
+
return length
|
| 1053 |
+
|
| 1054 |
+
def get_delay(self) -> int:
|
| 1055 |
+
l_out = self.get_output_length(0)
|
| 1056 |
+
L = l_out
|
| 1057 |
+
|
| 1058 |
+
layers = [layer for layer in self.modules() if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d))]
|
| 1059 |
+
for layer in reversed(layers):
|
| 1060 |
+
d = layer.dilation[0]
|
| 1061 |
+
k = layer.kernel_size[0]
|
| 1062 |
+
s = layer.stride[0]
|
| 1063 |
+
if isinstance(layer, nn.ConvTranspose1d):
|
| 1064 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
| 1065 |
+
elif isinstance(layer, nn.Conv1d):
|
| 1066 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
| 1067 |
+
L = math.ceil(L)
|
| 1068 |
+
|
| 1069 |
+
l_in = L
|
| 1070 |
+
return (l_in - l_out) // 2
|
| 1071 |
+
|
| 1072 |
+
def preprocess(self, audio_data: Tensor, sample_rate: Optional[int]) -> Tensor:
|
| 1073 |
+
if sample_rate is None:
|
| 1074 |
+
sample_rate = self.sample_rate
|
| 1075 |
+
assert sample_rate == self.sample_rate
|
| 1076 |
+
|
| 1077 |
+
length = audio_data.shape[-1]
|
| 1078 |
+
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
| 1079 |
+
audio_data = F.pad(audio_data, (0, right_pad))
|
| 1080 |
+
return audio_data
|
| 1081 |
+
|
| 1082 |
+
def encode(
|
| 1083 |
+
self,
|
| 1084 |
+
audio_data: Tensor,
|
| 1085 |
+
audio_lengths: Optional[Tensor] = None,
|
| 1086 |
+
n_quantizers: Optional[int] = None,
|
| 1087 |
+
**kwargs,
|
| 1088 |
+
):
|
| 1089 |
+
"""Encode audio to quantized code indices."""
|
| 1090 |
+
if audio_data.ndim == 2:
|
| 1091 |
+
audio_data = audio_data.unsqueeze(1)
|
| 1092 |
+
length = audio_data.shape[-1]
|
| 1093 |
+
right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
|
| 1094 |
+
audio_data = F.pad(audio_data, (0, right_pad))
|
| 1095 |
+
if audio_lengths is None:
|
| 1096 |
+
audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
|
| 1097 |
+
|
| 1098 |
+
z = self.encoder(audio_data)
|
| 1099 |
+
vq_results = self.quantizer(z, n_quantizers, **kwargs)
|
| 1100 |
+
indices = vq_results.codes
|
| 1101 |
+
indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
|
| 1102 |
+
return indices, indices_lens
|
| 1103 |
+
|
| 1104 |
+
def decode(self, indices: Tensor, feature_lengths: Tensor):
|
| 1105 |
+
"""Decode code indices to audio."""
|
| 1106 |
+
if indices.ndim == 2:
|
| 1107 |
+
indices = indices[None]
|
| 1108 |
+
z = self.quantizer.decode(indices)
|
| 1109 |
+
audio_lengths = feature_lengths * self.frame_length
|
| 1110 |
+
return self.decoder(z), audio_lengths
|
| 1111 |
+
|
| 1112 |
+
def encode_to_codes(self, audio: Tensor, audio_lengths: Optional[Tensor] = None, n_quantizers: Optional[int] = None, **kw):
|
| 1113 |
+
return self.encode(audio, audio_lengths, n_quantizers, **kw)
|
| 1114 |
+
|
| 1115 |
+
def decode_codes(self, indices: Tensor, feature_lengths: Tensor):
|
| 1116 |
+
return self.decode(indices, feature_lengths)
|
| 1117 |
+
|
| 1118 |
+
@torch.no_grad()
|
| 1119 |
+
def encode_zq(self, audio_data: Tensor) -> Tensor:
|
| 1120 |
+
indices, _ = self.encode(audio_data)
|
| 1121 |
+
new_indices = torch.zeros_like(indices)
|
| 1122 |
+
new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.quantizer.semantic_quantizer.codebook_size - 1)
|
| 1123 |
+
new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.quantizer.codebook_size - 1)
|
| 1124 |
+
|
| 1125 |
+
z_q_semantic = self.quantizer.semantic_quantizer.from_codes(new_indices[:, :1])[0]
|
| 1126 |
+
z_q_residual = self.quantizer.quantizer.from_codes(new_indices[:, 1:])[0]
|
| 1127 |
+
z_q = z_q_semantic + z_q_residual
|
| 1128 |
+
return z_q
|
| 1129 |
+
|
| 1130 |
+
@torch.no_grad()
|
| 1131 |
+
def decode_zq(self, z_q: Tensor) -> Tensor:
|
| 1132 |
+
z_q = self.quantizer.post_module(z_q)
|
| 1133 |
+
z_q = self.quantizer.upsample(z_q)
|
| 1134 |
+
return self.decoder(z_q)
|
| 1135 |
+
|
| 1136 |
+
@property
|
| 1137 |
+
def device(self) -> torch.device: return next(self.parameters()).device
|
| 1138 |
+
|
| 1139 |
+
@property
|
| 1140 |
+
def dtype(self) -> torch.dtype: return next(self.parameters()).dtype
|
| 1141 |
+
|
| 1142 |
+
# --------------------------------------------------------------------
|
| 1143 |
+
# Build helpers
|
| 1144 |
+
# --------------------------------------------------------------------
|
| 1145 |
+
|
| 1146 |
+
def build_ae(**cfg) -> DAC:
|
| 1147 |
+
"""
|
| 1148 |
+
Factory used by external loaders
|
| 1149 |
+
"""
|
| 1150 |
+
# Shared transformer config for the RVQ pre/post modules
|
| 1151 |
+
q_config = ModelArgs(
|
| 1152 |
+
block_size=4096, n_layer=8, n_head=16, dim=1024,
|
| 1153 |
+
intermediate_size=3072, head_dim=64, norm_eps=1e-5,
|
| 1154 |
+
dropout_rate=0.1, attn_dropout_rate=0.1, channels_first=True
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
def make_transformer():
|
| 1158 |
+
return WindowLimitedTransformer(
|
| 1159 |
+
causal=True, window_size=128, input_dim=1024, config=q_config
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
quantizer = DownsampleResidualVectorQuantize(
|
| 1163 |
+
input_dim=1024, n_codebooks=9, codebook_size=1024, codebook_dim=8,
|
| 1164 |
+
quantizer_dropout=0.5, downsample_factor=(2, 2),
|
| 1165 |
+
semantic_codebook_size=4096,
|
| 1166 |
+
pre_module=make_transformer(),
|
| 1167 |
+
post_module=make_transformer(),
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
def transformer_general_config(**kw):
|
| 1171 |
+
return ModelArgs(
|
| 1172 |
+
block_size=kw.get("block_size", 16384),
|
| 1173 |
+
n_layer=kw.get("n_layer", 8),
|
| 1174 |
+
n_head=kw.get("n_head", 8),
|
| 1175 |
+
dim=kw.get("dim", 512),
|
| 1176 |
+
intermediate_size=kw.get("intermediate_size", 1536),
|
| 1177 |
+
n_local_heads=kw.get("n_local_heads", -1),
|
| 1178 |
+
head_dim=kw.get("head_dim", 64),
|
| 1179 |
+
rope_base=kw.get("rope_base", 10000),
|
| 1180 |
+
norm_eps=kw.get("norm_eps", 1e-5),
|
| 1181 |
+
dropout_rate=kw.get("dropout_rate", 0.1),
|
| 1182 |
+
attn_dropout_rate=kw.get("attn_dropout_rate", 0.1),
|
| 1183 |
+
channels_first=kw.get("channels_first", True),
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
dac = DAC(
|
| 1187 |
+
encoder_dim=64, encoder_rates=[2, 4, 8, 8], latent_dim=1024,
|
| 1188 |
+
decoder_dim=1536, decoder_rates=[8, 8, 4, 2],
|
| 1189 |
+
quantizer=quantizer, sample_rate=44100, causal=True,
|
| 1190 |
+
encoder_transformer_layers=[0, 0, 0, 4],
|
| 1191 |
+
decoder_transformer_layers=[4, 0, 0, 0],
|
| 1192 |
+
transformer_general_config=transformer_general_config,
|
| 1193 |
+
)
|
| 1194 |
+
return dac
|
| 1195 |
+
|
| 1196 |
+
__all__ = [
|
| 1197 |
+
"DAC",
|
| 1198 |
+
"build_ae",
|
| 1199 |
+
"VectorQuantize",
|
| 1200 |
+
"ResidualVectorQuantize",
|
| 1201 |
+
"DownsampleResidualVectorQuantize",
|
| 1202 |
+
]
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
# ----- BEGIN DAC MIT LICENSE -----
|
| 1206 |
+
# MIT License
|
| 1207 |
+
# Copyright (c) 2023-present, Descript
|
| 1208 |
+
#
|
| 1209 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 1210 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 1211 |
+
# in the Software without restriction, including without limitation the rights
|
| 1212 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 1213 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 1214 |
+
# furnished to do so, subject to the following conditions:
|
| 1215 |
+
#
|
| 1216 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 1217 |
+
# copies or substantial portions of the Software.
|
| 1218 |
+
#
|
| 1219 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 1220 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 1221 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 1222 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 1223 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 1224 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 1225 |
+
# SOFTWARE.
|
| 1226 |
+
# ----- END DAC MIT LICENSE -----
|
| 1227 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Callable, List, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
import safetensors.torch as st
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
|
| 7 |
+
from model import EchoDiT
|
| 8 |
+
from autoencoder import build_ae, DAC
|
| 9 |
+
|
| 10 |
+
import torchaudio
|
| 11 |
+
from torchcodec.decoders import AudioDecoder
|
| 12 |
+
|
| 13 |
+
# from samplers import Sampler
|
| 14 |
+
|
| 15 |
+
SampleFn = Callable[
|
| 16 |
+
[EchoDiT, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int],
|
| 17 |
+
torch.Tensor
|
| 18 |
+
]
|
| 19 |
+
### Loading
|
| 20 |
+
|
| 21 |
+
def load_model_from_hf(repo_id: str = 'jordand/echo-tts', device: str = 'cuda', dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None) -> EchoDiT:
|
| 22 |
+
with torch.device('meta'):
|
| 23 |
+
model = EchoDiT(
|
| 24 |
+
latent_size=80, model_size=2048, num_layers=24, num_heads=16,
|
| 25 |
+
intermediate_size=5888, norm_eps=1e-5, max_seq_len=640,
|
| 26 |
+
text_vocab_size=256, text_model_size=1280, text_num_layers=14,
|
| 27 |
+
text_num_heads=10, text_intermediate_size=3328, text_max_seq_len=768,
|
| 28 |
+
speaker_patch_size=4, speaker_model_size=1280, speaker_num_layers=14,
|
| 29 |
+
speaker_num_heads=10, speaker_intermediate_size=3328,
|
| 30 |
+
speaker_max_patched_seq_len=640, timestep_embed_size=512, adaln_rank=256,
|
| 31 |
+
)
|
| 32 |
+
w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
|
| 33 |
+
|
| 34 |
+
# Load to CPU first
|
| 35 |
+
state = st.load_file(w_path, device='cpu')
|
| 36 |
+
|
| 37 |
+
# Convert dtype on CPU if needed
|
| 38 |
+
if dtype is not None:
|
| 39 |
+
state = {k: v.to(dtype=dtype) for k, v in state.items()}
|
| 40 |
+
|
| 41 |
+
# Now move to device
|
| 42 |
+
state = {k: v.to(device=device) for k, v in state.items()}
|
| 43 |
+
|
| 44 |
+
model.load_state_dict(state, strict=True, assign=True)
|
| 45 |
+
model = model.eval()
|
| 46 |
+
|
| 47 |
+
if compile:
|
| 48 |
+
model = torch.compile(model)
|
| 49 |
+
model.get_kv_cache = torch.compile(model.get_kv_cache)
|
| 50 |
+
|
| 51 |
+
return model
|
| 52 |
+
|
| 53 |
+
def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = 'cuda', dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC:
|
| 54 |
+
# have not tested lower precisions with fish AE yet
|
| 55 |
+
|
| 56 |
+
with torch.device('meta'):
|
| 57 |
+
fish_ae = build_ae()
|
| 58 |
+
|
| 59 |
+
w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
|
| 60 |
+
if dtype is not None and dtype != torch.float32:
|
| 61 |
+
state = st.load_file(w_path, device='cpu')
|
| 62 |
+
state = {k: v.to(dtype=dtype) for k, v in state.items()}
|
| 63 |
+
state = {k: v.to(device=device) for k, v in state.items()}
|
| 64 |
+
fish_ae.load_state_dict(state, strict=False, assign=True)
|
| 65 |
+
else:
|
| 66 |
+
state = st.load_file(w_path, device=device)
|
| 67 |
+
fish_ae.load_state_dict(state, strict=False, assign=True)
|
| 68 |
+
|
| 69 |
+
fish_ae = fish_ae.eval().to(device)
|
| 70 |
+
|
| 71 |
+
if compile:
|
| 72 |
+
fish_ae.encoder = torch.compile(fish_ae.encoder)
|
| 73 |
+
fish_ae.decoder = torch.compile(fish_ae.decoder)
|
| 74 |
+
|
| 75 |
+
return fish_ae
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class PCAState:
|
| 80 |
+
pca_components: torch.Tensor
|
| 81 |
+
pca_mean: torch.Tensor
|
| 82 |
+
latent_scale: float
|
| 83 |
+
|
| 84 |
+
def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts', device: str = 'cuda', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
|
| 85 |
+
p_path = hf_hub_download(repo_id, filename, token=token)
|
| 86 |
+
t = st.load_file(p_path, device=device)
|
| 87 |
+
return PCAState(
|
| 88 |
+
pca_components=t["pca_components"],
|
| 89 |
+
pca_mean=t["pca_mean"],
|
| 90 |
+
latent_scale=float(t["latent_scale"].item()),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
### default load audio
|
| 94 |
+
|
| 95 |
+
def load_audio(path: str) -> torch.Tensor:
|
| 96 |
+
|
| 97 |
+
decoder = AudioDecoder(path)
|
| 98 |
+
sr = decoder.metadata.sample_rate
|
| 99 |
+
audio = decoder.get_samples_played_in_range(0, 120)
|
| 100 |
+
audio = audio.data.mean(dim=0).unsqueeze(0)
|
| 101 |
+
audio = torchaudio.functional.resample(audio, sr, 44_100)
|
| 102 |
+
audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.))
|
| 103 |
+
# TODO is this better than clipping? should we target a specific energy level?
|
| 104 |
+
return audio
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
### Text helpers
|
| 109 |
+
|
| 110 |
+
def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True) -> torch.Tensor:
|
| 111 |
+
|
| 112 |
+
if normalize:
|
| 113 |
+
text = text.replace('…', '...')
|
| 114 |
+
text = text.replace('“', '"')
|
| 115 |
+
text = text.replace('”', '"')
|
| 116 |
+
text = text.replace('’', "'")
|
| 117 |
+
text = text.replace('\n', " ")
|
| 118 |
+
|
| 119 |
+
b = list(text.encode('utf-8'))
|
| 120 |
+
if append_bos:
|
| 121 |
+
b.insert(0, 0)
|
| 122 |
+
return torch.tensor(b)
|
| 123 |
+
|
| 124 |
+
def get_text_input_ids_and_mask(text_arr: List[str], max_length: int | None, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]:
|
| 125 |
+
batch_size = len(text_arr)
|
| 126 |
+
if max_length is None:
|
| 127 |
+
max_length = max(len(tokenizer_encode(text)) for text in text_arr) # obviously bad...
|
| 128 |
+
|
| 129 |
+
tokens = torch.zeros((batch_size, max_length), dtype=torch.int32)
|
| 130 |
+
mask = torch.zeros((batch_size, max_length), dtype=torch.bool)
|
| 131 |
+
|
| 132 |
+
for i, text in enumerate(text_arr):
|
| 133 |
+
encoded = tokenizer_encode(text)
|
| 134 |
+
length = min(len(encoded), max_length)
|
| 135 |
+
tokens[i, :length] = encoded[:length]
|
| 136 |
+
mask[i, :length] = 1
|
| 137 |
+
|
| 138 |
+
if device is not None:
|
| 139 |
+
tokens = tokens.to(device)
|
| 140 |
+
mask = mask.to(device)
|
| 141 |
+
|
| 142 |
+
return tokens, mask
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
### Autoencoder Inference
|
| 146 |
+
|
| 147 |
+
@torch.inference_mode()
|
| 148 |
+
def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length)
|
| 150 |
+
z_q = fish_ae.encode_zq(audio).float()
|
| 151 |
+
z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T
|
| 152 |
+
z_q = z_q * pca_state.latent_scale
|
| 153 |
+
return z_q
|
| 154 |
+
|
| 155 |
+
@torch.inference_mode()
|
| 156 |
+
def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean
|
| 158 |
+
return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float()
|
| 159 |
+
|
| 160 |
+
@torch.inference_mode()
|
| 161 |
+
def ae_reconstruct(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
|
| 162 |
+
# (audio is (b, 1, length))
|
| 163 |
+
z_q = ae_encode(fish_ae, pca_state, audio.to(fish_ae.dtype))
|
| 164 |
+
return ae_decode(fish_ae, pca_state, z_q)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@torch.inference_mode()
|
| 168 |
+
def get_speaker_latent_and_mask(
|
| 169 |
+
fish_ae: DAC,
|
| 170 |
+
pca_state: PCAState,
|
| 171 |
+
audio: torch.Tensor, # (1, length)
|
| 172 |
+
max_speaker_latent_len: int = 2560, # pretrained max length
|
| 173 |
+
audio_chunk_size: int = 640 * 2048 # (~30 seconds, 1/4 max speaker condition size)
|
| 174 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 175 |
+
|
| 176 |
+
# gets speaker latent and mask from audio, computes in chunks and concatenates (similar to pretraining setup)
|
| 177 |
+
|
| 178 |
+
AE_DOWNSAMPLE_FACTOR = 2048
|
| 179 |
+
max_audio_len = max_speaker_latent_len * AE_DOWNSAMPLE_FACTOR
|
| 180 |
+
|
| 181 |
+
assert audio.ndim == 2 and audio.shape[0] == 1 # (1, length)
|
| 182 |
+
audio = audio[:, :max_audio_len]
|
| 183 |
+
audio_len = audio.shape[1]
|
| 184 |
+
|
| 185 |
+
latent_arr = []
|
| 186 |
+
|
| 187 |
+
for i in range(0, audio_len, audio_chunk_size):
|
| 188 |
+
audio_chunk = audio[:, i:i + audio_chunk_size]
|
| 189 |
+
if audio_chunk.shape[1] < audio_chunk_size:
|
| 190 |
+
audio_chunk = torch.nn.functional.pad(audio_chunk, (0, audio_chunk_size - audio_chunk.shape[1]))
|
| 191 |
+
|
| 192 |
+
latent_chunk = ae_encode(fish_ae, pca_state, audio_chunk.unsqueeze(0))
|
| 193 |
+
latent_arr.append(latent_chunk)
|
| 194 |
+
|
| 195 |
+
speaker_latent = torch.cat(latent_arr, dim=1)
|
| 196 |
+
|
| 197 |
+
actual_latent_len = audio_len // AE_DOWNSAMPLE_FACTOR
|
| 198 |
+
speaker_mask = (torch.arange(speaker_latent.shape[1], device=speaker_latent.device) < actual_latent_len).unsqueeze(0)
|
| 199 |
+
|
| 200 |
+
if speaker_latent.shape[1] < max_speaker_latent_len:
|
| 201 |
+
speaker_latent = torch.nn.functional.pad(speaker_latent, (0, 0, 0, max_speaker_latent_len - speaker_latent.shape[1]))
|
| 202 |
+
speaker_mask = torch.nn.functional.pad(speaker_mask, (0, max_speaker_latent_len - speaker_mask.shape[1]))
|
| 203 |
+
|
| 204 |
+
return speaker_latent, speaker_mask
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
### Full sample pipeline
|
| 208 |
+
|
| 209 |
+
def find_flattening_point(data, target_value=0.0, window_size=20, std_threshold=0.05):
|
| 210 |
+
padded_data = torch.cat([data, torch.zeros(window_size, *data.shape[1:], device=data.device, dtype=data.dtype)])
|
| 211 |
+
for i in range(len(padded_data) - window_size):
|
| 212 |
+
window = padded_data[i:i + window_size]
|
| 213 |
+
if window.std() < std_threshold and abs(window.mean() - target_value) < 0.1:
|
| 214 |
+
return i
|
| 215 |
+
return len(data)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@torch.inference_mode()
|
| 219 |
+
def sample_pipeline(
|
| 220 |
+
model: EchoDiT,
|
| 221 |
+
fish_ae: DAC,
|
| 222 |
+
pca_state: PCAState,
|
| 223 |
+
sample_fn: SampleFn,
|
| 224 |
+
text_prompt: str,
|
| 225 |
+
speaker_audio: torch.Tensor | None,
|
| 226 |
+
rng_seed: int,
|
| 227 |
+
pad_to_max_speaker_latent_len: int | None = 2560,
|
| 228 |
+
pad_to_max_text_seq_len: int | None = 768,
|
| 229 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
| 230 |
+
|
| 231 |
+
MAX_SPEAKER_LATENT_LEN = 2560
|
| 232 |
+
MAX_TEXT_SEQ_LEN = 768
|
| 233 |
+
|
| 234 |
+
device, dtype = model.device, model.dtype
|
| 235 |
+
|
| 236 |
+
text_input_ids, text_mask = get_text_input_ids_and_mask([text_prompt], min(pad_to_max_text_seq_len or MAX_TEXT_SEQ_LEN, MAX_TEXT_SEQ_LEN), device=device)
|
| 237 |
+
|
| 238 |
+
# print('initial text input ids length: ', text_input_ids.shape[1])
|
| 239 |
+
# torch.cuda.synchronize()
|
| 240 |
+
|
| 241 |
+
# import time
|
| 242 |
+
|
| 243 |
+
# t0 = time.time()
|
| 244 |
+
|
| 245 |
+
if speaker_audio is None:
|
| 246 |
+
# No speaker prompt - use zero speaker latent and mask
|
| 247 |
+
speaker_latent = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN, 80), device=device, dtype=dtype)
|
| 248 |
+
speaker_mask = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN), device=device, dtype=torch.bool)
|
| 249 |
+
# print("Using zero speaker latent and mask (no speaker prompt)")
|
| 250 |
+
else:
|
| 251 |
+
speaker_latent, speaker_mask = get_speaker_latent_and_mask(
|
| 252 |
+
fish_ae,
|
| 253 |
+
pca_state,
|
| 254 |
+
speaker_audio.to(fish_ae.dtype),
|
| 255 |
+
max_speaker_latent_len=pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN
|
| 256 |
+
)
|
| 257 |
+
speaker_latent = speaker_latent.to(device)
|
| 258 |
+
speaker_mask = speaker_mask.to(device)
|
| 259 |
+
|
| 260 |
+
# print('speaker latent shape: ', speaker_latent.shape)
|
| 261 |
+
# print('speaker mask shape: ', speaker_mask.shape)
|
| 262 |
+
|
| 263 |
+
# torch.cuda.synchronize()
|
| 264 |
+
# t1 = time.time()
|
| 265 |
+
# print(f"Time taken encode: {t1 - t0} seconds")
|
| 266 |
+
|
| 267 |
+
latent_out = sample_fn(model, speaker_latent, speaker_mask, text_input_ids, text_mask, rng_seed)
|
| 268 |
+
|
| 269 |
+
# torch.cuda.synchronize()
|
| 270 |
+
# t2 = time.time()
|
| 271 |
+
|
| 272 |
+
# print(f"Time taken sample: {t2 - t1} seconds")
|
| 273 |
+
|
| 274 |
+
audio_out = ae_decode(fish_ae, pca_state, latent_out)
|
| 275 |
+
# torch.cuda.synchronize()
|
| 276 |
+
# t3 = time.time()
|
| 277 |
+
# print(f"Time taken decode: {t3 - t2} seconds")
|
| 278 |
+
|
| 279 |
+
flattening_point = find_flattening_point(latent_out[0])
|
| 280 |
+
audio_out = audio_out[..., :flattening_point * 2048]
|
| 281 |
+
|
| 282 |
+
# print(f"\nTime taken total: {t3 - t0} seconds")
|
| 283 |
+
|
| 284 |
+
# peak_mem = torch.cuda.max_memory_allocated()
|
| 285 |
+
# print(f"Peak memory: {peak_mem / 1024**2:.2f} MB")
|
| 286 |
+
# print(torch.cuda.memory_summary(abbreviated=True))
|
| 287 |
+
|
| 288 |
+
return audio_out
|
| 289 |
+
|
| 290 |
+
|
model.py
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
| 10 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)] / dim))
|
| 11 |
+
t = torch.arange(end)
|
| 12 |
+
freqs = torch.outer(t, freqs)
|
| 13 |
+
freqs_cis = torch.complex(torch.cos(freqs), torch.sin(freqs))
|
| 14 |
+
return freqs_cis
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def apply_rotary_emb(
|
| 18 |
+
x: torch.Tensor,
|
| 19 |
+
freqs_cis: torch.Tensor,
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:3], -1, 2))
|
| 22 |
+
x_ = x_ * freqs_cis[..., None, :]
|
| 23 |
+
x_ = torch.view_as_real(x_).reshape(x.shape)
|
| 24 |
+
return x_.type_as(x)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_timestep_embedding(
|
| 28 |
+
timestep: torch.Tensor,
|
| 29 |
+
embed_size: int,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
assert embed_size % 2 == 0
|
| 32 |
+
|
| 33 |
+
half = embed_size // 2
|
| 34 |
+
|
| 35 |
+
freqs = 1000 * torch.exp(
|
| 36 |
+
-torch.log(torch.tensor(10000.0)) *
|
| 37 |
+
torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 38 |
+
).to(timestep.device)
|
| 39 |
+
|
| 40 |
+
args = timestep[..., None] * freqs[None]
|
| 41 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 42 |
+
|
| 43 |
+
return embedding.to(timestep.dtype)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LowRankAdaLN(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
model_size: int,
|
| 50 |
+
rank: int,
|
| 51 |
+
eps: float
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.eps = eps
|
| 55 |
+
|
| 56 |
+
self.shift_down = nn.Linear(model_size, rank, bias=False)
|
| 57 |
+
self.scale_down = nn.Linear(model_size, rank, bias=False)
|
| 58 |
+
self.gate_down = nn.Linear(model_size, rank, bias=False)
|
| 59 |
+
|
| 60 |
+
self.shift_up = nn.Linear(rank, model_size, bias=True)
|
| 61 |
+
self.scale_up = nn.Linear(rank, model_size, bias=True)
|
| 62 |
+
self.gate_up = nn.Linear(rank, model_size, bias=True)
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
x: torch.Tensor,
|
| 67 |
+
cond_embed: torch.Tensor,
|
| 68 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 69 |
+
|
| 70 |
+
shift, scale, gate = cond_embed.chunk(3, dim=-1)
|
| 71 |
+
|
| 72 |
+
shift = self.shift_up(self.shift_down(F.silu(shift))) + shift
|
| 73 |
+
scale = self.scale_up(self.scale_down(F.silu(scale))) + scale
|
| 74 |
+
gate = self.gate_up(self.gate_down(F.silu(gate))) + gate
|
| 75 |
+
|
| 76 |
+
x_dtype = x.dtype
|
| 77 |
+
x = x.float()
|
| 78 |
+
x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps)
|
| 79 |
+
x = x * (scale + 1) + shift
|
| 80 |
+
|
| 81 |
+
gate = torch.tanh(gate)
|
| 82 |
+
|
| 83 |
+
return x.to(x_dtype), gate
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class RMSNorm(nn.Module): # could also just use torch rmsnorm
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
model_size: int | Tuple[int, int],
|
| 90 |
+
eps: float
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.eps = eps
|
| 94 |
+
|
| 95 |
+
if isinstance(model_size, int):
|
| 96 |
+
model_size = (model_size, )
|
| 97 |
+
self.weight = nn.Parameter(torch.ones(model_size))
|
| 98 |
+
|
| 99 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
x_dtype = x.dtype
|
| 101 |
+
x = x.float()
|
| 102 |
+
x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps)
|
| 103 |
+
x = x * self.weight
|
| 104 |
+
return x.to(x_dtype)
|
| 105 |
+
|
| 106 |
+
class SelfAttention(nn.Module):
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
model_size: int,
|
| 110 |
+
num_heads: int,
|
| 111 |
+
is_causal: bool,
|
| 112 |
+
norm_eps: float
|
| 113 |
+
):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.num_heads = num_heads
|
| 116 |
+
self.is_causal = is_causal
|
| 117 |
+
|
| 118 |
+
self.wq = nn.Linear(model_size, model_size, bias=False)
|
| 119 |
+
self.wk = nn.Linear(model_size, model_size, bias=False)
|
| 120 |
+
self.wv = nn.Linear(model_size, model_size, bias=False)
|
| 121 |
+
self.wo = nn.Linear(model_size, model_size, bias=False)
|
| 122 |
+
self.gate = nn.Linear(model_size, model_size, bias=False)
|
| 123 |
+
|
| 124 |
+
assert model_size % num_heads == 0
|
| 125 |
+
self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
|
| 126 |
+
self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
|
| 127 |
+
|
| 128 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 129 |
+
|
| 130 |
+
batch_size, seq_len = x.shape[:2]
|
| 131 |
+
|
| 132 |
+
xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 133 |
+
xk = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 134 |
+
xv = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 135 |
+
|
| 136 |
+
gate = self.gate(x)
|
| 137 |
+
|
| 138 |
+
xq = self.q_norm(xq)
|
| 139 |
+
xk = self.k_norm(xk)
|
| 140 |
+
|
| 141 |
+
xq = apply_rotary_emb(xq, freqs_cis[:seq_len])
|
| 142 |
+
xk = apply_rotary_emb(xk, freqs_cis[:seq_len])
|
| 143 |
+
|
| 144 |
+
if mask is not None:
|
| 145 |
+
assert mask.ndim == 2 # (b, s)
|
| 146 |
+
mask = mask[:, None, None]
|
| 147 |
+
|
| 148 |
+
output = F.scaled_dot_product_attention(
|
| 149 |
+
query=xq.transpose(1, 2),
|
| 150 |
+
key=xk.transpose(1, 2),
|
| 151 |
+
value=xv.transpose(1, 2),
|
| 152 |
+
attn_mask=mask,
|
| 153 |
+
is_causal=self.is_causal
|
| 154 |
+
).transpose(1, 2)
|
| 155 |
+
|
| 156 |
+
output = output.reshape(batch_size, seq_len, -1)
|
| 157 |
+
output = output * torch.sigmoid(gate)
|
| 158 |
+
|
| 159 |
+
output = self.wo(output)
|
| 160 |
+
|
| 161 |
+
return output
|
| 162 |
+
|
| 163 |
+
class JointAttention(nn.Module):
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
model_size: int,
|
| 167 |
+
num_heads: int,
|
| 168 |
+
text_model_size: int,
|
| 169 |
+
speaker_model_size: int,
|
| 170 |
+
speaker_patch_size: int,
|
| 171 |
+
norm_eps: float
|
| 172 |
+
):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.speaker_patch_size = speaker_patch_size
|
| 175 |
+
self.num_heads = num_heads
|
| 176 |
+
|
| 177 |
+
self.wq = nn.Linear(model_size, model_size, bias=False)
|
| 178 |
+
self.wk = nn.Linear(model_size, model_size, bias=False)
|
| 179 |
+
self.wv = nn.Linear(model_size, model_size, bias=False)
|
| 180 |
+
|
| 181 |
+
self.wk_text = nn.Linear(text_model_size, model_size, bias=False)
|
| 182 |
+
self.wv_text = nn.Linear(text_model_size, model_size, bias=False)
|
| 183 |
+
|
| 184 |
+
self.wk_speaker = nn.Linear(speaker_model_size, model_size, bias=False)
|
| 185 |
+
self.wv_speaker = nn.Linear(speaker_model_size, model_size, bias=False)
|
| 186 |
+
|
| 187 |
+
assert model_size % num_heads == 0
|
| 188 |
+
self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
|
| 189 |
+
self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
|
| 190 |
+
|
| 191 |
+
self.gate = nn.Linear(model_size, model_size, bias=False)
|
| 192 |
+
|
| 193 |
+
self.wo = nn.Linear(model_size, model_size, bias=False)
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
x: torch.Tensor,
|
| 198 |
+
text_state: torch.Tensor | None,
|
| 199 |
+
text_mask: torch.Tensor,
|
| 200 |
+
speaker_state: torch.Tensor | None,
|
| 201 |
+
speaker_mask: torch.Tensor,
|
| 202 |
+
freqs_cis: torch.Tensor,
|
| 203 |
+
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 204 |
+
) -> torch.Tensor:
|
| 205 |
+
batch_size, seq_len = x.shape[:2]
|
| 206 |
+
|
| 207 |
+
xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 208 |
+
xk_self = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 209 |
+
xv_self = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1)
|
| 210 |
+
|
| 211 |
+
xq = self.q_norm(xq)
|
| 212 |
+
xk_self = self.k_norm(xk_self)
|
| 213 |
+
|
| 214 |
+
gate = self.gate(x)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _apply_rotary_half(y: torch.Tensor, fc: torch.Tensor) -> torch.Tensor:
|
| 218 |
+
y1, y2 = y.chunk(2, dim=-2)
|
| 219 |
+
y1 = apply_rotary_emb(y1, fc)
|
| 220 |
+
return torch.cat([y1, y2], dim=-2)
|
| 221 |
+
|
| 222 |
+
xq = _apply_rotary_half(xq, freqs_cis)
|
| 223 |
+
xk_self = _apply_rotary_half(xk_self, freqs_cis)
|
| 224 |
+
|
| 225 |
+
if kv_cache is None:
|
| 226 |
+
|
| 227 |
+
xk_text = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
|
| 228 |
+
xv_text = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
|
| 229 |
+
|
| 230 |
+
xk_speaker = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
|
| 231 |
+
xv_speaker = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
|
| 232 |
+
|
| 233 |
+
xk_text = self.k_norm(xk_text)
|
| 234 |
+
xk_speaker = self.k_norm(xk_speaker)
|
| 235 |
+
|
| 236 |
+
xk = torch.cat([xk_self, xk_text, xk_speaker], dim=1)
|
| 237 |
+
xv = torch.cat([xv_self, xv_text, xv_speaker], dim=1)
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
xk_cross, xv_cross = kv_cache
|
| 241 |
+
xk = torch.cat([xk_self, xk_cross], dim=1)
|
| 242 |
+
xv = torch.cat([xv_self, xv_cross], dim=1)
|
| 243 |
+
|
| 244 |
+
self_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=x.device)
|
| 245 |
+
mask = torch.cat([self_mask, text_mask, speaker_mask], dim=1)
|
| 246 |
+
mask = mask[:, None, None]
|
| 247 |
+
|
| 248 |
+
output = F.scaled_dot_product_attention(
|
| 249 |
+
query=xq.transpose(1, 2),
|
| 250 |
+
key=xk.transpose(1, 2),
|
| 251 |
+
value=xv.transpose(1, 2),
|
| 252 |
+
attn_mask=mask,
|
| 253 |
+
is_causal=False
|
| 254 |
+
).transpose(1, 2)
|
| 255 |
+
|
| 256 |
+
output = output.reshape(batch_size, seq_len, -1)
|
| 257 |
+
output = output * torch.sigmoid(gate)
|
| 258 |
+
|
| 259 |
+
output = self.wo(output)
|
| 260 |
+
|
| 261 |
+
return output
|
| 262 |
+
|
| 263 |
+
def get_kv_cache(
|
| 264 |
+
self,
|
| 265 |
+
text_state: torch.Tensor,
|
| 266 |
+
speaker_state: torch.Tensor,
|
| 267 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 268 |
+
|
| 269 |
+
batch_size = text_state.shape[0]
|
| 270 |
+
|
| 271 |
+
xk_text = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
|
| 272 |
+
xv_text = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
|
| 273 |
+
|
| 274 |
+
xk_speaker = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
|
| 275 |
+
xv_speaker = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
|
| 276 |
+
|
| 277 |
+
xk = torch.cat([xk_text, xk_speaker], dim=1)
|
| 278 |
+
xv = torch.cat([xv_text, xv_speaker], dim=1)
|
| 279 |
+
|
| 280 |
+
xk = self.k_norm(xk)
|
| 281 |
+
|
| 282 |
+
return xk, xv
|
| 283 |
+
|
| 284 |
+
class MLP(nn.Module):
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
model_size: int,
|
| 288 |
+
intermediate_size: int
|
| 289 |
+
):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.w1 = nn.Linear(model_size, intermediate_size, bias=False)
|
| 292 |
+
self.w3 = nn.Linear(model_size, intermediate_size, bias=False)
|
| 293 |
+
self.w2 = nn.Linear(intermediate_size, model_size, bias=False)
|
| 294 |
+
|
| 295 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 296 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class EncoderTransformerBlock(nn.Module):
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
model_size: int,
|
| 303 |
+
num_heads: int,
|
| 304 |
+
intermediate_size: int,
|
| 305 |
+
is_causal: bool,
|
| 306 |
+
norm_eps: float
|
| 307 |
+
):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.attention = SelfAttention(
|
| 310 |
+
model_size=model_size,
|
| 311 |
+
num_heads=num_heads,
|
| 312 |
+
is_causal=is_causal,
|
| 313 |
+
norm_eps=norm_eps
|
| 314 |
+
)
|
| 315 |
+
self.mlp = MLP(
|
| 316 |
+
model_size=model_size,
|
| 317 |
+
intermediate_size=intermediate_size
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
self.attention_norm = RMSNorm(model_size, norm_eps)
|
| 321 |
+
self.mlp_norm = RMSNorm(model_size, norm_eps)
|
| 322 |
+
|
| 323 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 324 |
+
x = x + self.attention(self.attention_norm(x), mask, freqs_cis)
|
| 325 |
+
x = x + self.mlp(self.mlp_norm(x))
|
| 326 |
+
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
class TransformerBlock(nn.Module):
|
| 330 |
+
def __init__(
|
| 331 |
+
self,
|
| 332 |
+
model_size: int,
|
| 333 |
+
num_heads: int,
|
| 334 |
+
intermediate_size: int,
|
| 335 |
+
norm_eps: float,
|
| 336 |
+
text_model_size: int,
|
| 337 |
+
speaker_model_size: int,
|
| 338 |
+
speaker_patch_size: int,
|
| 339 |
+
adaln_rank: int,
|
| 340 |
+
):
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.attention = JointAttention(
|
| 343 |
+
model_size=model_size,
|
| 344 |
+
num_heads=num_heads,
|
| 345 |
+
text_model_size=text_model_size,
|
| 346 |
+
speaker_model_size=speaker_model_size,
|
| 347 |
+
speaker_patch_size=speaker_patch_size,
|
| 348 |
+
norm_eps=norm_eps
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
self.mlp = MLP(
|
| 352 |
+
model_size=model_size,
|
| 353 |
+
intermediate_size=intermediate_size
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
self.attention_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps)
|
| 357 |
+
self.mlp_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps)
|
| 358 |
+
|
| 359 |
+
def forward(
|
| 360 |
+
self,
|
| 361 |
+
x: torch.Tensor,
|
| 362 |
+
cond_embed: torch.Tensor,
|
| 363 |
+
text_state: torch.Tensor | None,
|
| 364 |
+
text_mask: torch.Tensor,
|
| 365 |
+
speaker_state: torch.Tensor | None,
|
| 366 |
+
speaker_mask: torch.Tensor,
|
| 367 |
+
freqs_cis: torch.Tensor,
|
| 368 |
+
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
|
| 371 |
+
x_norm, attention_gate = self.attention_adaln(x, cond_embed)
|
| 372 |
+
x = x + attention_gate * self.attention(x_norm, text_state, text_mask, speaker_state, speaker_mask, freqs_cis, kv_cache)
|
| 373 |
+
|
| 374 |
+
x_norm, mlp_gate = self.mlp_adaln(x, cond_embed)
|
| 375 |
+
x = x + mlp_gate * self.mlp(x_norm)
|
| 376 |
+
|
| 377 |
+
return x
|
| 378 |
+
|
| 379 |
+
def get_kv_cache(
|
| 380 |
+
self,
|
| 381 |
+
text_state: torch.Tensor,
|
| 382 |
+
speaker_state: torch.Tensor,
|
| 383 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 384 |
+
return self.attention.get_kv_cache(text_state, speaker_state)
|
| 385 |
+
|
| 386 |
+
class TextEncoder(nn.Module):
|
| 387 |
+
def __init__(
|
| 388 |
+
self,
|
| 389 |
+
vocab_size: int,
|
| 390 |
+
model_size: int,
|
| 391 |
+
num_layers: int,
|
| 392 |
+
num_heads: int,
|
| 393 |
+
intermediate_size: int,
|
| 394 |
+
norm_eps: float,
|
| 395 |
+
max_seq_len: int,
|
| 396 |
+
):
|
| 397 |
+
super().__init__()
|
| 398 |
+
self.text_embedding = nn.Embedding(vocab_size, model_size)
|
| 399 |
+
|
| 400 |
+
self.blocks = nn.ModuleList()
|
| 401 |
+
for i in range(num_layers):
|
| 402 |
+
block = EncoderTransformerBlock(
|
| 403 |
+
model_size=model_size,
|
| 404 |
+
num_heads=num_heads,
|
| 405 |
+
intermediate_size=intermediate_size,
|
| 406 |
+
is_causal=False,
|
| 407 |
+
norm_eps=norm_eps
|
| 408 |
+
)
|
| 409 |
+
self.blocks.append(block)
|
| 410 |
+
|
| 411 |
+
self.head_dim = model_size // num_heads
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def forward(self, input_ids: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 415 |
+
x = self.text_embedding(input_ids)
|
| 416 |
+
|
| 417 |
+
freqs_cis = precompute_freqs_cis(self.head_dim, input_ids.shape[1]).to(x.device) # see below about avoiding recomputation
|
| 418 |
+
for block in self.blocks:
|
| 419 |
+
x = block(x, mask, freqs_cis)
|
| 420 |
+
|
| 421 |
+
return x
|
| 422 |
+
|
| 423 |
+
class SpeakerEncoder(nn.Module):
|
| 424 |
+
def __init__(
|
| 425 |
+
self,
|
| 426 |
+
latent_size: int,
|
| 427 |
+
patch_size: int,
|
| 428 |
+
model_size: int,
|
| 429 |
+
num_layers: int,
|
| 430 |
+
num_heads: int,
|
| 431 |
+
intermediate_size: int,
|
| 432 |
+
norm_eps: float,
|
| 433 |
+
max_patched_seq_len: int,
|
| 434 |
+
):
|
| 435 |
+
super().__init__()
|
| 436 |
+
self.patch_size = patch_size
|
| 437 |
+
|
| 438 |
+
self.in_proj = nn.Linear(latent_size * patch_size, model_size, bias=True)
|
| 439 |
+
|
| 440 |
+
self.blocks = nn.ModuleList()
|
| 441 |
+
for i in range(num_layers):
|
| 442 |
+
block = EncoderTransformerBlock(
|
| 443 |
+
model_size=model_size,
|
| 444 |
+
num_heads=num_heads,
|
| 445 |
+
intermediate_size=intermediate_size,
|
| 446 |
+
is_causal=True,
|
| 447 |
+
norm_eps=norm_eps
|
| 448 |
+
)
|
| 449 |
+
self.blocks.append(block)
|
| 450 |
+
|
| 451 |
+
self.head_dim = model_size // num_heads
|
| 452 |
+
|
| 453 |
+
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
| 454 |
+
x = latent.reshape(*latent.shape[:-2], latent.shape[-2] // self.patch_size, latent.shape[-1] * self.patch_size)
|
| 455 |
+
|
| 456 |
+
x = self.in_proj(x)
|
| 457 |
+
x = x / 6. # this helped with initial activation dynamics in early ablations, could also bake into in_proj
|
| 458 |
+
|
| 459 |
+
freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) # see below about avoiding recomputation
|
| 460 |
+
|
| 461 |
+
for block in self.blocks:
|
| 462 |
+
x = block(x, None, freqs_cis)
|
| 463 |
+
|
| 464 |
+
return x
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class EchoDiT(nn.Module):
|
| 468 |
+
def __init__(
|
| 469 |
+
self,
|
| 470 |
+
latent_size: int,
|
| 471 |
+
#
|
| 472 |
+
model_size: int,
|
| 473 |
+
num_layers: int,
|
| 474 |
+
num_heads: int,
|
| 475 |
+
intermediate_size: int,
|
| 476 |
+
norm_eps: float,
|
| 477 |
+
max_seq_len: int,
|
| 478 |
+
#
|
| 479 |
+
text_vocab_size: int,
|
| 480 |
+
text_model_size: int,
|
| 481 |
+
text_num_layers: int,
|
| 482 |
+
text_num_heads: int,
|
| 483 |
+
text_intermediate_size: int,
|
| 484 |
+
text_max_seq_len: int,
|
| 485 |
+
#
|
| 486 |
+
speaker_patch_size: int,
|
| 487 |
+
speaker_model_size: int,
|
| 488 |
+
speaker_num_layers: int,
|
| 489 |
+
speaker_num_heads: int,
|
| 490 |
+
speaker_intermediate_size: int,
|
| 491 |
+
speaker_max_patched_seq_len: int,
|
| 492 |
+
#
|
| 493 |
+
timestep_embed_size: int,
|
| 494 |
+
adaln_rank: int,
|
| 495 |
+
):
|
| 496 |
+
super().__init__()
|
| 497 |
+
self.speaker_patch_size = speaker_patch_size
|
| 498 |
+
self.timestep_embed_size = timestep_embed_size
|
| 499 |
+
|
| 500 |
+
self.text_encoder = TextEncoder(
|
| 501 |
+
vocab_size=text_vocab_size,
|
| 502 |
+
model_size=text_model_size,
|
| 503 |
+
num_layers=text_num_layers,
|
| 504 |
+
num_heads=text_num_heads,
|
| 505 |
+
intermediate_size=text_intermediate_size,
|
| 506 |
+
norm_eps=norm_eps,
|
| 507 |
+
max_seq_len=text_max_seq_len,
|
| 508 |
+
)
|
| 509 |
+
self.speaker_encoder = SpeakerEncoder(
|
| 510 |
+
latent_size=latent_size,
|
| 511 |
+
patch_size=speaker_patch_size,
|
| 512 |
+
model_size=speaker_model_size,
|
| 513 |
+
num_layers=speaker_num_layers,
|
| 514 |
+
num_heads=speaker_num_heads,
|
| 515 |
+
intermediate_size=speaker_intermediate_size,
|
| 516 |
+
norm_eps=norm_eps,
|
| 517 |
+
max_patched_seq_len=speaker_max_patched_seq_len,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
self.text_norm = RMSNorm(text_model_size, norm_eps)
|
| 521 |
+
self.speaker_norm = RMSNorm(speaker_model_size, norm_eps)
|
| 522 |
+
|
| 523 |
+
self.cond_module = nn.Sequential(
|
| 524 |
+
nn.Linear(timestep_embed_size, model_size, bias=False),
|
| 525 |
+
nn.SiLU(),
|
| 526 |
+
nn.Linear(model_size, model_size, bias=False),
|
| 527 |
+
nn.SiLU(),
|
| 528 |
+
nn.Linear(model_size, model_size * 3, bias=False),
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
self.in_proj = nn.Linear(latent_size, model_size, bias=True)
|
| 532 |
+
|
| 533 |
+
self.blocks = nn.ModuleList()
|
| 534 |
+
for i in range(num_layers):
|
| 535 |
+
block = TransformerBlock(
|
| 536 |
+
model_size=model_size,
|
| 537 |
+
num_heads=num_heads,
|
| 538 |
+
intermediate_size=intermediate_size,
|
| 539 |
+
norm_eps=norm_eps,
|
| 540 |
+
text_model_size=text_model_size,
|
| 541 |
+
speaker_model_size=speaker_model_size,
|
| 542 |
+
speaker_patch_size=speaker_patch_size,
|
| 543 |
+
adaln_rank=adaln_rank,
|
| 544 |
+
)
|
| 545 |
+
self.blocks.append(block)
|
| 546 |
+
|
| 547 |
+
self.out_norm = RMSNorm(model_size, norm_eps)
|
| 548 |
+
self.out_proj = nn.Linear(model_size, latent_size, bias=True)
|
| 549 |
+
|
| 550 |
+
self.head_dim = model_size // num_heads
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def forward(
|
| 554 |
+
self,
|
| 555 |
+
x: torch.Tensor,
|
| 556 |
+
t: torch.Tensor,
|
| 557 |
+
text_input_ids: torch.Tensor,
|
| 558 |
+
text_mask: torch.Tensor | None,
|
| 559 |
+
speaker_latent: torch.Tensor,
|
| 560 |
+
speaker_mask: torch.Tensor | None,
|
| 561 |
+
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
|
| 562 |
+
) -> torch.Tensor:
|
| 563 |
+
"""
|
| 564 |
+
x: (b, s, d)
|
| 565 |
+
t: (b,)
|
| 566 |
+
text_input_ids: (b, s_t) # not used when kv_cache is provided
|
| 567 |
+
text_mask: (b, s_t)
|
| 568 |
+
speaker_latent: (b, s_r, d) # not used when kv_cache is provided
|
| 569 |
+
speaker_mask: (b, s_r)
|
| 570 |
+
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 571 |
+
|
| 572 |
+
returns: (b, s, d)
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device)
|
| 576 |
+
# can't register as buffer because we'd like it to stay in fp32; however, could optionally pass in to avoid recomputing
|
| 577 |
+
|
| 578 |
+
if kv_cache is None and speaker_state is None:
|
| 579 |
+
text_state = self.text_encoder(text_input_ids, text_mask)
|
| 580 |
+
text_state = self.text_norm(text_state)
|
| 581 |
+
speaker_state = self.speaker_encoder(speaker_latent)
|
| 582 |
+
speaker_state = self.speaker_norm(speaker_state)
|
| 583 |
+
else:
|
| 584 |
+
text_state, speaker_state = None, None
|
| 585 |
+
|
| 586 |
+
speaker_mask = speaker_mask[..., ::self.speaker_patch_size]
|
| 587 |
+
|
| 588 |
+
cond_embed = self.cond_module(get_timestep_embedding(t, self.timestep_embed_size))
|
| 589 |
+
|
| 590 |
+
assert cond_embed.ndim == 2
|
| 591 |
+
cond_embed = cond_embed[:, None]
|
| 592 |
+
|
| 593 |
+
x = self.in_proj(x)
|
| 594 |
+
|
| 595 |
+
for i, block in enumerate(self.blocks):
|
| 596 |
+
x = block(
|
| 597 |
+
x=x,
|
| 598 |
+
cond_embed=cond_embed,
|
| 599 |
+
text_state=text_state,
|
| 600 |
+
text_mask=text_mask,
|
| 601 |
+
speaker_state=speaker_state,
|
| 602 |
+
speaker_mask=speaker_mask,
|
| 603 |
+
freqs_cis=freqs_cis,
|
| 604 |
+
kv_cache=kv_cache[i] if kv_cache is not None else None,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
x = self.out_norm(x)
|
| 608 |
+
x = self.out_proj(x)
|
| 609 |
+
|
| 610 |
+
return x.float()
|
| 611 |
+
|
| 612 |
+
def get_kv_cache(
|
| 613 |
+
self,
|
| 614 |
+
speaker_latent: torch.Tensor,
|
| 615 |
+
speaker_mask: torch.Tensor,
|
| 616 |
+
text_input_ids: torch.Tensor,
|
| 617 |
+
text_mask: torch.Tensor,
|
| 618 |
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 619 |
+
|
| 620 |
+
speaker_state = self.speaker_encoder(speaker_latent)
|
| 621 |
+
speaker_state = self.speaker_norm(speaker_state)
|
| 622 |
+
|
| 623 |
+
text_state = self.text_encoder(text_input_ids, text_mask)
|
| 624 |
+
text_state = self.text_norm(text_state)
|
| 625 |
+
|
| 626 |
+
return [self.blocks[i].get_kv_cache(text_state, speaker_state) for i in range(len(self.blocks))]
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def get_kv_cache_from_precomputed_speaker_state(
|
| 630 |
+
self,
|
| 631 |
+
speaker_state: torch.Tensor,
|
| 632 |
+
speaker_mask: torch.Tensor,
|
| 633 |
+
text_input_ids: torch.Tensor,
|
| 634 |
+
text_mask: torch.Tensor,
|
| 635 |
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 636 |
+
|
| 637 |
+
# here, speaker state is already computed from the speaker latent encoder transformer
|
| 638 |
+
|
| 639 |
+
text_state = self.text_encoder(text_input_ids, text_mask)
|
| 640 |
+
text_state = self.text_norm(text_state)
|
| 641 |
+
|
| 642 |
+
return [self.blocks[i].get_kv_cache(text_state, speaker_state) for i in range(len(self.blocks))]
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
@property
|
| 647 |
+
def device(self) -> torch.device: return next(self.parameters()).device
|
| 648 |
+
|
| 649 |
+
@property
|
| 650 |
+
def dtype(self) -> torch.dtype: return next(self.parameters()).dtype
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
prompt_audio/EARS p004 freeform.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68947a209bc11064f749ca0a61b7959243df83565a0e462b87dfc0ffe03aa7b0
|
| 3 |
+
size 1526439
|
prompt_audio/EARS p005 freeform.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07344d073eb3e22c249ebfe15f31f4ba63fd9f17c71aeee93da199ff3b53fc45
|
| 3 |
+
size 1351147
|
prompt_audio/EARS p028 freeform.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8351eed5982f1fb5763a475c0fb69dba98a4bb49b0f2bbab12b978ff2b0fedeb
|
| 3 |
+
size 1211565
|
prompt_audio/EARS p036 freeform.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce77dbb86ea7c29edf2b9804ce9c9315334e9cfeef532dc0c50898a09bae1583
|
| 3 |
+
size 1227585
|
prompt_audio/expresso_02_ex03-ex01_calm_005.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f2be4d1cb5646b3523a460ec40bf171f959a9b33bde918e6d0f795d00284f52a
|
| 3 |
+
size 21168080
|
prompt_audio/freesound_demon_chant(use_forcespeaker).mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:471f67fff5ea613ec4617b9822b1396da123a1133f199925436a2c40e5d1eb91
|
| 3 |
+
size 303438
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
torchcodec
|
| 4 |
+
gradio>=5.49
|
| 5 |
+
huggingface-hub
|
| 6 |
+
numpy
|
| 7 |
+
safetensors
|
| 8 |
+
einops
|
sampler_presets.json
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
|
| 3 |
+
"Independent (High Speaker CFG)": {
|
| 4 |
+
"num_steps": "40",
|
| 5 |
+
"cfg_mode": "independent",
|
| 6 |
+
"cfg_scale_text": "3.0",
|
| 7 |
+
"cfg_scale_speaker": "8.0",
|
| 8 |
+
"cfg_min_t": "0.5",
|
| 9 |
+
"cfg_max_t": "1.0",
|
| 10 |
+
"truncation_factor": "1.",
|
| 11 |
+
"rescale_k": "1.",
|
| 12 |
+
"rescale_sigma": "3.0"
|
| 13 |
+
},
|
| 14 |
+
"Independent (High Speaker CFG) Flat": {
|
| 15 |
+
"num_steps": "40",
|
| 16 |
+
"cfg_mode": "independent",
|
| 17 |
+
"cfg_scale_text": "3.0",
|
| 18 |
+
"cfg_scale_speaker": "8.0",
|
| 19 |
+
"cfg_min_t": "0.5",
|
| 20 |
+
"cfg_max_t": "1.0",
|
| 21 |
+
"truncation_factor": "0.8",
|
| 22 |
+
"rescale_k": "1.2",
|
| 23 |
+
"rescale_sigma": "3.0"
|
| 24 |
+
},
|
| 25 |
+
"APG": {
|
| 26 |
+
"num_steps": "40",
|
| 27 |
+
"cfg_mode": "apg-independent",
|
| 28 |
+
"cfg_scale_text": "8.0",
|
| 29 |
+
"cfg_scale_speaker": "8.0",
|
| 30 |
+
"cfg_min_t": "0.5",
|
| 31 |
+
"cfg_max_t": "1.0",
|
| 32 |
+
"truncation_factor": "1.",
|
| 33 |
+
"rescale_k": "1.",
|
| 34 |
+
"rescale_sigma": "3.0",
|
| 35 |
+
"speaker_k_enable": false,
|
| 36 |
+
"speaker_k_scale": "1.5",
|
| 37 |
+
"speaker_k_min_t": "0.9",
|
| 38 |
+
"speaker_k_max_layers": "24",
|
| 39 |
+
"apg_eta_text": "0.5",
|
| 40 |
+
"apg_eta_speaker": "0.5",
|
| 41 |
+
"apg_momentum_text": "0.0",
|
| 42 |
+
"apg_momentum_speaker": "0.0"
|
| 43 |
+
},
|
| 44 |
+
"APG Flat": {
|
| 45 |
+
"num_steps": "40",
|
| 46 |
+
"cfg_mode": "apg-independent",
|
| 47 |
+
"cfg_scale_text": "8.0",
|
| 48 |
+
"cfg_scale_speaker": "8.0",
|
| 49 |
+
"cfg_min_t": "0.5",
|
| 50 |
+
"cfg_max_t": "1.0",
|
| 51 |
+
"truncation_factor": "0.8",
|
| 52 |
+
"rescale_k": "1.2",
|
| 53 |
+
"rescale_sigma": "3.0",
|
| 54 |
+
"speaker_k_enable": false,
|
| 55 |
+
"speaker_k_scale": "1.5",
|
| 56 |
+
"speaker_k_min_t": "0.9",
|
| 57 |
+
"speaker_k_max_layers": "24",
|
| 58 |
+
"apg_eta_text": "0.5",
|
| 59 |
+
"apg_eta_speaker": "0.5",
|
| 60 |
+
"apg_momentum_text": "0.0",
|
| 61 |
+
"apg_momentum_speaker": "0.0"
|
| 62 |
+
},
|
| 63 |
+
"Independent (High CFG)": {
|
| 64 |
+
"num_steps": "40",
|
| 65 |
+
"cfg_mode": "independent",
|
| 66 |
+
"cfg_scale_text": "8.0",
|
| 67 |
+
"cfg_scale_speaker": "8.0",
|
| 68 |
+
"cfg_min_t": "0.5",
|
| 69 |
+
"cfg_max_t": "1.0",
|
| 70 |
+
"truncation_factor": "1.",
|
| 71 |
+
"rescale_k": "1.",
|
| 72 |
+
"rescale_sigma": "3.0"
|
| 73 |
+
},
|
| 74 |
+
"Independent (High CFG) Flat": {
|
| 75 |
+
"num_steps": "40",
|
| 76 |
+
"cfg_mode": "independent",
|
| 77 |
+
"cfg_scale_text": "8.0",
|
| 78 |
+
"cfg_scale_speaker": "8.0",
|
| 79 |
+
"cfg_min_t": "0.5",
|
| 80 |
+
"cfg_max_t": "1.0",
|
| 81 |
+
"truncation_factor": "0.8",
|
| 82 |
+
"rescale_k": "1.2",
|
| 83 |
+
"rescale_sigma": "3.0"
|
| 84 |
+
},
|
| 85 |
+
|
| 86 |
+
"Independent": {
|
| 87 |
+
"num_steps": "40",
|
| 88 |
+
"cfg_mode": "independent",
|
| 89 |
+
"cfg_scale_text": "3.0",
|
| 90 |
+
"cfg_scale_speaker": "5.0",
|
| 91 |
+
"cfg_min_t": "0.5",
|
| 92 |
+
"cfg_max_t": "1.0",
|
| 93 |
+
"truncation_factor": "1.",
|
| 94 |
+
"rescale_k": "1.",
|
| 95 |
+
"rescale_sigma": "3.0"
|
| 96 |
+
},
|
| 97 |
+
"Independent Flat": {
|
| 98 |
+
"num_steps": "40",
|
| 99 |
+
"cfg_mode": "independent",
|
| 100 |
+
"cfg_scale_text": "3.0",
|
| 101 |
+
"cfg_scale_speaker": "5.0",
|
| 102 |
+
"cfg_min_t": "0.5",
|
| 103 |
+
"cfg_max_t": "1.0",
|
| 104 |
+
"truncation_factor": "0.8",
|
| 105 |
+
"rescale_k": "1.2",
|
| 106 |
+
"rescale_sigma": "3.0"
|
| 107 |
+
},
|
| 108 |
+
"Joint 20-step Flat": {
|
| 109 |
+
"num_steps": "20",
|
| 110 |
+
"cfg_mode": "joint-unconditional",
|
| 111 |
+
"cfg_scale_text": "3.0",
|
| 112 |
+
"cfg_scale_speaker": "3.0",
|
| 113 |
+
"cfg_min_t": "0.5",
|
| 114 |
+
"cfg_max_t": "1.0",
|
| 115 |
+
"truncation_factor": "0.8",
|
| 116 |
+
"rescale_k": "1.2",
|
| 117 |
+
"rescale_sigma": "3.0"
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
samplers.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from model import EchoDiT
|
| 6 |
+
|
| 7 |
+
# helper
|
| 8 |
+
def _get_uncond_text_input_ids_and_mask(batch_size: int, max_length: int, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]:
|
| 9 |
+
# returns zeros for text input ids, and (True, False, False, ... ) for text mask
|
| 10 |
+
text_input_ids_uncond = torch.zeros((batch_size, max_length), dtype=torch.int32)
|
| 11 |
+
text_mask_uncond = torch.zeros((batch_size, max_length), dtype=torch.bool)
|
| 12 |
+
text_mask_uncond[:, 0] = True
|
| 13 |
+
if device is not None:
|
| 14 |
+
text_input_ids_uncond = text_input_ids_uncond.to(device)
|
| 15 |
+
text_mask_uncond = text_mask_uncond.to(device)
|
| 16 |
+
return text_input_ids_uncond, text_mask_uncond
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# SIMPLE SAMPLER FOR REFERENCE, SHOULD PROBABLY AVOID
|
| 20 |
+
@torch.inference_mode()
|
| 21 |
+
def sample_euler_cfg_simple(
|
| 22 |
+
model: EchoDiT,
|
| 23 |
+
speaker_latent: torch.Tensor,
|
| 24 |
+
speaker_mask: torch.Tensor,
|
| 25 |
+
text_input_ids: torch.Tensor,
|
| 26 |
+
text_mask: torch.Tensor,
|
| 27 |
+
rng_seed: int,
|
| 28 |
+
num_steps: int,
|
| 29 |
+
cfg_scale: float,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
|
| 32 |
+
device, dtype = model.device, model.dtype
|
| 33 |
+
|
| 34 |
+
batch_size = text_input_ids.shape[0]
|
| 35 |
+
|
| 36 |
+
torch.manual_seed(rng_seed)
|
| 37 |
+
|
| 38 |
+
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device)
|
| 39 |
+
|
| 40 |
+
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
|
| 41 |
+
|
| 42 |
+
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
|
| 43 |
+
|
| 44 |
+
full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0)
|
| 45 |
+
full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
|
| 46 |
+
|
| 47 |
+
full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
|
| 48 |
+
full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
|
| 49 |
+
|
| 50 |
+
kv_cache = model.get_kv_cache(
|
| 51 |
+
speaker_latent=full_speaker_latent.to(dtype),
|
| 52 |
+
speaker_mask=full_speaker_mask,
|
| 53 |
+
text_input_ids=full_text_input_ids,
|
| 54 |
+
text_mask=full_text_mask,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
x_t = torch.randn((batch_size, 640, 80), device=device, dtype=torch.float32)
|
| 58 |
+
|
| 59 |
+
for i in range(num_steps):
|
| 60 |
+
t, t_next = t_schedule[i], t_schedule[i+1]
|
| 61 |
+
v_cond, v_uncond = model(
|
| 62 |
+
x=torch.cat([x_t, x_t], dim=0).to(dtype),
|
| 63 |
+
t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
|
| 64 |
+
text_input_ids=None,
|
| 65 |
+
text_mask=full_text_mask,
|
| 66 |
+
speaker_latent=None,
|
| 67 |
+
speaker_mask=full_speaker_mask,
|
| 68 |
+
kv_cache=kv_cache,
|
| 69 |
+
).float().chunk(2, dim=0)
|
| 70 |
+
|
| 71 |
+
v_pred = v_cond + cfg_scale * (v_cond - v_uncond)
|
| 72 |
+
# note: x_0_pred is x_t - v_pred * t
|
| 73 |
+
x_t = x_t + v_pred * (t_next - t)
|
| 74 |
+
|
| 75 |
+
return x_t
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
######
|
| 79 |
+
|
| 80 |
+
def _temporal_score_rescale(v_pred: torch.Tensor, x_t: torch.Tensor, t: float, rescale_k: float, rescale_sigma: float) -> torch.Tensor:
|
| 81 |
+
if t < 1:
|
| 82 |
+
snr = (1 - t) ** 2 / (t ** 2)
|
| 83 |
+
ratio = (snr * rescale_sigma ** 2 + 1) / (snr * rescale_sigma ** 2 / rescale_k + 1)
|
| 84 |
+
return 1 / (1 - t) * (ratio * ((1 - t) * v_pred + x_t) - x_t)
|
| 85 |
+
return v_pred
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _get_first_n_kv_cache(kv_cache: List[List[torch.Tensor]], n: int) -> List[List[torch.Tensor]]:
|
| 89 |
+
return [[kv_cache[i][0][:n], kv_cache[i][1][:n]] for i in range(len(kv_cache))]
|
| 90 |
+
|
| 91 |
+
def _multiply_speaker_kv_cache(
|
| 92 |
+
kv_cache: List[List[torch.Tensor]],
|
| 93 |
+
scale: float,
|
| 94 |
+
text_length: int,
|
| 95 |
+
max_layers: int = 24,
|
| 96 |
+
) -> List[List[torch.Tensor]]:
|
| 97 |
+
# multiplies speaker kv cache by scale
|
| 98 |
+
# speaker keys start after text keys (at position text_length)
|
| 99 |
+
for i in range(min(max_layers, len(kv_cache))):
|
| 100 |
+
for j in range(len(kv_cache[i])):
|
| 101 |
+
kv_cache[i][j][:, text_length:] *= scale
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@torch.inference_mode()
|
| 105 |
+
def sample_euler_cfg(
|
| 106 |
+
model: EchoDiT,
|
| 107 |
+
speaker_latent: torch.Tensor,
|
| 108 |
+
speaker_mask: torch.Tensor,
|
| 109 |
+
text_input_ids: torch.Tensor,
|
| 110 |
+
text_mask: torch.Tensor,
|
| 111 |
+
rng_seed: int,
|
| 112 |
+
num_steps: int,
|
| 113 |
+
cfg_scale: float,
|
| 114 |
+
cfg_min_t: float,
|
| 115 |
+
cfg_max_t: float,
|
| 116 |
+
truncation_factor: float | None,
|
| 117 |
+
rescale_k: float | None,
|
| 118 |
+
rescale_sigma: float | None,
|
| 119 |
+
speaker_k_scale: float | None,
|
| 120 |
+
speaker_k_max_layers: int | None,
|
| 121 |
+
speaker_k_min_t: float | None,
|
| 122 |
+
block_size: int | None = None,
|
| 123 |
+
) -> torch.Tensor:
|
| 124 |
+
|
| 125 |
+
if block_size is None:
|
| 126 |
+
block_size = 640
|
| 127 |
+
|
| 128 |
+
torch.manual_seed(rng_seed)
|
| 129 |
+
|
| 130 |
+
INIT_SCALE = 0.999
|
| 131 |
+
|
| 132 |
+
device, dtype = model.device, model.dtype
|
| 133 |
+
|
| 134 |
+
batch_size = text_input_ids.shape[0]
|
| 135 |
+
|
| 136 |
+
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
|
| 137 |
+
|
| 138 |
+
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
|
| 139 |
+
|
| 140 |
+
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
|
| 141 |
+
|
| 142 |
+
full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0)
|
| 143 |
+
full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
|
| 144 |
+
|
| 145 |
+
full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
|
| 146 |
+
full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
|
| 147 |
+
|
| 148 |
+
kv_cache_full = model.get_kv_cache(
|
| 149 |
+
speaker_latent=full_speaker_latent.to(dtype),
|
| 150 |
+
speaker_mask=full_speaker_mask,
|
| 151 |
+
text_input_ids=full_text_input_ids,
|
| 152 |
+
text_mask=full_text_mask,
|
| 153 |
+
) # could make faster by not computing fully / recomputing for unconditional batch elements
|
| 154 |
+
kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
|
| 155 |
+
if speaker_k_scale is not None:
|
| 156 |
+
_multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
|
| 157 |
+
|
| 158 |
+
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
|
| 159 |
+
|
| 160 |
+
if truncation_factor is not None:
|
| 161 |
+
x_t = x_t * truncation_factor
|
| 162 |
+
|
| 163 |
+
for i in range(num_steps):
|
| 164 |
+
t, t_next = t_schedule[i], t_schedule[i+1]
|
| 165 |
+
|
| 166 |
+
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
|
| 167 |
+
|
| 168 |
+
if has_cfg:
|
| 169 |
+
v_cond, v_uncond = model(
|
| 170 |
+
x=torch.cat([x_t, x_t], dim=0).to(dtype),
|
| 171 |
+
t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
|
| 172 |
+
text_input_ids=None,
|
| 173 |
+
text_mask=full_text_mask,
|
| 174 |
+
speaker_latent=None,
|
| 175 |
+
speaker_mask=full_speaker_mask,
|
| 176 |
+
kv_cache=kv_cache_full,
|
| 177 |
+
).float().chunk(2, dim=0)
|
| 178 |
+
v_pred = v_cond + cfg_scale * (v_cond - v_uncond)
|
| 179 |
+
else:
|
| 180 |
+
v_pred = model(
|
| 181 |
+
x=x_t.to(dtype),
|
| 182 |
+
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
|
| 183 |
+
text_input_ids=None,
|
| 184 |
+
text_mask=text_mask,
|
| 185 |
+
speaker_latent=None,
|
| 186 |
+
speaker_mask=speaker_mask,
|
| 187 |
+
kv_cache=kv_cache,
|
| 188 |
+
).float()
|
| 189 |
+
|
| 190 |
+
if rescale_k is not None and rescale_sigma is not None:
|
| 191 |
+
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
|
| 192 |
+
|
| 193 |
+
if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
|
| 194 |
+
_multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
|
| 195 |
+
|
| 196 |
+
x_t = x_t + v_pred * (t_next - t)
|
| 197 |
+
|
| 198 |
+
return x_t
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@torch.inference_mode()
|
| 202 |
+
def sample_euler_cfg_independent_guidances(
|
| 203 |
+
model: EchoDiT,
|
| 204 |
+
speaker_latent: torch.Tensor,
|
| 205 |
+
speaker_mask: torch.Tensor,
|
| 206 |
+
text_input_ids: torch.Tensor,
|
| 207 |
+
text_mask: torch.Tensor,
|
| 208 |
+
rng_seed: int,
|
| 209 |
+
num_steps: int,
|
| 210 |
+
cfg_scale_text: float,
|
| 211 |
+
cfg_scale_speaker: float,
|
| 212 |
+
cfg_min_t: float,
|
| 213 |
+
cfg_max_t: float,
|
| 214 |
+
truncation_factor: float | None,
|
| 215 |
+
rescale_k: float | None,
|
| 216 |
+
rescale_sigma: float | None,
|
| 217 |
+
speaker_k_scale: float | None,
|
| 218 |
+
speaker_k_max_layers: int | None,
|
| 219 |
+
speaker_k_min_t: float | None,
|
| 220 |
+
block_size: int | None = None,
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
|
| 223 |
+
if block_size is None:
|
| 224 |
+
block_size = 640
|
| 225 |
+
|
| 226 |
+
torch.manual_seed(rng_seed)
|
| 227 |
+
|
| 228 |
+
INIT_SCALE = 0.999
|
| 229 |
+
|
| 230 |
+
device, dtype = model.device, model.dtype
|
| 231 |
+
|
| 232 |
+
batch_size = text_input_ids.shape[0]
|
| 233 |
+
|
| 234 |
+
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
|
| 235 |
+
|
| 236 |
+
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
|
| 237 |
+
|
| 238 |
+
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
|
| 239 |
+
|
| 240 |
+
full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0)
|
| 241 |
+
full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
|
| 242 |
+
|
| 243 |
+
full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0)
|
| 244 |
+
full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
|
| 245 |
+
|
| 246 |
+
kv_cache_full = model.get_kv_cache(
|
| 247 |
+
speaker_latent=full_speaker_latent.to(dtype),
|
| 248 |
+
speaker_mask=full_speaker_mask,
|
| 249 |
+
text_input_ids=full_text_input_ids,
|
| 250 |
+
text_mask=full_text_mask,
|
| 251 |
+
)
|
| 252 |
+
kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
|
| 253 |
+
|
| 254 |
+
if speaker_k_scale is not None:
|
| 255 |
+
_multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
|
| 256 |
+
|
| 257 |
+
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
|
| 258 |
+
if truncation_factor is not None:
|
| 259 |
+
x_t = x_t * truncation_factor
|
| 260 |
+
|
| 261 |
+
for i in range(num_steps):
|
| 262 |
+
t, t_next = t_schedule[i], t_schedule[i+1]
|
| 263 |
+
|
| 264 |
+
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
|
| 265 |
+
|
| 266 |
+
if has_cfg:
|
| 267 |
+
v_cond, v_uncond_text, v_uncond_speaker = model(
|
| 268 |
+
x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
|
| 269 |
+
t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
|
| 270 |
+
text_input_ids=None,
|
| 271 |
+
text_mask=full_text_mask,
|
| 272 |
+
speaker_latent=None,
|
| 273 |
+
speaker_mask=full_speaker_mask,
|
| 274 |
+
kv_cache=kv_cache_full,
|
| 275 |
+
).float().chunk(3, dim=0)
|
| 276 |
+
v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker)
|
| 277 |
+
else:
|
| 278 |
+
v_pred = model(
|
| 279 |
+
x=x_t.to(dtype),
|
| 280 |
+
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
|
| 281 |
+
text_input_ids=None,
|
| 282 |
+
text_mask=text_mask,
|
| 283 |
+
speaker_latent=None,
|
| 284 |
+
speaker_mask=speaker_mask,
|
| 285 |
+
kv_cache=kv_cache,
|
| 286 |
+
).float()
|
| 287 |
+
|
| 288 |
+
if rescale_k is not None and rescale_sigma is not None:
|
| 289 |
+
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
|
| 290 |
+
|
| 291 |
+
if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
|
| 292 |
+
_multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
|
| 293 |
+
|
| 294 |
+
x_t = x_t + v_pred * (t_next - t)
|
| 295 |
+
|
| 296 |
+
return x_t
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@torch.inference_mode()
|
| 301 |
+
def sample_euler_cfg_alternating_guidances(
|
| 302 |
+
model: EchoDiT,
|
| 303 |
+
speaker_latent: torch.Tensor,
|
| 304 |
+
speaker_mask: torch.Tensor,
|
| 305 |
+
text_input_ids: torch.Tensor,
|
| 306 |
+
text_mask: torch.Tensor,
|
| 307 |
+
rng_seed: int,
|
| 308 |
+
num_steps: int,
|
| 309 |
+
cfg_scale_text: float,
|
| 310 |
+
cfg_scale_speaker: float,
|
| 311 |
+
cfg_min_t: float,
|
| 312 |
+
cfg_max_t: float,
|
| 313 |
+
truncation_factor: float | None,
|
| 314 |
+
rescale_k: float | None,
|
| 315 |
+
rescale_sigma: float | None,
|
| 316 |
+
speaker_k_scale: float | None,
|
| 317 |
+
speaker_k_max_layers: int | None,
|
| 318 |
+
speaker_k_min_t: float | None,
|
| 319 |
+
block_size: int | None = None,
|
| 320 |
+
) -> torch.Tensor:
|
| 321 |
+
|
| 322 |
+
if block_size is None:
|
| 323 |
+
block_size = 640
|
| 324 |
+
|
| 325 |
+
torch.manual_seed(rng_seed)
|
| 326 |
+
|
| 327 |
+
INIT_SCALE = 0.999
|
| 328 |
+
|
| 329 |
+
device, dtype = model.device, model.dtype
|
| 330 |
+
|
| 331 |
+
batch_size = text_input_ids.shape[0]
|
| 332 |
+
|
| 333 |
+
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
|
| 334 |
+
|
| 335 |
+
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
|
| 336 |
+
|
| 337 |
+
# TODO THIS / THE BELOW IS TECHNICALLY INCORRECT, AS IT ASSUMES A CAUSAL TEXT ENCODER (which is not the case)
|
| 338 |
+
# IF THE TEXT ENCODER WERE CAUSAL, THEN USING AN UNCOND TEXT MASK ON COND TEXT INPUTS GIVES YOU AN UNCOND STATE DUE TO BOS=0
|
| 339 |
+
# HOWEVER, MIGHT NOT MAKE MUCH OF A DIFFERENCE
|
| 340 |
+
# CHANGED ALL OTHER SAMPLERS TO USE CORRECT UNCONDITIONAL CACHES
|
| 341 |
+
|
| 342 |
+
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
|
| 343 |
+
|
| 344 |
+
full_text_input_ids = torch.cat([text_input_ids, text_input_ids], dim=0)
|
| 345 |
+
full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
|
| 346 |
+
|
| 347 |
+
full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
|
| 348 |
+
full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
|
| 349 |
+
|
| 350 |
+
kv_cache_full = model.get_kv_cache(
|
| 351 |
+
speaker_latent=full_speaker_latent.to(dtype),
|
| 352 |
+
speaker_mask=full_speaker_mask,
|
| 353 |
+
text_input_ids=full_text_input_ids,
|
| 354 |
+
text_mask=full_text_mask,
|
| 355 |
+
)
|
| 356 |
+
kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
|
| 357 |
+
|
| 358 |
+
if speaker_k_scale is not None:
|
| 359 |
+
_multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
|
| 360 |
+
|
| 361 |
+
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
|
| 362 |
+
if truncation_factor is not None:
|
| 363 |
+
x_t = x_t * truncation_factor
|
| 364 |
+
|
| 365 |
+
for i in range(num_steps):
|
| 366 |
+
t, t_next = t_schedule[i], t_schedule[i+1]
|
| 367 |
+
|
| 368 |
+
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
|
| 369 |
+
|
| 370 |
+
if has_cfg:
|
| 371 |
+
v_cond, v_uncond = model(
|
| 372 |
+
x=torch.cat([x_t, x_t], dim=0).to(dtype),
|
| 373 |
+
t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
|
| 374 |
+
text_input_ids=None,
|
| 375 |
+
text_mask=torch.cat([text_mask, text_mask_uncond if i % 2 == 0 else text_mask], dim=0),
|
| 376 |
+
speaker_latent=None,
|
| 377 |
+
speaker_mask=torch.cat([speaker_mask, speaker_mask if i % 2 == 0 else speaker_mask_uncond], dim=0),
|
| 378 |
+
kv_cache=kv_cache_full,
|
| 379 |
+
).float().chunk(2, dim=0)
|
| 380 |
+
v_pred = v_cond + (cfg_scale_text if i % 2 == 0 else cfg_scale_speaker) * (v_cond - v_uncond)
|
| 381 |
+
else:
|
| 382 |
+
v_pred = model(
|
| 383 |
+
x=x_t.to(dtype),
|
| 384 |
+
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
|
| 385 |
+
text_input_ids=None,
|
| 386 |
+
text_mask=text_mask,
|
| 387 |
+
speaker_latent=None,
|
| 388 |
+
speaker_mask=speaker_mask,
|
| 389 |
+
kv_cache=kv_cache,
|
| 390 |
+
).float()
|
| 391 |
+
|
| 392 |
+
if rescale_k is not None and rescale_sigma is not None:
|
| 393 |
+
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
|
| 394 |
+
|
| 395 |
+
if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
|
| 396 |
+
_multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
|
| 397 |
+
|
| 398 |
+
x_t = x_t + v_pred * (t_next - t)
|
| 399 |
+
|
| 400 |
+
return x_t
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@torch.inference_mode()
|
| 404 |
+
def sample_euler_apg_independent_guidances(
|
| 405 |
+
model: EchoDiT,
|
| 406 |
+
speaker_latent: torch.Tensor,
|
| 407 |
+
speaker_mask: torch.Tensor,
|
| 408 |
+
text_input_ids: torch.Tensor,
|
| 409 |
+
text_mask: torch.Tensor,
|
| 410 |
+
rng_seed: int,
|
| 411 |
+
num_steps: int,
|
| 412 |
+
cfg_scale_text: float,
|
| 413 |
+
cfg_scale_speaker: float,
|
| 414 |
+
cfg_min_t: float,
|
| 415 |
+
cfg_max_t: float,
|
| 416 |
+
truncation_factor: float | None,
|
| 417 |
+
rescale_k: float | None,
|
| 418 |
+
rescale_sigma: float | None,
|
| 419 |
+
apg_eta_text: float,
|
| 420 |
+
apg_eta_speaker: float,
|
| 421 |
+
apg_momentum_text: float | None,
|
| 422 |
+
apg_momentum_speaker: float | None,
|
| 423 |
+
apg_norm_text: float | None,
|
| 424 |
+
apg_norm_speaker: float | None,
|
| 425 |
+
speaker_k_scale: float | None,
|
| 426 |
+
speaker_k_max_layers: int | None,
|
| 427 |
+
speaker_k_min_t: float | None,
|
| 428 |
+
block_size: int | None = None,
|
| 429 |
+
) -> torch.Tensor:
|
| 430 |
+
|
| 431 |
+
if block_size is None:
|
| 432 |
+
block_size = 640
|
| 433 |
+
|
| 434 |
+
if apg_momentum_text is None:
|
| 435 |
+
apg_momentum_text = 0.0
|
| 436 |
+
if apg_momentum_speaker is None:
|
| 437 |
+
apg_momentum_speaker = 0.0
|
| 438 |
+
|
| 439 |
+
torch.manual_seed(rng_seed)
|
| 440 |
+
|
| 441 |
+
INIT_SCALE = 0.999
|
| 442 |
+
|
| 443 |
+
device, dtype = model.device, model.dtype
|
| 444 |
+
|
| 445 |
+
batch_size = text_input_ids.shape[0]
|
| 446 |
+
|
| 447 |
+
t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
|
| 448 |
+
|
| 449 |
+
text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
|
| 450 |
+
|
| 451 |
+
speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
|
| 452 |
+
|
| 453 |
+
full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0)
|
| 454 |
+
full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
|
| 455 |
+
|
| 456 |
+
full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0)
|
| 457 |
+
full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
|
| 458 |
+
|
| 459 |
+
kv_cache_full = model.get_kv_cache(
|
| 460 |
+
speaker_latent=full_speaker_latent.to(dtype),
|
| 461 |
+
speaker_mask=full_speaker_mask,
|
| 462 |
+
text_input_ids=full_text_input_ids,
|
| 463 |
+
text_mask=full_text_mask,
|
| 464 |
+
)
|
| 465 |
+
kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
|
| 466 |
+
|
| 467 |
+
if speaker_k_scale is not None:
|
| 468 |
+
_multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
|
| 469 |
+
|
| 470 |
+
x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
|
| 471 |
+
if truncation_factor is not None:
|
| 472 |
+
x_t = x_t * truncation_factor
|
| 473 |
+
|
| 474 |
+
buf_text = torch.zeros_like(x_t)
|
| 475 |
+
buf_speaker = torch.zeros_like(x_t)
|
| 476 |
+
|
| 477 |
+
for i in range(num_steps):
|
| 478 |
+
t, t_next = t_schedule[i], t_schedule[i+1]
|
| 479 |
+
|
| 480 |
+
has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
|
| 481 |
+
|
| 482 |
+
if has_cfg:
|
| 483 |
+
v_cond, v_uncond_text, v_uncond_speaker = model(
|
| 484 |
+
x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
|
| 485 |
+
t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
|
| 486 |
+
text_input_ids=None,
|
| 487 |
+
text_mask=full_text_mask,
|
| 488 |
+
speaker_latent=None,
|
| 489 |
+
speaker_mask=full_speaker_mask,
|
| 490 |
+
kv_cache=kv_cache_full,
|
| 491 |
+
).float().chunk(3, dim=0)
|
| 492 |
+
|
| 493 |
+
x0_cond = x_t - t * v_cond
|
| 494 |
+
x0_uncond_text = x_t - t * v_uncond_text
|
| 495 |
+
x0_uncond_speaker = x_t - t * v_uncond_speaker
|
| 496 |
+
|
| 497 |
+
diff_text = x0_cond - x0_uncond_text
|
| 498 |
+
diff_speaker = x0_cond - x0_uncond_speaker
|
| 499 |
+
|
| 500 |
+
buf_text = diff_text + apg_momentum_text * buf_text
|
| 501 |
+
diff_text = buf_text
|
| 502 |
+
|
| 503 |
+
buf_speaker = diff_speaker + apg_momentum_speaker * buf_speaker
|
| 504 |
+
diff_speaker = buf_speaker
|
| 505 |
+
|
| 506 |
+
if apg_norm_text is not None:
|
| 507 |
+
nt = torch.sqrt((diff_text * diff_text).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) + 1e-12)
|
| 508 |
+
s = torch.minimum(torch.ones_like(nt), (torch.as_tensor(apg_norm_text, device=device, dtype=diff_text.dtype) / nt))
|
| 509 |
+
diff_text = diff_text * s
|
| 510 |
+
if apg_norm_speaker is not None:
|
| 511 |
+
ns = torch.sqrt((diff_speaker * diff_speaker).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) + 1e-12)
|
| 512 |
+
s = torch.minimum(torch.ones_like(ns), (torch.as_tensor(apg_norm_speaker, device=device, dtype=diff_speaker.dtype) / ns))
|
| 513 |
+
diff_speaker = diff_speaker * s
|
| 514 |
+
|
| 515 |
+
c_norm = torch.sqrt((x0_cond * x0_cond).sum(dim=tuple(range(1, x0_cond.dim())), keepdim=True) + 1e-12)
|
| 516 |
+
c_hat = x0_cond / c_norm
|
| 517 |
+
|
| 518 |
+
par_text = (diff_text * c_hat).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) * c_hat
|
| 519 |
+
ort_text = diff_text - par_text
|
| 520 |
+
upd_text = ort_text + apg_eta_text * par_text
|
| 521 |
+
|
| 522 |
+
par_speaker = (diff_speaker * c_hat).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) * c_hat
|
| 523 |
+
ort_speaker = diff_speaker - par_speaker
|
| 524 |
+
upd_speaker = ort_speaker + apg_eta_speaker * par_speaker
|
| 525 |
+
|
| 526 |
+
x0_pred = x0_cond + cfg_scale_text * upd_text + cfg_scale_speaker * upd_speaker
|
| 527 |
+
v_pred = (x_t - x0_pred) / t
|
| 528 |
+
else:
|
| 529 |
+
v_pred = model(
|
| 530 |
+
x=x_t.to(dtype),
|
| 531 |
+
t=(torch.ones((batch_size,), device=device) * t).to(dtype),
|
| 532 |
+
text_input_ids=None,
|
| 533 |
+
text_mask=text_mask,
|
| 534 |
+
speaker_latent=None,
|
| 535 |
+
speaker_mask=speaker_mask,
|
| 536 |
+
kv_cache=kv_cache,
|
| 537 |
+
).float()
|
| 538 |
+
|
| 539 |
+
if rescale_k is not None and rescale_sigma is not None:
|
| 540 |
+
v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
|
| 541 |
+
|
| 542 |
+
if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
|
| 543 |
+
_multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
|
| 544 |
+
|
| 545 |
+
x_t = x_t + v_pred * (t_next - t)
|
| 546 |
+
|
| 547 |
+
return x_t
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
# router
|
| 552 |
+
|
| 553 |
+
class GuidanceMode(Enum):
|
| 554 |
+
INDEPENDENT = "independent"
|
| 555 |
+
APG = "apg"
|
| 556 |
+
JOINT = "joint"
|
| 557 |
+
ALTERNATING = "alternating"
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def sample_euler_cfg_any(
|
| 561 |
+
model: EchoDiT,
|
| 562 |
+
speaker_latent: torch.Tensor,
|
| 563 |
+
speaker_mask: torch.Tensor,
|
| 564 |
+
text_input_ids: torch.Tensor,
|
| 565 |
+
text_mask: torch.Tensor,
|
| 566 |
+
rng_seed: int,
|
| 567 |
+
guidance_mode: GuidanceMode,
|
| 568 |
+
num_steps: int,
|
| 569 |
+
cfg_scale_text: float,
|
| 570 |
+
cfg_scale_speaker: float | None,
|
| 571 |
+
cfg_min_t: float,
|
| 572 |
+
cfg_max_t: float,
|
| 573 |
+
truncation_factor: float | None,
|
| 574 |
+
rescale_k: float | None,
|
| 575 |
+
rescale_sigma: float | None,
|
| 576 |
+
speaker_k_scale: float | None,
|
| 577 |
+
speaker_k_min_t: float | None,
|
| 578 |
+
speaker_k_max_layers: int | None,
|
| 579 |
+
apg_eta_text: float | None,
|
| 580 |
+
apg_eta_speaker: float | None,
|
| 581 |
+
apg_momentum_text: float | None,
|
| 582 |
+
apg_momentum_speaker: float | None,
|
| 583 |
+
apg_norm_text: float | None,
|
| 584 |
+
apg_norm_speaker: float | None,
|
| 585 |
+
block_size: int | None = None,
|
| 586 |
+
) -> torch.Tensor:
|
| 587 |
+
|
| 588 |
+
if guidance_mode == GuidanceMode.INDEPENDENT:
|
| 589 |
+
assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for independent guidances"
|
| 590 |
+
return sample_euler_cfg_independent_guidances(
|
| 591 |
+
model=model,
|
| 592 |
+
speaker_latent=speaker_latent,
|
| 593 |
+
speaker_mask=speaker_mask,
|
| 594 |
+
text_input_ids=text_input_ids,
|
| 595 |
+
text_mask=text_mask,
|
| 596 |
+
rng_seed=rng_seed,
|
| 597 |
+
num_steps=num_steps,
|
| 598 |
+
cfg_scale_text=cfg_scale_text,
|
| 599 |
+
cfg_scale_speaker=cfg_scale_speaker,
|
| 600 |
+
cfg_min_t=cfg_min_t,
|
| 601 |
+
cfg_max_t=cfg_max_t,
|
| 602 |
+
truncation_factor=truncation_factor,
|
| 603 |
+
rescale_k=rescale_k,
|
| 604 |
+
rescale_sigma=rescale_sigma,
|
| 605 |
+
speaker_k_scale=speaker_k_scale,
|
| 606 |
+
speaker_k_max_layers=speaker_k_max_layers,
|
| 607 |
+
speaker_k_min_t=speaker_k_min_t,
|
| 608 |
+
block_size=block_size,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
elif guidance_mode == GuidanceMode.APG:
|
| 612 |
+
assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for APG"
|
| 613 |
+
assert apg_eta_text is not None, "apg_eta_text must be provided for APG"
|
| 614 |
+
assert apg_eta_speaker is not None, "apg_eta_speaker must be provided for APG"
|
| 615 |
+
return sample_euler_apg_independent_guidances(
|
| 616 |
+
model=model,
|
| 617 |
+
speaker_latent=speaker_latent,
|
| 618 |
+
speaker_mask=speaker_mask,
|
| 619 |
+
text_input_ids=text_input_ids,
|
| 620 |
+
text_mask=text_mask,
|
| 621 |
+
rng_seed=rng_seed,
|
| 622 |
+
num_steps=num_steps,
|
| 623 |
+
cfg_scale_text=cfg_scale_text,
|
| 624 |
+
cfg_scale_speaker=cfg_scale_speaker,
|
| 625 |
+
cfg_min_t=cfg_min_t,
|
| 626 |
+
cfg_max_t=cfg_max_t,
|
| 627 |
+
truncation_factor=truncation_factor,
|
| 628 |
+
rescale_k=rescale_k,
|
| 629 |
+
rescale_sigma=rescale_sigma,
|
| 630 |
+
apg_eta_text=apg_eta_text,
|
| 631 |
+
apg_eta_speaker=apg_eta_speaker,
|
| 632 |
+
apg_momentum_text=apg_momentum_text,
|
| 633 |
+
apg_momentum_speaker=apg_momentum_speaker,
|
| 634 |
+
apg_norm_text=apg_norm_text,
|
| 635 |
+
apg_norm_speaker=apg_norm_speaker,
|
| 636 |
+
speaker_k_scale=speaker_k_scale,
|
| 637 |
+
speaker_k_max_layers=speaker_k_max_layers,
|
| 638 |
+
speaker_k_min_t=speaker_k_min_t,
|
| 639 |
+
block_size=block_size,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
elif guidance_mode == GuidanceMode.JOINT:
|
| 643 |
+
assert cfg_scale_text == cfg_scale_speaker or cfg_scale_speaker is None, "cfg_scale_text and cfg_scale_speaker must be the same or cfg_scale_speaker must be None"
|
| 644 |
+
return sample_euler_cfg(
|
| 645 |
+
model=model,
|
| 646 |
+
speaker_latent=speaker_latent,
|
| 647 |
+
speaker_mask=speaker_mask,
|
| 648 |
+
text_input_ids=text_input_ids,
|
| 649 |
+
text_mask=text_mask,
|
| 650 |
+
rng_seed=rng_seed,
|
| 651 |
+
num_steps=num_steps,
|
| 652 |
+
cfg_scale=cfg_scale_text,
|
| 653 |
+
cfg_min_t=cfg_min_t,
|
| 654 |
+
cfg_max_t=cfg_max_t,
|
| 655 |
+
truncation_factor=truncation_factor,
|
| 656 |
+
rescale_k=rescale_k,
|
| 657 |
+
rescale_sigma=rescale_sigma,
|
| 658 |
+
speaker_k_scale=speaker_k_scale,
|
| 659 |
+
speaker_k_max_layers=speaker_k_max_layers,
|
| 660 |
+
speaker_k_min_t=speaker_k_min_t,
|
| 661 |
+
block_size=block_size,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
elif guidance_mode == GuidanceMode.ALTERNATING:
|
| 665 |
+
assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for alternating guidances"
|
| 666 |
+
return sample_euler_cfg_alternating_guidances(
|
| 667 |
+
model=model,
|
| 668 |
+
speaker_latent=speaker_latent,
|
| 669 |
+
speaker_mask=speaker_mask,
|
| 670 |
+
text_input_ids=text_input_ids,
|
| 671 |
+
text_mask=text_mask,
|
| 672 |
+
rng_seed=rng_seed,
|
| 673 |
+
num_steps=num_steps,
|
| 674 |
+
cfg_scale_text=cfg_scale_text,
|
| 675 |
+
cfg_scale_speaker=cfg_scale_speaker,
|
| 676 |
+
cfg_min_t=cfg_min_t,
|
| 677 |
+
cfg_max_t=cfg_max_t,
|
| 678 |
+
truncation_factor=truncation_factor,
|
| 679 |
+
rescale_k=rescale_k,
|
| 680 |
+
rescale_sigma=rescale_sigma,
|
| 681 |
+
speaker_k_scale=speaker_k_scale,
|
| 682 |
+
speaker_k_max_layers=speaker_k_max_layers,
|
| 683 |
+
speaker_k_min_t=speaker_k_min_t,
|
| 684 |
+
block_size=block_size,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
else:
|
| 688 |
+
raise ValueError(f"Unknown guidance mode: {guidance_mode}")
|
| 689 |
+
|
| 690 |
+
|
silentcipher/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .server import get_model
|
| 2 |
+
|
| 3 |
+
__version__ = '1.0.4'
|
silentcipher/model.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Layer(nn.Module):
|
| 7 |
+
def __init__(self, dim_in, dim_out, kernel_size, stride, padding):
|
| 8 |
+
super(Layer, self).__init__()
|
| 9 |
+
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
|
| 10 |
+
self.gate = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
|
| 11 |
+
self.bn = nn.BatchNorm2d(dim_out)
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return self.bn(self.conv(x) * torch.sigmoid(self.gate(x)))
|
| 15 |
+
|
| 16 |
+
class Encoder(nn.Module):
|
| 17 |
+
def __init__(self, out_dim=32, n_layers=3, message_dim=0, message_band_size=None, n_fft=None):
|
| 18 |
+
super(Encoder, self).__init__()
|
| 19 |
+
assert message_band_size is not None
|
| 20 |
+
assert n_fft is not None
|
| 21 |
+
self.message_band_size = message_band_size
|
| 22 |
+
main = [Layer(dim_in=1, dim_out=32, kernel_size=3, stride=1, padding=1)]
|
| 23 |
+
|
| 24 |
+
for i in range(n_layers-2):
|
| 25 |
+
main.append(Layer(dim_in=32, dim_out=32, kernel_size=3, stride=1, padding=1))
|
| 26 |
+
main.append(Layer(dim_in=32, dim_out=out_dim, kernel_size=3, stride=1, padding=1))
|
| 27 |
+
|
| 28 |
+
self.main = nn.Sequential(*main)
|
| 29 |
+
self.linear = nn.Linear(message_dim, message_band_size)
|
| 30 |
+
self.n_fft = n_fft
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
h = self.main(x)
|
| 34 |
+
return h
|
| 35 |
+
|
| 36 |
+
def transform_message(self, msg):
|
| 37 |
+
output = self.linear(msg.transpose(2, 3)).transpose(2, 3)
|
| 38 |
+
if self.message_band_size != self.n_fft // 2 + 1:
|
| 39 |
+
output = torch.nn.functional.pad(output, (0, 0, 0, self.n_fft // 2 + 1 - self.message_band_size))
|
| 40 |
+
return output
|
| 41 |
+
|
| 42 |
+
class CarrierDecoder(nn.Module):
|
| 43 |
+
def __init__(self, config, conv_dim, n_layers=4, message_band_size=1024):
|
| 44 |
+
super(CarrierDecoder, self).__init__()
|
| 45 |
+
self.config = config
|
| 46 |
+
self.message_band_size = message_band_size
|
| 47 |
+
layers = [Layer(dim_in=conv_dim, dim_out=96, kernel_size=3, stride=1, padding=1)]
|
| 48 |
+
|
| 49 |
+
for i in range(n_layers-2):
|
| 50 |
+
layers.append(Layer(dim_in=96, dim_out=96, kernel_size=3, stride=1, padding=1))
|
| 51 |
+
|
| 52 |
+
layers.append(Layer(dim_in=96, dim_out=1, kernel_size=1, stride=1, padding=0))
|
| 53 |
+
|
| 54 |
+
self.main = nn.Sequential(*layers)
|
| 55 |
+
|
| 56 |
+
def forward(self, x, message_sdr):
|
| 57 |
+
h = self.main(x)
|
| 58 |
+
|
| 59 |
+
if self.config.ensure_negative_message:
|
| 60 |
+
h = torch.abs(h)
|
| 61 |
+
|
| 62 |
+
h[:, :, self.message_band_size:, :] = 0
|
| 63 |
+
|
| 64 |
+
if not self.config.no_normalization:
|
| 65 |
+
h = h / torch.mean(h**2, dim=2, keepdim=True)**0.5 / (10**(message_sdr/20))
|
| 66 |
+
|
| 67 |
+
return h
|
| 68 |
+
|
| 69 |
+
class MsgDecoder(nn.Module):
|
| 70 |
+
def __init__(self, message_dim=0, message_band_size=None, channel_dim=128, num_layers=10):
|
| 71 |
+
super(MsgDecoder, self).__init__()
|
| 72 |
+
assert message_band_size is not None
|
| 73 |
+
self.message_band_size = message_band_size
|
| 74 |
+
|
| 75 |
+
main = [
|
| 76 |
+
nn.Dropout(0),
|
| 77 |
+
Layer(dim_in=1, dim_out=channel_dim, kernel_size=3, stride=1, padding=1)
|
| 78 |
+
]
|
| 79 |
+
for l in range(num_layers - 2):
|
| 80 |
+
main += [
|
| 81 |
+
nn.Dropout(0),
|
| 82 |
+
Layer(dim_in=channel_dim, dim_out=channel_dim, kernel_size=3, stride=1, padding=1),
|
| 83 |
+
]
|
| 84 |
+
main += [
|
| 85 |
+
nn.Dropout(0),
|
| 86 |
+
Layer(dim_in=channel_dim, dim_out=message_dim, kernel_size=3, stride=1, padding=1)
|
| 87 |
+
]
|
| 88 |
+
self.main = nn.Sequential(*main)
|
| 89 |
+
self.linear = nn.Linear(self.message_band_size, 1)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
|
| 93 |
+
h = self.main(x[:, :, :self.message_band_size])
|
| 94 |
+
h = self.linear(h.transpose(2, 3)).squeeze(3).unsqueeze(1)
|
| 95 |
+
return h
|
silentcipher/server.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from calendar import c
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import re
|
| 5 |
+
from tabnanny import check
|
| 6 |
+
import yaml
|
| 7 |
+
import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
from scipy import stats as st
|
| 11 |
+
import librosa
|
| 12 |
+
from pydub import AudioSegment
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
from .model import Encoder, CarrierDecoder, MsgDecoder
|
| 17 |
+
from .stft import STFT
|
| 18 |
+
|
| 19 |
+
class Model():
|
| 20 |
+
|
| 21 |
+
def __init__(self, config, device='cpu'):
|
| 22 |
+
|
| 23 |
+
self.config = config
|
| 24 |
+
self.device = device
|
| 25 |
+
|
| 26 |
+
self.n_messages = config.n_messages
|
| 27 |
+
self.model_type = config.model_type
|
| 28 |
+
self.message_dim = config.message_dim
|
| 29 |
+
self.message_len = config.message_len
|
| 30 |
+
|
| 31 |
+
# model dimensions
|
| 32 |
+
self.enc_conv_dim = 16
|
| 33 |
+
self.enc_num_repeat = 3
|
| 34 |
+
self.dec_c_num_repeat = self.enc_num_repeat
|
| 35 |
+
self.dec_m_conv_dim = 1
|
| 36 |
+
self.dec_m_num_repeat = 8
|
| 37 |
+
self.encoder_out_dim = 32
|
| 38 |
+
self.dec_c_conv_dim = 32*3
|
| 39 |
+
|
| 40 |
+
self.enc_c = Encoder(n_layers=self.config.enc_n_layers,
|
| 41 |
+
message_dim=self.message_dim,
|
| 42 |
+
out_dim=self.encoder_out_dim,
|
| 43 |
+
message_band_size=self.config.message_band_size,
|
| 44 |
+
n_fft=self.config.N_FFT)
|
| 45 |
+
|
| 46 |
+
self.dec_c = CarrierDecoder(config=self.config,
|
| 47 |
+
conv_dim=self.dec_c_conv_dim,
|
| 48 |
+
n_layers=self.config.dec_c_n_layers,
|
| 49 |
+
message_band_size=self.config.message_band_size)
|
| 50 |
+
|
| 51 |
+
self.dec_m = [MsgDecoder(message_dim=self.message_dim,
|
| 52 |
+
message_band_size=self.config.message_band_size) for _ in range(self.n_messages)]
|
| 53 |
+
# ------ make parallel ------
|
| 54 |
+
self.enc_c = self.enc_c.to(self.device)
|
| 55 |
+
self.dec_c = self.dec_c.to(self.device)
|
| 56 |
+
self.dec_m = [m.to(self.device) for m in self.dec_m]
|
| 57 |
+
|
| 58 |
+
self.average_energy_VCTK=0.002837200844477648
|
| 59 |
+
self.stft = STFT(self.config.N_FFT, self.config.HOP_LENGTH)
|
| 60 |
+
self.stft.to(self.device)
|
| 61 |
+
self.load_models(config.load_ckpt)
|
| 62 |
+
self.sr = self.config.SR
|
| 63 |
+
|
| 64 |
+
def letters_encoding(self, patch_len, message_lst):
|
| 65 |
+
|
| 66 |
+
"""
|
| 67 |
+
Encodes a list of messages into a compact representation and a padded representation.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
patch_len (int): The length of the patch.
|
| 71 |
+
message_lst (list): A list of messages to be encoded.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
tuple: A tuple containing two numpy arrays:
|
| 75 |
+
- message: A padded representation of the messages, where each message is repeated to match the patch length.
|
| 76 |
+
- message_compact: A compact representation of the messages, where each message is encoded as a one-hot vector.
|
| 77 |
+
|
| 78 |
+
Raises:
|
| 79 |
+
AssertionError: If the length of any message in message_lst is not equal to self.config.message_len - 1.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
message = []
|
| 83 |
+
message_compact = []
|
| 84 |
+
for i in range(self.n_messages):
|
| 85 |
+
|
| 86 |
+
assert len(message_lst[i]) == self.config.message_len - 1
|
| 87 |
+
index = np.concatenate((np.array(message_lst[i])+1, [0]))
|
| 88 |
+
one_hot = np.identity(self.message_dim)[index]
|
| 89 |
+
message_compact.append(one_hot)
|
| 90 |
+
if patch_len % self.message_len == 0:
|
| 91 |
+
message.append(np.tile(one_hot.T, (1, patch_len // self.message_len)))
|
| 92 |
+
else:
|
| 93 |
+
_ = np.tile(one_hot.T, (1, patch_len // self.message_len))
|
| 94 |
+
_ = np.concatenate([_, one_hot.T[:, 0:patch_len % self.message_len]], axis=1)
|
| 95 |
+
message.append(_)
|
| 96 |
+
message = np.stack(message)
|
| 97 |
+
message_compact = np.stack(message_compact)
|
| 98 |
+
# message = np.pad(message, ((0, 0), (0, 129 - self.message_dim), (0, 0)), 'constant')
|
| 99 |
+
return message, message_compact
|
| 100 |
+
|
| 101 |
+
def get_best_ps(self, y_one_sec):
|
| 102 |
+
|
| 103 |
+
"""
|
| 104 |
+
Calculates the best phase shift value for watermark decoding.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
y_one_sec (numpy.ndarray): Input audio signal.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
int: The best phase shift value.
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def check_accuracy(pred_values):
|
| 115 |
+
|
| 116 |
+
accuracy = 0
|
| 117 |
+
for i in range(pred_values.shape[1]):
|
| 118 |
+
unique, counts = np.unique(pred_values[:, i], return_counts=True)
|
| 119 |
+
accuracy += np.max(counts) / pred_values.shape[0]
|
| 120 |
+
|
| 121 |
+
return accuracy / pred_values.shape[1]
|
| 122 |
+
|
| 123 |
+
y = torch.FloatTensor(y_one_sec).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 124 |
+
max_accuracy = 0
|
| 125 |
+
final_phase_shift = 0
|
| 126 |
+
|
| 127 |
+
for ps in range(0, self.config.HOP_LENGTH, 10):
|
| 128 |
+
|
| 129 |
+
carrier, _ = self.stft.transform(y[0:1, 0:1, ps:].squeeze(1))
|
| 130 |
+
carrier = carrier[:, None]
|
| 131 |
+
|
| 132 |
+
for i in range(self.n_messages): # decode each msg_i using decoder_m_i
|
| 133 |
+
msg_reconst = self.dec_m[i](carrier)
|
| 134 |
+
pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
|
| 135 |
+
pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
|
| 136 |
+
pred_values = pred_values.reshape([-1, self.config.message_len])
|
| 137 |
+
cur_acc = check_accuracy(pred_values)
|
| 138 |
+
if cur_acc > max_accuracy:
|
| 139 |
+
max_accuracy = cur_acc
|
| 140 |
+
final_phase_shift = ps
|
| 141 |
+
|
| 142 |
+
return final_phase_shift
|
| 143 |
+
|
| 144 |
+
def get_confidence(self, pred_values, message):
|
| 145 |
+
"""
|
| 146 |
+
Calculates the confidence of the predicted values based on the provided message.
|
| 147 |
+
|
| 148 |
+
Parameters:
|
| 149 |
+
pred_values (numpy.ndarray): The predicted values.
|
| 150 |
+
message (str): The message used for prediction.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
float: The confidence score.
|
| 154 |
+
|
| 155 |
+
Raises:
|
| 156 |
+
AssertionError: If the length of the message is not equal to the number of columns in pred_values.
|
| 157 |
+
|
| 158 |
+
"""
|
| 159 |
+
assert len(message) == pred_values.shape[1], f'{len(message)} | {pred_values.shape}'
|
| 160 |
+
return np.mean((pred_values == message[None]).astype(np.float32)).item()
|
| 161 |
+
|
| 162 |
+
def sdr(self, orig, recon):
|
| 163 |
+
"""
|
| 164 |
+
Calculate the Signal-to-Distortion Ratio (SDR) between the original and reconstructed signals.
|
| 165 |
+
|
| 166 |
+
Parameters:
|
| 167 |
+
orig (numpy.ndarray): The original signal.
|
| 168 |
+
recon (numpy.ndarray): The reconstructed signal.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
float: The Signal-to-Distortion Ratio (SDR) value.
|
| 172 |
+
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
rms1 = ((np.mean(orig ** 2)) ** 0.5)
|
| 176 |
+
rms2 = ((np.mean((orig - recon) ** 2)) ** 0.5)
|
| 177 |
+
sdr = 20 * np.log10(rms1 / rms2)
|
| 178 |
+
return sdr
|
| 179 |
+
|
| 180 |
+
def load_audio(self, path):
|
| 181 |
+
"""
|
| 182 |
+
Load an audio file from the given path and return the audio array and sample rate.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
path (str): The path to the audio file.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
tuple: A tuple containing the audio array and sample rate.
|
| 189 |
+
|
| 190 |
+
"""
|
| 191 |
+
audio = AudioSegment.from_file(path)
|
| 192 |
+
audio_array, sr = (np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((-1, audio.channels)) / (
|
| 193 |
+
1 << (8 * audio.sample_width - 1))), audio.frame_rate
|
| 194 |
+
if audio_array.shape[1] == 1:
|
| 195 |
+
audio_array = audio_array[:, 0]
|
| 196 |
+
|
| 197 |
+
return audio_array, sr
|
| 198 |
+
|
| 199 |
+
def encode(self, in_path, out_path, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
|
| 200 |
+
"""
|
| 201 |
+
Encodes a message into an audio file.
|
| 202 |
+
|
| 203 |
+
Parameters:
|
| 204 |
+
- in_path (str): The path to the input audio file.
|
| 205 |
+
- out_path (str): The path to save the output audio file.
|
| 206 |
+
- message_list (list): A list of messages to be encoded into the audio file.
|
| 207 |
+
- message_sdr (float, optional): The Signal-to-Distortion Ratio (SDR) of the message. Defaults to None.
|
| 208 |
+
- calc_sdr (bool, optional): Whether to calculate the SDR of the encoded audio. Defaults to True.
|
| 209 |
+
- disable_checks (bool, optional): Whether to disable input checks. Defaults to False.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
- dict: A dictionary containing the status of the encoding process, the SDR value(s), the time taken for encoding, and the time taken per second of audio.
|
| 213 |
+
|
| 214 |
+
"""
|
| 215 |
+
y, orig_sr = self.load_audio(in_path)
|
| 216 |
+
start = time.time()
|
| 217 |
+
encoded_y, sdr = self.encode_wav(y, orig_sr, message_list=message_list, message_sdr=message_sdr, calc_sdr=calc_sdr, disable_checks=disable_checks)
|
| 218 |
+
time_taken = time.time() - start
|
| 219 |
+
sf.write(out_path, encoded_y, orig_sr)
|
| 220 |
+
|
| 221 |
+
if type(sdr) == list:
|
| 222 |
+
return {'status': True, 'sdr': [f'{sdr_i:.2f}' for sdr_i in sdr], 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
|
| 223 |
+
else:
|
| 224 |
+
return {'status': True, 'sdr': f'{sdr:.2f}', 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
|
| 225 |
+
|
| 226 |
+
def decode(self, path, phase_shift_decoding):
|
| 227 |
+
"""
|
| 228 |
+
Decode the audio file at the given path using phase shift decoding.
|
| 229 |
+
|
| 230 |
+
Parameters:
|
| 231 |
+
path (str): The path to the audio file.
|
| 232 |
+
phase_shift_decoding (bool): Flag indicating whether to use phase shift decoding.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
dictionary: A dictionary containing the decoded message status and value
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
y, orig_sr = self.load_audio(path)
|
| 239 |
+
|
| 240 |
+
return self.decode_wav(y, orig_sr, phase_shift_decoding)
|
| 241 |
+
|
| 242 |
+
def encode_wav(self, y_multi_channel, orig_sr, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
|
| 243 |
+
|
| 244 |
+
"""
|
| 245 |
+
Encodes a multi-channel audio waveform with a given message.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
y_multi_channel (numpy.ndarray): The multi-channel audio waveform to be encoded.
|
| 249 |
+
orig_sr (int): The original sampling rate of the audio waveform.
|
| 250 |
+
message_list (list): The list of messages to be encoded. Each message may correspond to a channel in the audio waveform.
|
| 251 |
+
message_sdr (float, optional): The signal-to-distortion ratio (SDR) of the message. If not provided, the default SDR from the configuration is used.
|
| 252 |
+
calc_sdr (bool, optional): Flag indicating whether to calculate the SDR of the encoded waveform. Defaults to True.
|
| 253 |
+
disable_checks (bool, optional): Flag indicating whether to disable input audio checks. Defaults to False.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
tuple: A tuple containing the encoded multi-channel audio waveform and the SDR (if calculated).
|
| 257 |
+
|
| 258 |
+
Raises:
|
| 259 |
+
AssertionError: If the number of messages does not match the number of channels in the input audio waveform.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
single_channel = False
|
| 263 |
+
if len(y_multi_channel.shape) == 1:
|
| 264 |
+
single_channel = True
|
| 265 |
+
y_multi_channel = y_multi_channel[:, None]
|
| 266 |
+
|
| 267 |
+
if message_sdr is None:
|
| 268 |
+
message_sdr = self.config.message_sdr
|
| 269 |
+
print(f'Using the default SDR of {self.config.message_sdr} dB')
|
| 270 |
+
|
| 271 |
+
if type(message_list[0]) == int:
|
| 272 |
+
message_list = [message_list]*y_multi_channel.shape[1]
|
| 273 |
+
|
| 274 |
+
y_watermarked_multi_channel = []
|
| 275 |
+
sdrs = []
|
| 276 |
+
|
| 277 |
+
assert len(message_list) == y_multi_channel.shape[1], f'{len(message_list)} | {y_multi_channel.shape[1]} Mismatch in the number of messages and channels in the input audio.'
|
| 278 |
+
|
| 279 |
+
for channel_i in range(y_multi_channel.shape[1]):
|
| 280 |
+
y = y_multi_channel[:, channel_i]
|
| 281 |
+
message = message_list[channel_i]
|
| 282 |
+
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
|
| 285 |
+
orig_y = y.copy()
|
| 286 |
+
if orig_sr != self.sr:
|
| 287 |
+
if orig_sr > self.sr:
|
| 288 |
+
print(f'WARNING! Reducing the sampling rate of the original audio from {orig_sr} -> {self.sr}. High frequency components may be lost!')
|
| 289 |
+
y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
|
| 290 |
+
original_power = np.mean(y**2)
|
| 291 |
+
|
| 292 |
+
if not disable_checks:
|
| 293 |
+
if original_power == 0:
|
| 294 |
+
print('WARNING! The input audio has a power of 0.This means the audio is likely just silence. Skipping encoding.')
|
| 295 |
+
return orig_y, 0
|
| 296 |
+
|
| 297 |
+
y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
|
| 298 |
+
y = torch.FloatTensor(y).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 299 |
+
carrier, carrier_phase = self.stft.transform(y.squeeze(1))
|
| 300 |
+
carrier = carrier[:, None]
|
| 301 |
+
carrier_phase = carrier_phase[:, None]
|
| 302 |
+
|
| 303 |
+
def binary_encode(mes):
|
| 304 |
+
binary_message = ''.join(['{0:08b}'.format(mes_i) for mes_i in mes])
|
| 305 |
+
four_bit_msg = []
|
| 306 |
+
for i in range(len(binary_message)//2):
|
| 307 |
+
four_bit_msg.append(int(binary_message[i*2:i*2+2], 2))
|
| 308 |
+
return four_bit_msg
|
| 309 |
+
|
| 310 |
+
binary_encoded_message = binary_encode(message)
|
| 311 |
+
|
| 312 |
+
msgs, msgs_compact = self.letters_encoding(carrier.shape[3], [binary_encoded_message])
|
| 313 |
+
msg_enc = torch.from_numpy(msgs[None]).to(self.device).float()
|
| 314 |
+
|
| 315 |
+
carrier_enc = self.enc_c(carrier) # encode the carrier
|
| 316 |
+
msg_enc = self.enc_c.transform_message(msg_enc)
|
| 317 |
+
|
| 318 |
+
merged_enc = torch.cat((carrier_enc, carrier.repeat(1, 32, 1, 1), msg_enc.repeat(1, 32, 1, 1)), dim=1) # concat encodings on features axis
|
| 319 |
+
|
| 320 |
+
message_info = self.dec_c(merged_enc, message_sdr)
|
| 321 |
+
if self.config.frame_level_normalization:
|
| 322 |
+
message_info = message_info*(torch.mean((carrier**2), dim=2, keepdim=True)**0.5) # *time_weighing
|
| 323 |
+
elif self.config.utterance_level_normalization:
|
| 324 |
+
message_info = message_info*(torch.mean((carrier**2), dim=(2,3), keepdim=True)**0.5) # *time_weighing
|
| 325 |
+
|
| 326 |
+
if self.config.ensure_negative_message:
|
| 327 |
+
message_info = -message_info
|
| 328 |
+
carrier_reconst = torch.nn.functional.relu(message_info + carrier) # decode carrier, output in stft domain
|
| 329 |
+
elif self.config.ensure_constrained_message:
|
| 330 |
+
message_info[message_info > carrier] = carrier[message_info > carrier]
|
| 331 |
+
message_info[-message_info > carrier] = -carrier[-message_info > carrier]
|
| 332 |
+
carrier_reconst = message_info + carrier # decode carrier, output in stft domain
|
| 333 |
+
assert torch.all(carrier_reconst >= 0), 'negative values found in carrier_reconst'
|
| 334 |
+
else:
|
| 335 |
+
carrier_reconst = torch.abs(message_info + carrier) # decode carrier, output in stft domain
|
| 336 |
+
|
| 337 |
+
self.stft.num_samples = y.shape[2]
|
| 338 |
+
|
| 339 |
+
y = self.stft.inverse(carrier_reconst.squeeze(1), carrier_phase.squeeze(1)).data.cpu().numpy()[0, 0]
|
| 340 |
+
y = y * np.sqrt(original_power / (self.average_energy_VCTK)) # Noise has a power of 5% power of VCTK samples
|
| 341 |
+
if orig_sr != self.sr:
|
| 342 |
+
y = librosa.resample(y, orig_sr = self.sr, target_sr = orig_sr)
|
| 343 |
+
|
| 344 |
+
if calc_sdr:
|
| 345 |
+
sdr = self.sdr(orig_y, y)
|
| 346 |
+
else:
|
| 347 |
+
sdr = 0
|
| 348 |
+
|
| 349 |
+
y_watermarked_multi_channel.append(y[:, None])
|
| 350 |
+
sdrs.append(sdr)
|
| 351 |
+
|
| 352 |
+
y_watermarked_multi_channel = np.concatenate(y_watermarked_multi_channel, axis=1)
|
| 353 |
+
|
| 354 |
+
if single_channel:
|
| 355 |
+
y_watermarked_multi_channel = y_watermarked_multi_channel[:, 0]
|
| 356 |
+
sdrs = sdrs[0]
|
| 357 |
+
|
| 358 |
+
return y_watermarked_multi_channel, sdrs
|
| 359 |
+
|
| 360 |
+
def decode_wav(self, y_multi_channel, orig_sr, phase_shift_decoding):
|
| 361 |
+
"""
|
| 362 |
+
Decode the given audio waveform to extract hidden messages.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
y_multi_channel (numpy.ndarray): The multi-channel audio waveform.
|
| 366 |
+
orig_sr (int): The original sample rate of the audio waveform.
|
| 367 |
+
phase_shift_decoding (str): Flag indicating whether to perform phase shift decoding.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
dict or list: A list of dictionary containing the decoded messages, confidences, and status for each channel if the input is multi-channel.
|
| 371 |
+
Otherwise, a dictionary containing the decoded messages, confidences, and status for a single channel.
|
| 372 |
+
|
| 373 |
+
Raises:
|
| 374 |
+
Exception: If the decoding process fails.
|
| 375 |
+
|
| 376 |
+
"""
|
| 377 |
+
single_channel = False
|
| 378 |
+
if len(y_multi_channel.shape) == 1:
|
| 379 |
+
single_channel = True
|
| 380 |
+
y_multi_channel = y_multi_channel[:, None]
|
| 381 |
+
|
| 382 |
+
results = []
|
| 383 |
+
|
| 384 |
+
for channel_i in range(y_multi_channel.shape[1]):
|
| 385 |
+
y = y_multi_channel[:, channel_i]
|
| 386 |
+
try:
|
| 387 |
+
with torch.no_grad():
|
| 388 |
+
if orig_sr != self.sr:
|
| 389 |
+
y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
|
| 390 |
+
original_power = np.mean(y**2)
|
| 391 |
+
y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
|
| 392 |
+
if phase_shift_decoding and phase_shift_decoding != 'false':
|
| 393 |
+
ps = self.get_best_ps(y)
|
| 394 |
+
else:
|
| 395 |
+
ps = 0
|
| 396 |
+
y = torch.FloatTensor(y[ps:]).unsqueeze(0).unsqueeze(0).to(self.device)
|
| 397 |
+
carrier, _ = self.stft.transform(y.squeeze(1))
|
| 398 |
+
carrier = carrier[:, None]
|
| 399 |
+
|
| 400 |
+
msg_reconst_list = []
|
| 401 |
+
confidence = []
|
| 402 |
+
|
| 403 |
+
for i in range(self.n_messages): # decode each msg_i using decoder_m_i
|
| 404 |
+
msg_reconst = self.dec_m[i](carrier)
|
| 405 |
+
pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
|
| 406 |
+
pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
|
| 407 |
+
pred_values = pred_values.reshape([-1, self.config.message_len])
|
| 408 |
+
|
| 409 |
+
ord_values = st.mode(pred_values, keepdims=False).mode
|
| 410 |
+
end_char = np.min(np.nonzero(ord_values == 0)[0])
|
| 411 |
+
confidence.append(self.get_confidence(pred_values, ord_values))
|
| 412 |
+
if end_char == self.config.message_len:
|
| 413 |
+
ord_values = ord_values[:self.config.message_len-1]
|
| 414 |
+
else:
|
| 415 |
+
ord_values = np.concatenate([ord_values[end_char+1:], ord_values[:end_char]], axis=0)
|
| 416 |
+
|
| 417 |
+
# pred_values = ''.join([chr(v + 64) for v in ord_values])
|
| 418 |
+
msg_reconst_list.append((ord_values - 1).tolist())
|
| 419 |
+
|
| 420 |
+
def convert_to_8_bit_segments(msg_list):
|
| 421 |
+
segment_message_list = []
|
| 422 |
+
for msg_list_i in msg_list:
|
| 423 |
+
binary_format = ''.join(['{0:02b}'.format(mes_i) for mes_i in msg_list_i])
|
| 424 |
+
eight_bit_segments = [int(binary_format[i*8:i*8+8], 2) for i in range(len(binary_format)//8)]
|
| 425 |
+
segment_message_list.append(eight_bit_segments)
|
| 426 |
+
return segment_message_list
|
| 427 |
+
msg_reconst_list = convert_to_8_bit_segments(msg_reconst_list)
|
| 428 |
+
|
| 429 |
+
results.append({'messages': msg_reconst_list, 'confidences': confidence, 'status': True})
|
| 430 |
+
except:
|
| 431 |
+
results.append({'messages': [], 'confidences': [], 'error': 'Could not find message', 'status': False})
|
| 432 |
+
|
| 433 |
+
if single_channel:
|
| 434 |
+
results = results[0]
|
| 435 |
+
|
| 436 |
+
return results
|
| 437 |
+
|
| 438 |
+
def convert_dataparallel_to_normal(self, checkpoint):
|
| 439 |
+
|
| 440 |
+
return {i[len('module.'):] if i.startswith('module.') else i: checkpoint[i] for i in checkpoint }
|
| 441 |
+
|
| 442 |
+
def load_models(self, ckpt_dir):
|
| 443 |
+
|
| 444 |
+
self.enc_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "enc_c.ckpt"), map_location=self.device)))
|
| 445 |
+
self.dec_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "dec_c.ckpt"), map_location=self.device)))
|
| 446 |
+
for i,m in enumerate(self.dec_m):
|
| 447 |
+
m.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, f"dec_m_{i}.ckpt"), map_location=self.device)))
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def get_model(model_type='44.1k', ckpt_path='../Models/44_1_khz/73999_iteration', config_path='../Models/44_1_khz/73999_iteration/hparams.yaml', device='cpu'):
|
| 451 |
+
|
| 452 |
+
if model_type == '44.1k':
|
| 453 |
+
if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
|
| 454 |
+
print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
|
| 455 |
+
from huggingface_hub import snapshot_download
|
| 456 |
+
folder_dir = snapshot_download(repo_id="sony/silentcipher")
|
| 457 |
+
ckpt_path = os.path.join(folder_dir, '44_1_khz/73999_iteration')
|
| 458 |
+
config_path = os.path.join(folder_dir, '44_1_khz/73999_iteration/hparams.yaml')
|
| 459 |
+
|
| 460 |
+
config = yaml.safe_load(open(config_path))
|
| 461 |
+
config = argparse.Namespace(**config)
|
| 462 |
+
config.load_ckpt = ckpt_path
|
| 463 |
+
model = Model(config, device)
|
| 464 |
+
elif model_type == '16k':
|
| 465 |
+
if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
|
| 466 |
+
print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
|
| 467 |
+
from huggingface_hub import snapshot_download
|
| 468 |
+
folder_dir = snapshot_download(repo_id="sony/silentcipher")
|
| 469 |
+
ckpt_path = os.path.join(folder_dir, '16_khz/97561_iteration')
|
| 470 |
+
config_path = os.path.join(folder_dir, '16_khz/97561_iteration/hparams.yaml')
|
| 471 |
+
|
| 472 |
+
config = yaml.safe_load(open(config_path))
|
| 473 |
+
config = argparse.Namespace(**config)
|
| 474 |
+
config.load_ckpt = ckpt_path
|
| 475 |
+
|
| 476 |
+
model = Model(config, device)
|
| 477 |
+
else:
|
| 478 |
+
print('Please specify a valid model_type [44.1k, 16k]')
|
| 479 |
+
|
| 480 |
+
return model
|
silentcipher/stft.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class Singleton(type):
|
| 4 |
+
_instances = {}
|
| 5 |
+
def __call__(cls, *args, **kwargs):
|
| 6 |
+
if cls not in cls._instances:
|
| 7 |
+
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
| 8 |
+
return cls._instances[cls]
|
| 9 |
+
|
| 10 |
+
class STFT(torch.nn.Module, metaclass=Singleton):
|
| 11 |
+
def __init__(self, filter_length=1024, hop_length=512):
|
| 12 |
+
super(STFT, self).__init__()
|
| 13 |
+
|
| 14 |
+
self.filter_length = filter_length
|
| 15 |
+
self.hop_len = hop_length
|
| 16 |
+
self.win_len = filter_length
|
| 17 |
+
self.window = torch.hann_window(self.win_len)
|
| 18 |
+
self.num_samples = -1
|
| 19 |
+
|
| 20 |
+
def transform(self, x):
|
| 21 |
+
x = torch.nn.functional.pad(x, (0, self.win_len - x.shape[1]%self.win_len))
|
| 22 |
+
fft = torch.stft(x, self.filter_length, self.hop_len, self.win_len, window=self.window.to(x.device), return_complex=True)
|
| 23 |
+
|
| 24 |
+
real_part, imag_part = fft.real, fft.imag
|
| 25 |
+
|
| 26 |
+
squared = real_part**2 + imag_part**2
|
| 27 |
+
additive_epsilon = torch.ones_like(squared) * (squared == 0).float() * 1e-24
|
| 28 |
+
magnitude = torch.sqrt(squared + additive_epsilon) - torch.sqrt(additive_epsilon)
|
| 29 |
+
|
| 30 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)).float()
|
| 31 |
+
return magnitude, phase
|
| 32 |
+
|
| 33 |
+
def inverse(self, magnitude, phase):
|
| 34 |
+
|
| 35 |
+
recombine_magnitude_phase = magnitude*torch.cos(phase) + 1j*magnitude*torch.sin(phase)
|
| 36 |
+
inverse_transform = torch.istft(recombine_magnitude_phase, self.filter_length, hop_length=self.hop_len, win_length=self.win_len, window=self.window.to(magnitude.device)).unsqueeze(1) # , length=self.num_samples
|
| 37 |
+
padding = self.win_len - (self.num_samples % self.win_len)
|
| 38 |
+
inverse_transform = inverse_transform[:, :, :-padding]
|
| 39 |
+
return inverse_transform
|
| 40 |
+
|
text_presets.txt
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Reading | [S1] The old lighthouse keeper had seen many storms in his thirty years on the rock, but nothing like this. The fog rolled in thick as wool, swallowing the beam of light before it could reach the churning waves below. Then he heard it, three short bells from the channel, where no ship should be at this hour. He grabbed his lantern and peered into the mist, his heart pounding. Something was out there, something that shouldn't exist.
|
| 2 |
+
|
| 3 |
+
Reading | [S1] Deep beneath the ocean's surface, where sunlight fades to perpetual twilight, extraordinary creatures have evolved in ways that defy imagination. Bioluminescent jellyfish pulse with ethereal blue light, while giant squid hunt in the crushing darkness. At depths of over two miles, the pressure is immense, enough to collapse a submarine, yet life persists.
|
| 4 |
+
|
| 5 |
+
Reading | [S1] The telegram arrived on a Tuesday morning in June, nineteen forty-three. Margaret's hands trembled as she tore open the envelope, dreading the words she knew might be inside. Her brother had shipped out to North Africa six months ago, and his letters had grown increasingly sparse.
|
| 6 |
+
|
| 7 |
+
Reading | [S1] The ancient map showed a path through the Whispering Mountains that no living traveler had taken in generations. Legends spoke of a hidden valley where time moved differently, where a single day in the outside world meant years had passed within. As dawn broke over the snow-capped peaks, Elena shouldered her pack and began the ascent. Whatever waited at the journey's end, whether treasure or peril,
|
| 8 |
+
|
| 9 |
+
Cartoon | [S1] After giving everything some more thought, I've decided it's in the best interest of humanity to acquire Nexus AI. (laughs) I've spoken with the CEO and he's on board. Well (laughs), at least that's the impression he gave initially.
|
| 10 |
+
|
| 11 |
+
Single (Disfluent) | [S1] ... explore how we can design, create interfaces that are not confusing, but at the same time can be powerful. Um, you know, I think, uh, in the, the famous, um, usability book, it's, uh, it's this, um, um, oh, geez, I'm, I'm blanking on the term, uh, uh, the, the rule about, um, uh, it's like the simplicity rule. I can't recall. Oh, cognitive load maybe.
|
| 12 |
+
|
| 13 |
+
Single (Disfluent) | [S1] Uh, complacency when the motivation isn't structured properly. Like for example, if you, if you're in the cor- if you work in the corporation for many years, a lot of corporate employees, they just, they're, they're aiming for that stock vesting and they're, they're doing just a sufficient job to, to, to reach that vesting and, and they don't, they're not performing any better than that. Um, and so I think, um, that showed me an important insight. Yeah.
|
| 14 |
+
|
| 15 |
+
Single (Disfluent) | [S1] We see the pattern of revelations, major shifts. I think Neptune in Pisces, which that transit has been happening all of 2021, and Neptune will remain in the sign of Pisces until March of 2029. So it's several years more of this transit. And what it brings is a lot of things, you know, the thing that I tend to emphasize is the profound dissolution or profound changes
|
| 16 |
+
|
| 17 |
+
Single (Disfluent) | [S1] I asked her, "Do you have like a phrase you use," and she mentioned she actually does. Like when things get tense, when there's like a moment, like if her, if her roommate is like venting about work drama or just like is stressed, and her, her roommate like deals with anxiety, I'm like, "Oh, this is probably how it feels to live with me." But, um, and like if, if, if things are rough, like she'll internally just like use this practice where she's like, like, "Not my problem, not mine to carry, not mine to handle, not mine to change." Like she'll sort of repeat that. So that's interesting.
|
| 18 |
+
|
| 19 |
+
Single (Disfluent) | [S1] If I examine the, the, if, if you examine the range of options, uh, beginning from, like, say, individual all the way, right? There will be some revenue stream, uh, there will be some purchase, there'll be some hardware profit margin for someone who creates a smart product, um, uh, there will be memberships, personal and business, uh, and then there'll be usage-based, right? So I still believe that that's kinda how, those are all the metrics. To your point, what is a membership? Up to now, folks
|
| 20 |
+
|
| 21 |
+
Single (Disfluent) | [S1] I think, if, if we can keep it under 25 points allowed, sure, our odds improve significantly. We wouldn't need to put up huge numbers ourselves, or at least that's the theory. And I should, I want to share some other stats which might be a bit outside our current discussion, but regarding this compared to 2018, the team's final four games that year, they managed 18 points total.
|
| 22 |
+
|
| 23 |
+
Singing | [S1] (singing) Amazing grace, how sweet the sound, that saved a wretch like me. I once was lost, but now am found, was blind, but now I see.
|
| 24 |
+
|
| 25 |
+
Conversation | [S1] Alright then. So, so 18 years you spent in that, uh, in that role, but alongside that in, in, was it while you were working that position in '93, you started doing some work with the network? [S2] Uh, yes. It was somewhere around '93. I, I, I played tennis pretty well, you know? I, I, I competed as a tennis player. And the, I got a chance to do some broadcasting over in Brisbane.
|
| 26 |
+
|
| 27 |
+
Conversation | [S1] ... that will provide the analytics component- [S2] Right. [S1] ... to ideally get you to adopt some of their other tools. And- [S2] (laughs) [S1] ... some of those features are valuable too. [S2] That's interesting. [S1] Mailchimp, I mean, that's campaign manage-, uh, not exactly campaign management, but messaging platforms. [S2] Uh-huh. [S1] The, the companies that are, you know,
|
| 28 |
+
|
| 29 |
+
Conversation | [S1] They were like, they were pumped for it, going wild for it, and it disappeared immediately. [S2] Yeah, I think it's about people understanding what's available first. Um... [S1] I think the finish on that one too was really nice. [S2] Yeah. [S1] I mean, that was pretty awesome. [S2] Have you seen those new editions?
|
| 30 |
+
|
| 31 |
+
Conversation | [S1] He was just practicing with them and they were on rotation. [S2] So that was probably in January. [S1] I think startup stereotypes, there is some like that, but some of them, I think they need to be changed. Like we don't all work twenty-hour days. [S2] No, they just need to, it's called not, it's based in Silicon Valley. [S1] Yeah. [S2] But the stereotypes would apply if they, it was called Techlife- [S1] Palo Alto. [S2] ... Cupertino or Mountain View, California.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
Conversation | [S1] That's a nice overview. [S2] We were at the downtown cinema. [S1] By that, you mean the one in Riverside? [S2] Yeah. [S1] Yeah. So not exactly downtown. [S2] Not exactly downtown, yeah. [S1] I know a little bit about that area. [S2] (laughs) [S1] You know, Millbrook doesn't have a cinema. [S2] (laughs) It's the closest one for us. It's the closest. [S1] Yeah, that's true. [S2] The most nearby. [S1] Riverside is nearby. [S2] Riverside's close. [S1] That's fair. [S2] Support nearby. [S1] You can say, say Riverside, definitely. [S2] Well, yeah, fair enough.
|
| 35 |
+
|
| 36 |
+
Conversation | [S1] But they also, they also discovered, um, they also discovered like patterns in the desert, um, near Peru, like in the Atacama Desert. [S2] Yeah. [S1] Um, and like, it was like, of like perfectly, like, geo- geometric shapes. And they're like, "Yo, this is definitely not like formed by wind. This has to be artificial." [S2] Yeah, it's too precise.
|
| 37 |
+
|
| 38 |
+
Conversation | [S1] 'Cause I, yeah, there, there has to be a way that they can just make the, the system recognize that, no, you did not earn this- [S2] (laughs) [S1] ... on your own. You still have to go and complete one if you want it for your own- [S2] Right. [S1] ... like, profile. [S2] Right. Mm-hmm. [S1] So, yeah. [S2] Um, yeah. So let's actually move into multiplayer.
|
| 39 |
+
|
| 40 |
+
Conversation | [S1] Yeah. [S2] Yeah. TRS as a whole is just relaxed. [S1] But anyway, you know that Mirror app that launched and then got removed like a month later? [S2] Mirror, what, like, to your future? [S1] Yeah. [S2] Oh. [S1] So basically, there was an app, there's a show coming out. [S2] This is a show. [S1] Coming, I don't know what it is. [S2] Yeah, yeah, yeah. [S1] Like 2026 or something. Basically, Marcus, have you heard about this? [S2] I'm sorry, I don't know. No, I don't have an, it's an app- [S1] Okay, so I'll explain. I'll explain. [S2] Yeah. [S1] For context. So there's this app that launched in terms of the show called Mirror.
|
| 41 |
+
|
| 42 |
+
Conversation | [S1] Jamie Patterson, right? [S2] No, I know where- [S1] I know where- [S2] ... Patterson works as well. I know where- [S1] I know- I know he used to work near- on this street, and this is a weird street. [S2] The only person who I don't know where they work, Jamie. But anyway, why are we even talking about who works where? [S1] It was a- it was- it was a really weird street name where Jamie worked. [S2] I- I drove past this street on my commute. [S1] No, you didn't. [S2] Yeah, I did. [S1] No, you drove past the street that my street is down the street of. [S2] Nice. There's, like, one street in Oakfield, I think I'll be able to find it, mate.
|