jordand commited on
Commit
60cc71a
·
verified ·
1 Parent(s): 14e4ac4

Upload 21 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ prompt_audio/EARS[[:space:]]p004[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
37
+ prompt_audio/EARS[[:space:]]p005[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
38
+ prompt_audio/EARS[[:space:]]p028[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
39
+ prompt_audio/EARS[[:space:]]p036[[:space:]]freeform.mp3 filter=lfs diff=lfs merge=lfs -text
40
+ prompt_audio/expresso_02_ex03-ex01_calm_005.wav filter=lfs diff=lfs merge=lfs -text
41
+ prompt_audio/freesound_demon_chant(use_forcespeaker).mp3 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ Copyright 2025 Jordan Darefsky
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
LICENSE-APACHE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This Apache 2.0 license applies only to autoencoder.py
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright 2024 Fish Audio
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
app.py ADDED
The diff for this file is too large to render. See raw diff
 
autoencoder.py ADDED
@@ -0,0 +1,1227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2025 Jordan Darefsky
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # This file contains portions adapted from:
5
+ # • Descript Audio Codec (DAC) — MIT License (full text appended below)
6
+ # • Fish-Speech S1 DAC Autoencoder — reference implementation (Apache-2.0 / CC-BY-NC),
7
+ # rewritten here in a single-file Torch module for interoperability and transparency.
8
+ #
9
+ # OVERALL LICENSE (this file): Apache-2.0, except where explicitly marked:
10
+ # # SPDX-License-Identifier: MIT
11
+ # Keep these notices and the embedded MIT text if you redistribute this file.
12
+
13
+ # NOTE (style/provenance):
14
+ # Code in this module has been largely copy-and-pasted from the Fish-S1-DAC and DAC repositories,
15
+ # and refactored with help from ChatGPT/Claude (these models also helped with licensing).
16
+ # Thus, it stylistically differs from the rest of the codebase (I'm not even sure about internal consistency)
17
+ # and is likely much messier than it would have been had it been written from scratch.
18
+
19
+
20
+ from __future__ import annotations
21
+
22
+ import math
23
+ from dataclasses import dataclass
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+ from torch import Tensor, nn
29
+ from torch.nn import functional as F
30
+ from torch.nn.utils.parametrizations import weight_norm
31
+ from torch.nn.utils.parametrize import remove_parametrizations
32
+
33
+ from einops import rearrange
34
+
35
+
36
+ # --------------------------------------------------------------------
37
+ # Shared helpers
38
+ # --------------------------------------------------------------------
39
+
40
+ def find_multiple(n: int, k: int) -> int:
41
+ return n if n % k == 0 else n + k - (n % k)
42
+
43
+ def unpad1d(x: Tensor, paddings: Tuple[int, int]) -> Tensor:
44
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
45
+ padding_left, padding_right = paddings
46
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
47
+ assert (padding_left + padding_right) <= x.shape[-1]
48
+ end = x.shape[-1] - padding_right
49
+ return x[..., padding_left:end]
50
+
51
+ def get_extra_padding_for_conv1d(
52
+ x: Tensor, kernel_size: int, stride: int, padding_total: int = 0
53
+ ) -> int:
54
+ """See pad_for_conv1d; enough right pad so striding evenly covers length."""
55
+ length = x.shape[-1]
56
+ n_frames = (length - kernel_size + padding_total) / stride + 1
57
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
58
+ return ideal_length - length
59
+
60
+ def pad1d(
61
+ x: Tensor,
62
+ paddings: Tuple[int, int],
63
+ mode: str = "zeros",
64
+ value: float = 0.0,
65
+ ) -> Tensor:
66
+ """
67
+ Reflect‑safe 1D pad: if reflect would underflow on small inputs, insert
68
+ temporary right zero-pad before reflecting.
69
+ """
70
+ length = x.shape[-1]
71
+ padding_left, padding_right = paddings
72
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
73
+ if mode == "reflect":
74
+ max_pad = max(padding_left, padding_right)
75
+ extra_pad = 0
76
+ if length <= max_pad:
77
+ extra_pad = max_pad - length + 1
78
+ x = F.pad(x, (0, extra_pad))
79
+ padded = F.pad(x, (padding_left, padding_right), mode, value)
80
+ end = padded.shape[-1] - extra_pad
81
+ return padded[..., :end]
82
+ else:
83
+ return F.pad(x, (padding_left, padding_right), mode, value)
84
+
85
+
86
+ # --------------------------------------------------------------------
87
+ # DAC Layers (adapted) — MIT
88
+ # Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py
89
+ # SPDX-License-Identifier: MIT
90
+ # --------------------------------------------------------------------
91
+
92
+ def WNConv1d(*args, **kwargs):
93
+ return weight_norm(nn.Conv1d(*args, **kwargs))
94
+
95
+ def WNConvTranspose1d(*args, **kwargs):
96
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
97
+
98
+ @torch.jit.script
99
+ def snake(x: Tensor, alpha: Tensor) -> Tensor:
100
+ shape = x.shape
101
+ x = x.reshape(shape[0], shape[1], -1)
102
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
103
+ x = x.reshape(shape)
104
+ return x
105
+
106
+ class Snake1d(nn.Module):
107
+ def __init__(self, channels: int):
108
+ super().__init__()
109
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
110
+ def forward(self, x: Tensor) -> Tensor:
111
+ return snake(x, self.alpha)
112
+
113
+ # --------------------------------------------------------------------
114
+ # DAC Vector Quantize (adapted) — MIT
115
+ # Original: https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/quantize.py
116
+ # SPDX-License-Identifier: MIT
117
+ # --------------------------------------------------------------------
118
+
119
+ class VectorQuantize(nn.Module):
120
+ """
121
+ VQ with factorized, l2-normalized codes (ViT‑VQGAN style).
122
+ I/O in (B, D, T).
123
+ """
124
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
125
+ super().__init__()
126
+ self.codebook_size = codebook_size
127
+ self.codebook_dim = codebook_dim
128
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
129
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
130
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
131
+
132
+ def forward(self, z: Tensor):
133
+ z_e = self.in_proj(z) # (B, D, T)
134
+ z_q, indices = self.decode_latents(z_e)
135
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
136
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
137
+ z_q = z_e + (z_q - z_e).detach() # straight‑through
138
+ z_q = self.out_proj(z_q)
139
+ return z_q, commitment_loss, codebook_loss, indices, z_e
140
+
141
+ def embed_code(self, embed_id: Tensor) -> Tensor:
142
+ return F.embedding(embed_id, self.codebook.weight)
143
+
144
+ def decode_code(self, embed_id: Tensor) -> Tensor:
145
+ return self.embed_code(embed_id).transpose(1, 2)
146
+
147
+ def decode_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
148
+ encodings = rearrange(latents, "b d t -> (b t) d")
149
+ codebook = self.codebook.weight
150
+ encodings = F.normalize(encodings)
151
+ codebook = F.normalize(codebook)
152
+ dist = (
153
+ encodings.pow(2).sum(1, keepdim=True)
154
+ - 2 * encodings @ codebook.t()
155
+ + codebook.pow(2).sum(1, keepdim=True).t()
156
+ )
157
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
158
+ z_q = self.decode_code(indices)
159
+ return z_q, indices
160
+
161
+
162
+ class ResidualVectorQuantize(nn.Module):
163
+ """SoundStream-style residual VQ stack."""
164
+ def __init__(
165
+ self,
166
+ input_dim: int = 512,
167
+ n_codebooks: int = 9,
168
+ codebook_size: int = 1024,
169
+ codebook_dim: Union[int, List[int]] = 8,
170
+ quantizer_dropout: float = 0.0,
171
+ ):
172
+ super().__init__()
173
+ if isinstance(codebook_dim, int):
174
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
175
+
176
+ self.n_codebooks = n_codebooks
177
+ self.codebook_dim = codebook_dim
178
+ self.codebook_size = codebook_size
179
+
180
+ self.quantizers = nn.ModuleList([
181
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
182
+ for i in range(n_codebooks)
183
+ ])
184
+ self.quantizer_dropout = quantizer_dropout
185
+
186
+ def forward(self, z: Tensor, n_quantizers: Optional[int] = None):
187
+ z_q = 0
188
+ residual = z
189
+ commitment_loss = 0
190
+ codebook_loss = 0
191
+
192
+ codebook_indices = []
193
+ latents = []
194
+
195
+ if n_quantizers is None:
196
+ n_quantizers = self.n_codebooks
197
+ if self.training:
198
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
199
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
200
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
201
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
202
+ n_quantizers = n_quantizers.to(z.device)
203
+
204
+ for i, quantizer in enumerate(self.quantizers):
205
+ if self.training is False and i >= n_quantizers:
206
+ break
207
+
208
+ z_q_i, commit_i, codebk_i, indices_i, z_e_i = quantizer(residual)
209
+
210
+ mask = (torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers)
211
+ z_q = z_q + z_q_i * mask[:, None, None]
212
+ residual = residual - z_q_i
213
+
214
+ commitment_loss += (commit_i * mask).mean()
215
+ codebook_loss += (codebk_i * mask).mean()
216
+
217
+ codebook_indices.append(indices_i)
218
+ latents.append(z_e_i)
219
+
220
+ codes = torch.stack(codebook_indices, dim=1)
221
+ latents = torch.cat(latents, dim=1)
222
+
223
+ return z_q, codes, latents, commitment_loss, codebook_loss
224
+
225
+ def from_codes(self, codes: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
226
+ z_q = 0.0
227
+ z_p = []
228
+ n_codebooks = codes.shape[1]
229
+ for i in range(n_codebooks):
230
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
231
+ z_p.append(z_p_i)
232
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
233
+ z_q = z_q + z_q_i
234
+ return z_q, torch.cat(z_p, dim=1), codes
235
+
236
+ def from_latents(self, latents: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
237
+ z_q = 0
238
+ z_p = []
239
+ codes = []
240
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
241
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
242
+ for i in range(n_codebooks):
243
+ j, k = dims[i], dims[i + 1]
244
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
245
+ z_p.append(z_p_i)
246
+ codes.append(codes_i)
247
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
248
+ z_q = z_q + z_q_i
249
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
250
+
251
+
252
+ # --------------------------------------------------------------------
253
+ # S1 DAC rvq
254
+ # --------------------------------------------------------------------
255
+
256
+ @dataclass
257
+ class VQResult:
258
+ z: Tensor
259
+ codes: Tensor
260
+ latents: Tensor
261
+ codebook_loss: Tensor
262
+ commitment_loss: Tensor
263
+ semantic_distill_z: Optional[Tensor] = None
264
+
265
+
266
+ class CausalConvNet(nn.Module):
267
+ def __init__(
268
+ self,
269
+ in_channels,
270
+ out_channels,
271
+ kernel_size,
272
+ dilation=1,
273
+ stride=1,
274
+ groups=1,
275
+ padding=None,
276
+ ):
277
+ super().__init__()
278
+ self.conv = nn.Conv1d(
279
+ in_channels, out_channels, kernel_size,
280
+ stride=stride, dilation=dilation, groups=groups,
281
+ )
282
+ self.stride = stride
283
+ self.kernel_size = (kernel_size - 1) * dilation + 1
284
+ self.dilation = dilation
285
+ self.padding = self.kernel_size - self.stride
286
+
287
+ def forward(self, x: Tensor) -> Tensor:
288
+ pad = self.padding
289
+ extra = get_extra_padding_for_conv1d(x, self.kernel_size, self.stride, pad)
290
+ x = pad1d(x, (pad, extra), mode="constant", value=0)
291
+ return self.conv(x).contiguous()
292
+
293
+ def weight_norm(self, name="weight", dim=0):
294
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
295
+ return self
296
+
297
+ def remove_weight_norm(self):
298
+ self.conv = remove_parametrizations(self.conv)
299
+ return self
300
+
301
+
302
+ class CausalTransConvNet(nn.Module):
303
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None):
304
+ super().__init__()
305
+ self.conv = nn.ConvTranspose1d(
306
+ in_channels, out_channels, kernel_size,
307
+ stride=stride, dilation=dilation
308
+ )
309
+ self.stride = stride
310
+ self.kernel_size = kernel_size
311
+
312
+ def forward(self, x: Tensor) -> Tensor:
313
+ x = self.conv(x)
314
+ pad = self.kernel_size - self.stride
315
+ padding_right = math.ceil(pad)
316
+ padding_left = pad - padding_right
317
+ x = unpad1d(x, (padding_left, padding_right))
318
+ return x.contiguous()
319
+
320
+ def weight_norm(self, name="weight", dim=0):
321
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
322
+ return self
323
+
324
+ def remove_weight_norm(self):
325
+ self.conv = remove_parametrizations(self.conv)
326
+ return self
327
+
328
+
329
+ def CausalWNConv1d(*args, **kwargs):
330
+ return CausalConvNet(*args, **kwargs).weight_norm()
331
+
332
+ def CausalWNConvTranspose1d(*args, **kwargs):
333
+ return CausalTransConvNet(*args, **kwargs).weight_norm()
334
+
335
+ class ConvNeXtBlock(nn.Module):
336
+ r"""ConvNeXt Block (1D).
337
+ DwConv -> (N, C, L) → (N, L, C) -> LN -> Linear -> GELU -> Linear -> (N, C, L) with residual
338
+ """
339
+ def __init__(
340
+ self,
341
+ dim: int,
342
+ layer_scale_init_value: float = 1e-6,
343
+ mlp_ratio: float = 4.0,
344
+ kernel_size: int = 7,
345
+ dilation: int = 1,
346
+ ):
347
+ super().__init__()
348
+ convnet_type = CausalConvNet
349
+ self.dwconv = convnet_type(
350
+ dim, dim, kernel_size=kernel_size,
351
+ groups=dim, dilation=dilation,
352
+ ) # depthwise conv
353
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
354
+ self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim))
355
+ self.act = nn.GELU()
356
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
357
+ self.gamma = (
358
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
359
+ if layer_scale_init_value > 0 else None
360
+ )
361
+
362
+ def forward(self, x: Tensor, apply_residual: bool = True) -> Tensor:
363
+ inp = x
364
+ x = self.dwconv(x)
365
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
366
+ x = self.norm(x)
367
+ x = self.pwconv1(x)
368
+ x = self.act(x)
369
+ x = self.pwconv2(x)
370
+ if self.gamma is not None:
371
+ x = self.gamma * x
372
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
373
+ if apply_residual:
374
+ x = inp + x
375
+ return x
376
+
377
+
378
+ class DownsampleResidualVectorQuantize(nn.Module):
379
+ def __init__(
380
+ self,
381
+ input_dim: int = 1024,
382
+ n_codebooks: int = 9,
383
+ codebook_dim: int = 8,
384
+ quantizer_dropout: float = 0.5,
385
+ codebook_size: int = 1024,
386
+ semantic_codebook_size: int = 4096,
387
+ downsample_factor: Tuple[int, ...] = (2, 2),
388
+ downsample_dims: Optional[Tuple[int, ...]] = None,
389
+ pre_module: Optional[nn.Module] = None,
390
+ post_module: Optional[nn.Module] = None,
391
+ semantic_predictor_module: Optional[nn.Module] = None,
392
+ ):
393
+ super().__init__()
394
+
395
+ if downsample_dims is None:
396
+ downsample_dims = tuple(input_dim for _ in range(len(downsample_factor)))
397
+
398
+ all_dims = (input_dim,) + tuple(downsample_dims)
399
+
400
+ self.semantic_quantizer = ResidualVectorQuantize(
401
+ input_dim=input_dim,
402
+ n_codebooks=1,
403
+ codebook_size=semantic_codebook_size,
404
+ codebook_dim=codebook_dim,
405
+ quantizer_dropout=0.0,
406
+ )
407
+
408
+ self.quantizer = ResidualVectorQuantize(
409
+ input_dim=input_dim,
410
+ n_codebooks=n_codebooks,
411
+ codebook_size=codebook_size,
412
+ codebook_dim=codebook_dim,
413
+ quantizer_dropout=quantizer_dropout,
414
+ )
415
+
416
+ convnet_type = CausalConvNet
417
+ transconvnet_type = CausalTransConvNet
418
+
419
+ self.downsample = nn.Sequential(
420
+ *[
421
+ nn.Sequential(
422
+ convnet_type(all_dims[idx], all_dims[idx + 1], kernel_size=factor, stride=factor),
423
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
424
+ )
425
+ for idx, factor in enumerate(downsample_factor)
426
+ ]
427
+ )
428
+
429
+ self.upsample = nn.Sequential(
430
+ *[
431
+ nn.Sequential(
432
+ transconvnet_type(all_dims[idx + 1], all_dims[idx], kernel_size=factor, stride=factor),
433
+ ConvNeXtBlock(dim=all_dims[idx]),
434
+ )
435
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
436
+ ]
437
+ )
438
+
439
+ self.apply(self._init_weights)
440
+ self.pre_module = pre_module if pre_module is not None else nn.Identity()
441
+ self.post_module = post_module if post_module is not None else nn.Identity()
442
+ self.semantic_predictor_module = (
443
+ semantic_predictor_module if semantic_predictor_module is not None else nn.Identity()
444
+ )
445
+
446
+ @staticmethod
447
+ def _init_weights(m):
448
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
449
+ nn.init.trunc_normal_(m.weight, std=0.02)
450
+ if getattr(m, "bias", None) is not None:
451
+ nn.init.constant_(m.bias, 0)
452
+
453
+ def forward(self, z: Tensor, n_quantizers: Optional[int] = None, semantic_len: Optional[Tensor] = None, **kwargs):
454
+ # z: (B, D, T)
455
+ original_shape = z.shape
456
+ if semantic_len is None:
457
+ semantic_len = torch.LongTensor([z.shape[-1]])
458
+
459
+ z = self.downsample(z)
460
+ z = self.pre_module(z) # (B, D, T) or (B, T, D) depending on module; original uses channels-first in/out
461
+
462
+ semantic_z, semantic_codes, semantic_latents, semantic_commitment_loss, semantic_codebook_loss = \
463
+ self.semantic_quantizer(z)
464
+ residual_z = z - semantic_z
465
+ residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(residual_z, n_quantizers=n_quantizers)
466
+ z = semantic_z + residual_z
467
+ commitment_loss = commitment_loss + semantic_commitment_loss
468
+ codebook_loss = codebook_loss + semantic_codebook_loss
469
+ codes = torch.cat([semantic_codes, codes], dim=1)
470
+ latents = torch.cat([semantic_latents, latents], dim=1)
471
+ z = self.post_module(z)
472
+ z = self.upsample(z)
473
+
474
+ # Pad or crop z to match original shape (time dimension)
475
+ diff = original_shape[-1] - z.shape[-1]
476
+ right = 0
477
+ left = abs(diff) - right
478
+ if diff > 0:
479
+ z = F.pad(z, (left, right))
480
+ elif diff < 0:
481
+ z = z[..., left:]
482
+
483
+ return VQResult(
484
+ z=z, codes=codes, latents=latents,
485
+ commitment_loss=commitment_loss, codebook_loss=codebook_loss,
486
+ )
487
+
488
+ def decode(self, indices: Tensor) -> Tensor:
489
+ new_indices = torch.zeros_like(indices)
490
+ new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.semantic_quantizer.codebook_size - 1)
491
+ new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.codebook_size - 1)
492
+
493
+ z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0]
494
+ z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0]
495
+ z_q = z_q_semantic + z_q_residual
496
+ z_q = self.post_module(z_q)
497
+ z_q = self.upsample(z_q)
498
+ return z_q
499
+
500
+
501
+ # --------------------------------------------------------------------
502
+ # Transformer stack
503
+ # --------------------------------------------------------------------
504
+
505
+ @dataclass
506
+ class ModelArgs:
507
+ block_size: int = 2048
508
+ n_layer: int = 8
509
+ n_head: int = 8
510
+ dim: int = 512
511
+ intermediate_size: int = 1536
512
+ n_local_heads: int = -1
513
+ head_dim: int = 64
514
+ rope_base: float = 10000
515
+ norm_eps: float = 1e-5
516
+ dropout_rate: float = 0.1
517
+ attn_dropout_rate: float = 0.1
518
+ channels_first: bool = True # to be compatible with conv1d input/output
519
+ pos_embed_type: str = "rope" # "rope" or "conformer"
520
+ max_relative_position: int = 128
521
+
522
+ def __post_init__(self):
523
+ if self.n_local_heads == -1:
524
+ self.n_local_heads = self.n_head
525
+ if self.intermediate_size is None:
526
+ hidden_dim = 4 * self.dim
527
+ n_hidden = int(2 * hidden_dim / 3)
528
+ self.intermediate_size = find_multiple(n_hidden, 256)
529
+ assert self.pos_embed_type in ["rope", "conformer"]
530
+
531
+
532
+ class KVCache(nn.Module):
533
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
534
+ super().__init__()
535
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
536
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
537
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
538
+
539
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
540
+ # input_pos: [S], k_val: [B, H, S, D]
541
+ assert input_pos.shape[0] == k_val.shape[2]
542
+ k_out = self.k_cache
543
+ v_out = self.v_cache
544
+ k_out[:, :, input_pos] = k_val
545
+ v_out[:, :, input_pos] = v_val
546
+ return (
547
+ k_out[:, :, : input_pos.max() + 1, :],
548
+ v_out[:, :, : input_pos.max() + 1, :],
549
+ )
550
+
551
+ def clear_cache(self, prompt_len: int):
552
+ self.k_cache[:, :, prompt_len:, :].fill_(0)
553
+ self.v_cache[:, :, prompt_len:, :].fill_(0)
554
+
555
+
556
+ class Transformer(nn.Module):
557
+ def __init__(self, config: ModelArgs) -> None:
558
+ super().__init__()
559
+ self.config = config
560
+
561
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
562
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
563
+
564
+ if config.pos_embed_type == "rope":
565
+ freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base)
566
+ self.register_buffer("freqs_cis", freqs_cis)
567
+ else:
568
+ self.register_buffer("freqs_cis", None)
569
+
570
+ causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool))
571
+ self.register_buffer("causal_mask", causal_mask)
572
+
573
+ self.max_batch_size = -1
574
+ self.max_seq_length = -1
575
+ self.use_kv_cache = False
576
+
577
+ def setup_caches(self, max_batch_size, max_seq_length):
578
+ head_dim = self.config.dim // self.config.n_head
579
+ max_seq_length = find_multiple(max_seq_length, 8)
580
+ self.max_seq_length = max_seq_length
581
+ self.max_batch_size = max_batch_size
582
+ dtype = self.norm.weight.dtype
583
+ device = self.norm.weight.device
584
+
585
+ for b in self.layers:
586
+ b.attention.kv_cache = KVCache(
587
+ max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype
588
+ ).to(device)
589
+
590
+ self.use_kv_cache = True
591
+
592
+ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor:
593
+ if self.config.pos_embed_type == "rope":
594
+ assert self.freqs_cis is not None
595
+ freqs_cis = self.freqs_cis[input_pos]
596
+ else:
597
+ freqs_cis = None
598
+
599
+ if mask is None:
600
+ if not self.training and self.use_kv_cache:
601
+ mask = self.causal_mask[None, None, input_pos]
602
+ mask = mask[..., : input_pos.max() + 1]
603
+ else:
604
+ mask = self.causal_mask[None, None, input_pos]
605
+ mask = mask[..., input_pos]
606
+
607
+ for layer in self.layers:
608
+ x = layer(x, input_pos, freqs_cis, mask)
609
+ x = self.norm(x)
610
+ return x
611
+
612
+
613
+ class TransformerBlock(nn.Module):
614
+ def __init__(self, config: ModelArgs) -> None:
615
+ super().__init__()
616
+ self.attention = Attention(config)
617
+ self.feed_forward = FeedForward(config)
618
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
619
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
620
+ self.attention_layer_scale = LayerScale(config.dim, inplace=True)
621
+ self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
622
+
623
+ def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
624
+ h = x + self.attention_layer_scale(
625
+ self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
626
+ )
627
+ out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
628
+ return out
629
+
630
+
631
+ class Attention(nn.Module):
632
+ def __init__(self, config: ModelArgs):
633
+ super().__init__()
634
+ assert config.dim % config.n_head == 0
635
+
636
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
637
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
638
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
639
+ self.kv_cache = None
640
+
641
+ self.n_head = config.n_head
642
+ self.head_dim = config.head_dim
643
+ self.n_local_heads = config.n_local_heads
644
+ self.dim = config.dim
645
+ self.attn_dropout_rate = config.attn_dropout_rate
646
+ self.pos_embed_type = config.pos_embed_type
647
+
648
+ if self.pos_embed_type == "conformer":
649
+ self.max_relative_position = config.max_relative_position
650
+ num_pos_embeddings = 2 * config.max_relative_position + 1
651
+ self.rel_pos_embeddings = nn.Parameter(torch.zeros(num_pos_embeddings, self.head_dim))
652
+ nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
653
+
654
+ def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
655
+ positions = torch.arange(seqlen, device=q.device)
656
+ relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
657
+ relative_positions = torch.clamp(relative_positions + self.max_relative_position,
658
+ 0, 2 * self.max_relative_position)
659
+ rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
660
+ q = q.transpose(1, 2) # [B, S, H, D]
661
+ rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
662
+ rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
663
+ return rel_logits
664
+
665
+ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
666
+ bsz, seqlen, _ = x.shape
667
+
668
+ kv_size = self.n_local_heads * self.head_dim
669
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
670
+ context_seqlen = seqlen
671
+
672
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
673
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
674
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
675
+
676
+ if self.pos_embed_type == "rope":
677
+ q = apply_rotary_emb(q, freqs_cis)
678
+ k = apply_rotary_emb(k, freqs_cis)
679
+
680
+ q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
681
+
682
+ if self.kv_cache is not None:
683
+ k, v = self.kv_cache.update(input_pos, k, v)
684
+
685
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
686
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
687
+
688
+ if self.pos_embed_type == "conformer":
689
+ scale = 1.0 / math.sqrt(self.head_dim)
690
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
691
+ rel_scores = self._compute_conformer_pos_scores(q, seqlen)
692
+ scores = scores + rel_scores
693
+ if mask is not None:
694
+ scores = scores.masked_fill(~mask, float("-inf"))
695
+ attn = F.softmax(scores, dim=-1)
696
+ if self.attn_dropout_rate > 0 and self.training:
697
+ attn = F.dropout(attn, p=self.attn_dropout_rate)
698
+ y = torch.matmul(attn, v)
699
+ else:
700
+ y = F.scaled_dot_product_attention(
701
+ q, k, v,
702
+ dropout_p=self.attn_dropout_rate if self.training else 0.0,
703
+ attn_mask=mask,
704
+ )
705
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
706
+ y = self.wo(y)
707
+ return y
708
+
709
+
710
+ class FeedForward(nn.Module):
711
+ def __init__(self, config: ModelArgs) -> None:
712
+ super().__init__()
713
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
714
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
715
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
716
+ self.dropout = nn.Dropout(config.dropout_rate)
717
+
718
+ def forward(self, x: Tensor) -> Tensor:
719
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
720
+
721
+
722
+ class RMSNorm(nn.Module):
723
+ def __init__(self, dim: int, eps: float = 1e-5):
724
+ super().__init__()
725
+ self.eps = eps
726
+ self.weight = nn.Parameter(torch.ones(dim))
727
+
728
+ def _norm(self, x):
729
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
730
+
731
+ def forward(self, x: Tensor) -> Tensor:
732
+ output = self._norm(x.float()).type_as(x)
733
+ return output * self.weight
734
+
735
+
736
+ class LayerScale(nn.Module):
737
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-2, inplace: bool = False) -> None:
738
+ super().__init__()
739
+ self.inplace = inplace
740
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
741
+
742
+ def forward(self, x: Tensor) -> Tensor:
743
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
744
+
745
+
746
+ class WindowLimitedTransformer(Transformer):
747
+ """Transformer with window-limited causal attention."""
748
+ def __init__(
749
+ self,
750
+ config: ModelArgs,
751
+ input_dim: int = 512,
752
+ window_size: Optional[int] = None,
753
+ causal: bool = True,
754
+ look_ahead_conv: Optional[nn.Module] = None,
755
+ ):
756
+ super().__init__(config)
757
+ self.window_size = window_size
758
+ self.causal = causal
759
+ self.channels_first = config.channels_first
760
+ self.look_ahead_conv = look_ahead_conv if look_ahead_conv is not None else nn.Identity()
761
+ self.input_proj = nn.Linear(input_dim, config.dim) if input_dim != config.dim else nn.Identity()
762
+ self.output_proj = nn.Linear(config.dim, input_dim) if input_dim != config.dim else nn.Identity()
763
+
764
+ def make_window_limited_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor:
765
+ if self.causal:
766
+ mask = torch.tril(torch.ones(max_length, max_length))
767
+ row_indices = torch.arange(max_length).view(-1, 1)
768
+ window_size = self.window_size or max_length
769
+ valid_range = (row_indices - window_size + 1).clamp(min=0)
770
+ column_indices = torch.arange(max_length)
771
+ mask = (column_indices >= valid_range) & mask.bool()
772
+ else:
773
+ raise NotImplementedError
774
+ mask = mask.bool()[None, None]
775
+ return mask
776
+
777
+ def make_mask(self, max_length: int, x_lens: Optional[Tensor] = None) -> Tensor:
778
+ if self.causal:
779
+ mask = torch.tril(torch.ones(max_length, max_length))
780
+ else:
781
+ mask = torch.ones(max_length, max_length)
782
+ mask = mask.bool()[None, None]
783
+ for i, x_len in enumerate(x_lens):
784
+ mask[:x_len, i] = 0
785
+ mask = mask.bool()[None, None]
786
+ return mask
787
+
788
+ def forward(self, x: Tensor, x_lens: Optional[Tensor] = None) -> Tensor:
789
+ if self.channels_first:
790
+ x = x.transpose(1, 2)
791
+ x = self.input_proj(x)
792
+ x = self.look_ahead_conv(x)
793
+ input_pos = torch.arange(x.shape[1], device=x.device)
794
+ max_length = x.shape[1]
795
+ if self.window_size is not None:
796
+ mask = self.make_window_limited_mask(max_length, x_lens)
797
+ else:
798
+ mask = self.make_mask(max_length, x_lens)
799
+ mask = mask.to(x.device)
800
+ x = super().forward(x, input_pos, mask)
801
+ x = self.output_proj(x)
802
+ if self.channels_first:
803
+ x = x.transpose(1, 2)
804
+ return x
805
+
806
+
807
+ def precompute_freqs_cis(
808
+ seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
809
+ ) -> Tensor:
810
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
811
+ t = torch.arange(seq_len, device=freqs.device)
812
+ freqs = torch.outer(t, freqs)
813
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
814
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
815
+ return cache.to(dtype=dtype)
816
+
817
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
818
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
819
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
820
+ x_out2 = torch.stack(
821
+ [
822
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
823
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
824
+ ],
825
+ -1,
826
+ )
827
+ x_out2 = x_out2.flatten(3)
828
+ return x_out2.type_as(x)
829
+
830
+
831
+ def init_weights(m):
832
+ if isinstance(m, nn.Conv1d):
833
+ nn.init.trunc_normal_(m.weight, std=0.02)
834
+ nn.init.constant_(m.bias, 0)
835
+
836
+
837
+ # --------------------------------------------------------------------
838
+ # Top-level AE
839
+ # --------------------------------------------------------------------
840
+
841
+ class EncoderBlock(nn.Module):
842
+ def __init__(
843
+ self,
844
+ dim: int = 16,
845
+ stride: int = 1,
846
+ causal: bool = False,
847
+ n_t_layer: int = 0,
848
+ transformer_general_config=None,
849
+ ):
850
+ super().__init__()
851
+ conv_class = CausalWNConv1d if causal else WNConv1d
852
+ transformer_module = (
853
+ nn.Identity()
854
+ if n_t_layer == 0
855
+ else WindowLimitedTransformer(
856
+ causal=causal,
857
+ input_dim=dim,
858
+ window_size=512,
859
+ config=transformer_general_config(
860
+ n_layer=n_t_layer,
861
+ n_head=dim // 64,
862
+ dim=dim,
863
+ intermediate_size=dim * 3,
864
+ ),
865
+ )
866
+ )
867
+ self.block = nn.Sequential(
868
+ # three multi‑receptive‑field residual units
869
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
870
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
871
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
872
+ Snake1d(dim // 2),
873
+ conv_class(dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
874
+ transformer_module,
875
+ )
876
+
877
+ def forward(self, x: Tensor) -> Tensor:
878
+ return self.block(x)
879
+
880
+
881
+ class ResidualUnit(nn.Module):
882
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
883
+ super().__init__()
884
+ conv_class = CausalWNConv1d if causal else WNConv1d
885
+ pad = ((7 - 1) * dilation) // 2
886
+ self.block = nn.Sequential(
887
+ Snake1d(dim),
888
+ conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
889
+ Snake1d(dim),
890
+ conv_class(dim, dim, kernel_size=1),
891
+ )
892
+ self.causal = causal
893
+
894
+ def forward(self, x: Tensor) -> Tensor:
895
+ y = self.block(x)
896
+ pad = x.shape[-1] - y.shape[-1]
897
+ if pad > 0:
898
+ if self.causal:
899
+ x = x[..., :-pad]
900
+ else:
901
+ x = x[..., pad // 2 : -pad // 2]
902
+ return x + y
903
+
904
+
905
+ class Encoder(nn.Module):
906
+ def __init__(
907
+ self,
908
+ d_model: int = 64,
909
+ strides: List[int] = [2, 4, 8, 8],
910
+ d_latent: int = 64,
911
+ n_transformer_layers: List[int] = [0, 0, 4, 4],
912
+ transformer_general_config: Optional[ModelArgs] = None,
913
+ causal: bool = False,
914
+ ):
915
+ super().__init__()
916
+ conv_class = CausalWNConv1d if causal else WNConv1d
917
+ layers: List[nn.Module] = [conv_class(1, d_model, kernel_size=7, padding=3)]
918
+ for stride, n_t_layer in zip(strides, n_transformer_layers):
919
+ d_model *= 2
920
+ layers.append(
921
+ EncoderBlock(
922
+ d_model, stride=stride, causal=causal,
923
+ n_t_layer=n_t_layer, transformer_general_config=transformer_general_config,
924
+ )
925
+ )
926
+ layers += [Snake1d(d_model), conv_class(d_model, d_latent, kernel_size=3, padding=1)]
927
+ self.block = nn.Sequential(*layers)
928
+ self.enc_dim = d_model
929
+
930
+ def forward(self, x: Tensor) -> Tensor:
931
+ return self.block(x)
932
+
933
+
934
+ class DecoderBlock(nn.Module):
935
+ def __init__(
936
+ self,
937
+ input_dim: int = 16,
938
+ output_dim: int = 8,
939
+ stride: int = 1,
940
+ causal: bool = False,
941
+ n_t_layer: int = 0,
942
+ transformer_general_config=None,
943
+ ):
944
+ super().__init__()
945
+ conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
946
+ transformer_module = (
947
+ nn.Identity()
948
+ if n_t_layer == 0
949
+ else WindowLimitedTransformer(
950
+ causal=causal,
951
+ input_dim=input_dim,
952
+ window_size=None,
953
+ config=transformer_general_config(
954
+ n_layer=n_t_layer,
955
+ n_head=input_dim // 64,
956
+ dim=input_dim,
957
+ intermediate_size=input_dim * 3,
958
+ ),
959
+ )
960
+ )
961
+ self.block = nn.Sequential(
962
+ Snake1d(input_dim),
963
+ conv_trans_class(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
964
+ ResidualUnit(output_dim, dilation=1, causal=causal),
965
+ ResidualUnit(output_dim, dilation=3, causal=causal),
966
+ ResidualUnit(output_dim, dilation=9, causal=causal),
967
+ )
968
+
969
+ def forward(self, x: Tensor) -> Tensor:
970
+ return self.block(x)
971
+
972
+
973
+ class Decoder(nn.Module):
974
+ def __init__(
975
+ self,
976
+ input_channel: int,
977
+ channels: int,
978
+ rates: List[int],
979
+ d_out: int = 1,
980
+ causal: bool = False,
981
+ n_transformer_layers: List[int] = [0, 0, 0, 0],
982
+ transformer_general_config=None,
983
+ ):
984
+ super().__init__()
985
+ conv_class = CausalWNConv1d if causal else WNConv1d
986
+ layers: List[nn.Module] = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
987
+ for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
988
+ input_dim = channels // 2**i
989
+ output_dim = channels // 2 ** (i + 1)
990
+ layers.append(
991
+ DecoderBlock(
992
+ input_dim, output_dim, stride, causal=causal,
993
+ n_t_layer=n_t_layer, transformer_general_config=transformer_general_config,
994
+ )
995
+ )
996
+ layers += [Snake1d(output_dim), conv_class(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh()]
997
+ self.model = nn.Sequential(*layers)
998
+
999
+ def forward(self, x: Tensor) -> Tensor:
1000
+ return self.model(x)
1001
+
1002
+
1003
+ class DAC(nn.Module):
1004
+ def __init__(
1005
+ self,
1006
+ encoder_dim: int = 64,
1007
+ encoder_rates: List[int] = [2, 4, 8, 8],
1008
+ latent_dim: Optional[int] = None,
1009
+ decoder_dim: int = 1536,
1010
+ decoder_rates: List[int] = [8, 8, 4, 2],
1011
+ quantizer: Optional[nn.Module] = None,
1012
+ sample_rate: int = 44100,
1013
+ causal: bool = True,
1014
+ encoder_transformer_layers: List[int] = [0, 0, 0, 0],
1015
+ decoder_transformer_layers: List[int] = [0, 0, 0, 0],
1016
+ transformer_general_config=None,
1017
+ ):
1018
+ super().__init__()
1019
+
1020
+ self.encoder_dim = encoder_dim
1021
+ self.encoder_rates = encoder_rates
1022
+ self.decoder_dim = decoder_dim
1023
+ self.decoder_rates = decoder_rates
1024
+ self.sample_rate = sample_rate
1025
+
1026
+ if latent_dim is None:
1027
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
1028
+ self.latent_dim = latent_dim
1029
+
1030
+ self.hop_length = int(np.prod(encoder_rates))
1031
+ self.encoder = Encoder(
1032
+ encoder_dim, encoder_rates, latent_dim, causal=causal,
1033
+ n_transformer_layers=encoder_transformer_layers,
1034
+ transformer_general_config=transformer_general_config,
1035
+ )
1036
+ self.quantizer = quantizer
1037
+ self.decoder = Decoder(
1038
+ latent_dim, decoder_dim, decoder_rates, causal=causal,
1039
+ n_transformer_layers=decoder_transformer_layers,
1040
+ transformer_general_config=transformer_general_config,
1041
+ )
1042
+ self.sample_rate = sample_rate
1043
+ self.apply(init_weights)
1044
+
1045
+ self.delay = self.get_delay()
1046
+ self.frame_length = self.hop_length * 4
1047
+
1048
+ def get_output_length(self, input_length: int) -> int:
1049
+ length = input_length
1050
+ for stride in self.encoder_rates:
1051
+ length = math.ceil(length / stride)
1052
+ return length
1053
+
1054
+ def get_delay(self) -> int:
1055
+ l_out = self.get_output_length(0)
1056
+ L = l_out
1057
+
1058
+ layers = [layer for layer in self.modules() if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d))]
1059
+ for layer in reversed(layers):
1060
+ d = layer.dilation[0]
1061
+ k = layer.kernel_size[0]
1062
+ s = layer.stride[0]
1063
+ if isinstance(layer, nn.ConvTranspose1d):
1064
+ L = ((L - d * (k - 1) - 1) / s) + 1
1065
+ elif isinstance(layer, nn.Conv1d):
1066
+ L = (L - 1) * s + d * (k - 1) + 1
1067
+ L = math.ceil(L)
1068
+
1069
+ l_in = L
1070
+ return (l_in - l_out) // 2
1071
+
1072
+ def preprocess(self, audio_data: Tensor, sample_rate: Optional[int]) -> Tensor:
1073
+ if sample_rate is None:
1074
+ sample_rate = self.sample_rate
1075
+ assert sample_rate == self.sample_rate
1076
+
1077
+ length = audio_data.shape[-1]
1078
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
1079
+ audio_data = F.pad(audio_data, (0, right_pad))
1080
+ return audio_data
1081
+
1082
+ def encode(
1083
+ self,
1084
+ audio_data: Tensor,
1085
+ audio_lengths: Optional[Tensor] = None,
1086
+ n_quantizers: Optional[int] = None,
1087
+ **kwargs,
1088
+ ):
1089
+ """Encode audio to quantized code indices."""
1090
+ if audio_data.ndim == 2:
1091
+ audio_data = audio_data.unsqueeze(1)
1092
+ length = audio_data.shape[-1]
1093
+ right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
1094
+ audio_data = F.pad(audio_data, (0, right_pad))
1095
+ if audio_lengths is None:
1096
+ audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
1097
+
1098
+ z = self.encoder(audio_data)
1099
+ vq_results = self.quantizer(z, n_quantizers, **kwargs)
1100
+ indices = vq_results.codes
1101
+ indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
1102
+ return indices, indices_lens
1103
+
1104
+ def decode(self, indices: Tensor, feature_lengths: Tensor):
1105
+ """Decode code indices to audio."""
1106
+ if indices.ndim == 2:
1107
+ indices = indices[None]
1108
+ z = self.quantizer.decode(indices)
1109
+ audio_lengths = feature_lengths * self.frame_length
1110
+ return self.decoder(z), audio_lengths
1111
+
1112
+ def encode_to_codes(self, audio: Tensor, audio_lengths: Optional[Tensor] = None, n_quantizers: Optional[int] = None, **kw):
1113
+ return self.encode(audio, audio_lengths, n_quantizers, **kw)
1114
+
1115
+ def decode_codes(self, indices: Tensor, feature_lengths: Tensor):
1116
+ return self.decode(indices, feature_lengths)
1117
+
1118
+ @torch.no_grad()
1119
+ def encode_zq(self, audio_data: Tensor) -> Tensor:
1120
+ indices, _ = self.encode(audio_data)
1121
+ new_indices = torch.zeros_like(indices)
1122
+ new_indices[:, 0] = torch.clamp(indices[:, 0], max=self.quantizer.semantic_quantizer.codebook_size - 1)
1123
+ new_indices[:, 1:] = torch.clamp(indices[:, 1:], max=self.quantizer.quantizer.codebook_size - 1)
1124
+
1125
+ z_q_semantic = self.quantizer.semantic_quantizer.from_codes(new_indices[:, :1])[0]
1126
+ z_q_residual = self.quantizer.quantizer.from_codes(new_indices[:, 1:])[0]
1127
+ z_q = z_q_semantic + z_q_residual
1128
+ return z_q
1129
+
1130
+ @torch.no_grad()
1131
+ def decode_zq(self, z_q: Tensor) -> Tensor:
1132
+ z_q = self.quantizer.post_module(z_q)
1133
+ z_q = self.quantizer.upsample(z_q)
1134
+ return self.decoder(z_q)
1135
+
1136
+ @property
1137
+ def device(self) -> torch.device: return next(self.parameters()).device
1138
+
1139
+ @property
1140
+ def dtype(self) -> torch.dtype: return next(self.parameters()).dtype
1141
+
1142
+ # --------------------------------------------------------------------
1143
+ # Build helpers
1144
+ # --------------------------------------------------------------------
1145
+
1146
+ def build_ae(**cfg) -> DAC:
1147
+ """
1148
+ Factory used by external loaders
1149
+ """
1150
+ # Shared transformer config for the RVQ pre/post modules
1151
+ q_config = ModelArgs(
1152
+ block_size=4096, n_layer=8, n_head=16, dim=1024,
1153
+ intermediate_size=3072, head_dim=64, norm_eps=1e-5,
1154
+ dropout_rate=0.1, attn_dropout_rate=0.1, channels_first=True
1155
+ )
1156
+
1157
+ def make_transformer():
1158
+ return WindowLimitedTransformer(
1159
+ causal=True, window_size=128, input_dim=1024, config=q_config
1160
+ )
1161
+
1162
+ quantizer = DownsampleResidualVectorQuantize(
1163
+ input_dim=1024, n_codebooks=9, codebook_size=1024, codebook_dim=8,
1164
+ quantizer_dropout=0.5, downsample_factor=(2, 2),
1165
+ semantic_codebook_size=4096,
1166
+ pre_module=make_transformer(),
1167
+ post_module=make_transformer(),
1168
+ )
1169
+
1170
+ def transformer_general_config(**kw):
1171
+ return ModelArgs(
1172
+ block_size=kw.get("block_size", 16384),
1173
+ n_layer=kw.get("n_layer", 8),
1174
+ n_head=kw.get("n_head", 8),
1175
+ dim=kw.get("dim", 512),
1176
+ intermediate_size=kw.get("intermediate_size", 1536),
1177
+ n_local_heads=kw.get("n_local_heads", -1),
1178
+ head_dim=kw.get("head_dim", 64),
1179
+ rope_base=kw.get("rope_base", 10000),
1180
+ norm_eps=kw.get("norm_eps", 1e-5),
1181
+ dropout_rate=kw.get("dropout_rate", 0.1),
1182
+ attn_dropout_rate=kw.get("attn_dropout_rate", 0.1),
1183
+ channels_first=kw.get("channels_first", True),
1184
+ )
1185
+
1186
+ dac = DAC(
1187
+ encoder_dim=64, encoder_rates=[2, 4, 8, 8], latent_dim=1024,
1188
+ decoder_dim=1536, decoder_rates=[8, 8, 4, 2],
1189
+ quantizer=quantizer, sample_rate=44100, causal=True,
1190
+ encoder_transformer_layers=[0, 0, 0, 4],
1191
+ decoder_transformer_layers=[4, 0, 0, 0],
1192
+ transformer_general_config=transformer_general_config,
1193
+ )
1194
+ return dac
1195
+
1196
+ __all__ = [
1197
+ "DAC",
1198
+ "build_ae",
1199
+ "VectorQuantize",
1200
+ "ResidualVectorQuantize",
1201
+ "DownsampleResidualVectorQuantize",
1202
+ ]
1203
+
1204
+
1205
+ # ----- BEGIN DAC MIT LICENSE -----
1206
+ # MIT License
1207
+ # Copyright (c) 2023-present, Descript
1208
+ #
1209
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
1210
+ # of this software and associated documentation files (the "Software"), to deal
1211
+ # in the Software without restriction, including without limitation the rights
1212
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
1213
+ # copies of the Software, and to permit persons to whom the Software is
1214
+ # furnished to do so, subject to the following conditions:
1215
+ #
1216
+ # The above copyright notice and this permission notice shall be included in all
1217
+ # copies or substantial portions of the Software.
1218
+ #
1219
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1220
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1221
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1222
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1223
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1224
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1225
+ # SOFTWARE.
1226
+ # ----- END DAC MIT LICENSE -----
1227
+
inference.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable, List, Tuple
3
+ import torch
4
+ import safetensors.torch as st
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from model import EchoDiT
8
+ from autoencoder import build_ae, DAC
9
+
10
+ import torchaudio
11
+ from torchcodec.decoders import AudioDecoder
12
+
13
+ # from samplers import Sampler
14
+
15
+ SampleFn = Callable[
16
+ [EchoDiT, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int],
17
+ torch.Tensor
18
+ ]
19
+ ### Loading
20
+
21
+ def load_model_from_hf(repo_id: str = 'jordand/echo-tts', device: str = 'cuda', dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None) -> EchoDiT:
22
+ with torch.device('meta'):
23
+ model = EchoDiT(
24
+ latent_size=80, model_size=2048, num_layers=24, num_heads=16,
25
+ intermediate_size=5888, norm_eps=1e-5, max_seq_len=640,
26
+ text_vocab_size=256, text_model_size=1280, text_num_layers=14,
27
+ text_num_heads=10, text_intermediate_size=3328, text_max_seq_len=768,
28
+ speaker_patch_size=4, speaker_model_size=1280, speaker_num_layers=14,
29
+ speaker_num_heads=10, speaker_intermediate_size=3328,
30
+ speaker_max_patched_seq_len=640, timestep_embed_size=512, adaln_rank=256,
31
+ )
32
+ w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
33
+
34
+ # Load to CPU first
35
+ state = st.load_file(w_path, device='cpu')
36
+
37
+ # Convert dtype on CPU if needed
38
+ if dtype is not None:
39
+ state = {k: v.to(dtype=dtype) for k, v in state.items()}
40
+
41
+ # Now move to device
42
+ state = {k: v.to(device=device) for k, v in state.items()}
43
+
44
+ model.load_state_dict(state, strict=True, assign=True)
45
+ model = model.eval()
46
+
47
+ if compile:
48
+ model = torch.compile(model)
49
+ model.get_kv_cache = torch.compile(model.get_kv_cache)
50
+
51
+ return model
52
+
53
+ def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = 'cuda', dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC:
54
+ # have not tested lower precisions with fish AE yet
55
+
56
+ with torch.device('meta'):
57
+ fish_ae = build_ae()
58
+
59
+ w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
60
+ if dtype is not None and dtype != torch.float32:
61
+ state = st.load_file(w_path, device='cpu')
62
+ state = {k: v.to(dtype=dtype) for k, v in state.items()}
63
+ state = {k: v.to(device=device) for k, v in state.items()}
64
+ fish_ae.load_state_dict(state, strict=False, assign=True)
65
+ else:
66
+ state = st.load_file(w_path, device=device)
67
+ fish_ae.load_state_dict(state, strict=False, assign=True)
68
+
69
+ fish_ae = fish_ae.eval().to(device)
70
+
71
+ if compile:
72
+ fish_ae.encoder = torch.compile(fish_ae.encoder)
73
+ fish_ae.decoder = torch.compile(fish_ae.decoder)
74
+
75
+ return fish_ae
76
+
77
+
78
+ @dataclass
79
+ class PCAState:
80
+ pca_components: torch.Tensor
81
+ pca_mean: torch.Tensor
82
+ latent_scale: float
83
+
84
+ def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts', device: str = 'cuda', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
85
+ p_path = hf_hub_download(repo_id, filename, token=token)
86
+ t = st.load_file(p_path, device=device)
87
+ return PCAState(
88
+ pca_components=t["pca_components"],
89
+ pca_mean=t["pca_mean"],
90
+ latent_scale=float(t["latent_scale"].item()),
91
+ )
92
+
93
+ ### default load audio
94
+
95
+ def load_audio(path: str) -> torch.Tensor:
96
+
97
+ decoder = AudioDecoder(path)
98
+ sr = decoder.metadata.sample_rate
99
+ audio = decoder.get_samples_played_in_range(0, 120)
100
+ audio = audio.data.mean(dim=0).unsqueeze(0)
101
+ audio = torchaudio.functional.resample(audio, sr, 44_100)
102
+ audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.))
103
+ # TODO is this better than clipping? should we target a specific energy level?
104
+ return audio
105
+
106
+
107
+
108
+ ### Text helpers
109
+
110
+ def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True) -> torch.Tensor:
111
+
112
+ if normalize:
113
+ text = text.replace('…', '...')
114
+ text = text.replace('“', '"')
115
+ text = text.replace('”', '"')
116
+ text = text.replace('’', "'")
117
+ text = text.replace('\n', " ")
118
+
119
+ b = list(text.encode('utf-8'))
120
+ if append_bos:
121
+ b.insert(0, 0)
122
+ return torch.tensor(b)
123
+
124
+ def get_text_input_ids_and_mask(text_arr: List[str], max_length: int | None, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]:
125
+ batch_size = len(text_arr)
126
+ if max_length is None:
127
+ max_length = max(len(tokenizer_encode(text)) for text in text_arr) # obviously bad...
128
+
129
+ tokens = torch.zeros((batch_size, max_length), dtype=torch.int32)
130
+ mask = torch.zeros((batch_size, max_length), dtype=torch.bool)
131
+
132
+ for i, text in enumerate(text_arr):
133
+ encoded = tokenizer_encode(text)
134
+ length = min(len(encoded), max_length)
135
+ tokens[i, :length] = encoded[:length]
136
+ mask[i, :length] = 1
137
+
138
+ if device is not None:
139
+ tokens = tokens.to(device)
140
+ mask = mask.to(device)
141
+
142
+ return tokens, mask
143
+
144
+
145
+ ### Autoencoder Inference
146
+
147
+ @torch.inference_mode()
148
+ def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
149
+ assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length)
150
+ z_q = fish_ae.encode_zq(audio).float()
151
+ z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T
152
+ z_q = z_q * pca_state.latent_scale
153
+ return z_q
154
+
155
+ @torch.inference_mode()
156
+ def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor:
157
+ z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean
158
+ return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float()
159
+
160
+ @torch.inference_mode()
161
+ def ae_reconstruct(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
162
+ # (audio is (b, 1, length))
163
+ z_q = ae_encode(fish_ae, pca_state, audio.to(fish_ae.dtype))
164
+ return ae_decode(fish_ae, pca_state, z_q)
165
+
166
+
167
+ @torch.inference_mode()
168
+ def get_speaker_latent_and_mask(
169
+ fish_ae: DAC,
170
+ pca_state: PCAState,
171
+ audio: torch.Tensor, # (1, length)
172
+ max_speaker_latent_len: int = 2560, # pretrained max length
173
+ audio_chunk_size: int = 640 * 2048 # (~30 seconds, 1/4 max speaker condition size)
174
+ ) -> tuple[torch.Tensor, torch.Tensor]:
175
+
176
+ # gets speaker latent and mask from audio, computes in chunks and concatenates (similar to pretraining setup)
177
+
178
+ AE_DOWNSAMPLE_FACTOR = 2048
179
+ max_audio_len = max_speaker_latent_len * AE_DOWNSAMPLE_FACTOR
180
+
181
+ assert audio.ndim == 2 and audio.shape[0] == 1 # (1, length)
182
+ audio = audio[:, :max_audio_len]
183
+ audio_len = audio.shape[1]
184
+
185
+ latent_arr = []
186
+
187
+ for i in range(0, audio_len, audio_chunk_size):
188
+ audio_chunk = audio[:, i:i + audio_chunk_size]
189
+ if audio_chunk.shape[1] < audio_chunk_size:
190
+ audio_chunk = torch.nn.functional.pad(audio_chunk, (0, audio_chunk_size - audio_chunk.shape[1]))
191
+
192
+ latent_chunk = ae_encode(fish_ae, pca_state, audio_chunk.unsqueeze(0))
193
+ latent_arr.append(latent_chunk)
194
+
195
+ speaker_latent = torch.cat(latent_arr, dim=1)
196
+
197
+ actual_latent_len = audio_len // AE_DOWNSAMPLE_FACTOR
198
+ speaker_mask = (torch.arange(speaker_latent.shape[1], device=speaker_latent.device) < actual_latent_len).unsqueeze(0)
199
+
200
+ if speaker_latent.shape[1] < max_speaker_latent_len:
201
+ speaker_latent = torch.nn.functional.pad(speaker_latent, (0, 0, 0, max_speaker_latent_len - speaker_latent.shape[1]))
202
+ speaker_mask = torch.nn.functional.pad(speaker_mask, (0, max_speaker_latent_len - speaker_mask.shape[1]))
203
+
204
+ return speaker_latent, speaker_mask
205
+
206
+
207
+ ### Full sample pipeline
208
+
209
+ def find_flattening_point(data, target_value=0.0, window_size=20, std_threshold=0.05):
210
+ padded_data = torch.cat([data, torch.zeros(window_size, *data.shape[1:], device=data.device, dtype=data.dtype)])
211
+ for i in range(len(padded_data) - window_size):
212
+ window = padded_data[i:i + window_size]
213
+ if window.std() < std_threshold and abs(window.mean() - target_value) < 0.1:
214
+ return i
215
+ return len(data)
216
+
217
+
218
+ @torch.inference_mode()
219
+ def sample_pipeline(
220
+ model: EchoDiT,
221
+ fish_ae: DAC,
222
+ pca_state: PCAState,
223
+ sample_fn: SampleFn,
224
+ text_prompt: str,
225
+ speaker_audio: torch.Tensor | None,
226
+ rng_seed: int,
227
+ pad_to_max_speaker_latent_len: int | None = 2560,
228
+ pad_to_max_text_seq_len: int | None = 768,
229
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
230
+
231
+ MAX_SPEAKER_LATENT_LEN = 2560
232
+ MAX_TEXT_SEQ_LEN = 768
233
+
234
+ device, dtype = model.device, model.dtype
235
+
236
+ text_input_ids, text_mask = get_text_input_ids_and_mask([text_prompt], min(pad_to_max_text_seq_len or MAX_TEXT_SEQ_LEN, MAX_TEXT_SEQ_LEN), device=device)
237
+
238
+ # print('initial text input ids length: ', text_input_ids.shape[1])
239
+ # torch.cuda.synchronize()
240
+
241
+ # import time
242
+
243
+ # t0 = time.time()
244
+
245
+ if speaker_audio is None:
246
+ # No speaker prompt - use zero speaker latent and mask
247
+ speaker_latent = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN, 80), device=device, dtype=dtype)
248
+ speaker_mask = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN), device=device, dtype=torch.bool)
249
+ # print("Using zero speaker latent and mask (no speaker prompt)")
250
+ else:
251
+ speaker_latent, speaker_mask = get_speaker_latent_and_mask(
252
+ fish_ae,
253
+ pca_state,
254
+ speaker_audio.to(fish_ae.dtype),
255
+ max_speaker_latent_len=pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN
256
+ )
257
+ speaker_latent = speaker_latent.to(device)
258
+ speaker_mask = speaker_mask.to(device)
259
+
260
+ # print('speaker latent shape: ', speaker_latent.shape)
261
+ # print('speaker mask shape: ', speaker_mask.shape)
262
+
263
+ # torch.cuda.synchronize()
264
+ # t1 = time.time()
265
+ # print(f"Time taken encode: {t1 - t0} seconds")
266
+
267
+ latent_out = sample_fn(model, speaker_latent, speaker_mask, text_input_ids, text_mask, rng_seed)
268
+
269
+ # torch.cuda.synchronize()
270
+ # t2 = time.time()
271
+
272
+ # print(f"Time taken sample: {t2 - t1} seconds")
273
+
274
+ audio_out = ae_decode(fish_ae, pca_state, latent_out)
275
+ # torch.cuda.synchronize()
276
+ # t3 = time.time()
277
+ # print(f"Time taken decode: {t3 - t2} seconds")
278
+
279
+ flattening_point = find_flattening_point(latent_out[0])
280
+ audio_out = audio_out[..., :flattening_point * 2048]
281
+
282
+ # print(f"\nTime taken total: {t3 - t0} seconds")
283
+
284
+ # peak_mem = torch.cuda.max_memory_allocated()
285
+ # print(f"Peak memory: {peak_mem / 1024**2:.2f} MB")
286
+ # print(torch.cuda.memory_summary(abbreviated=True))
287
+
288
+ return audio_out
289
+
290
+
model.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+
7
+ import torch.nn.functional as F
8
+
9
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
10
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)] / dim))
11
+ t = torch.arange(end)
12
+ freqs = torch.outer(t, freqs)
13
+ freqs_cis = torch.complex(torch.cos(freqs), torch.sin(freqs))
14
+ return freqs_cis
15
+
16
+
17
+ def apply_rotary_emb(
18
+ x: torch.Tensor,
19
+ freqs_cis: torch.Tensor,
20
+ ) -> torch.Tensor:
21
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:3], -1, 2))
22
+ x_ = x_ * freqs_cis[..., None, :]
23
+ x_ = torch.view_as_real(x_).reshape(x.shape)
24
+ return x_.type_as(x)
25
+
26
+
27
+ def get_timestep_embedding(
28
+ timestep: torch.Tensor,
29
+ embed_size: int,
30
+ ) -> torch.Tensor:
31
+ assert embed_size % 2 == 0
32
+
33
+ half = embed_size // 2
34
+
35
+ freqs = 1000 * torch.exp(
36
+ -torch.log(torch.tensor(10000.0)) *
37
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
38
+ ).to(timestep.device)
39
+
40
+ args = timestep[..., None] * freqs[None]
41
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
+
43
+ return embedding.to(timestep.dtype)
44
+
45
+
46
+ class LowRankAdaLN(nn.Module):
47
+ def __init__(
48
+ self,
49
+ model_size: int,
50
+ rank: int,
51
+ eps: float
52
+ ):
53
+ super().__init__()
54
+ self.eps = eps
55
+
56
+ self.shift_down = nn.Linear(model_size, rank, bias=False)
57
+ self.scale_down = nn.Linear(model_size, rank, bias=False)
58
+ self.gate_down = nn.Linear(model_size, rank, bias=False)
59
+
60
+ self.shift_up = nn.Linear(rank, model_size, bias=True)
61
+ self.scale_up = nn.Linear(rank, model_size, bias=True)
62
+ self.gate_up = nn.Linear(rank, model_size, bias=True)
63
+
64
+ def forward(
65
+ self,
66
+ x: torch.Tensor,
67
+ cond_embed: torch.Tensor,
68
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
69
+
70
+ shift, scale, gate = cond_embed.chunk(3, dim=-1)
71
+
72
+ shift = self.shift_up(self.shift_down(F.silu(shift))) + shift
73
+ scale = self.scale_up(self.scale_down(F.silu(scale))) + scale
74
+ gate = self.gate_up(self.gate_down(F.silu(gate))) + gate
75
+
76
+ x_dtype = x.dtype
77
+ x = x.float()
78
+ x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps)
79
+ x = x * (scale + 1) + shift
80
+
81
+ gate = torch.tanh(gate)
82
+
83
+ return x.to(x_dtype), gate
84
+
85
+
86
+ class RMSNorm(nn.Module): # could also just use torch rmsnorm
87
+ def __init__(
88
+ self,
89
+ model_size: int | Tuple[int, int],
90
+ eps: float
91
+ ):
92
+ super().__init__()
93
+ self.eps = eps
94
+
95
+ if isinstance(model_size, int):
96
+ model_size = (model_size, )
97
+ self.weight = nn.Parameter(torch.ones(model_size))
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ x_dtype = x.dtype
101
+ x = x.float()
102
+ x = x * torch.rsqrt(torch.pow(x.float(), 2).mean(dim=-1, keepdim=True) + self.eps)
103
+ x = x * self.weight
104
+ return x.to(x_dtype)
105
+
106
+ class SelfAttention(nn.Module):
107
+ def __init__(
108
+ self,
109
+ model_size: int,
110
+ num_heads: int,
111
+ is_causal: bool,
112
+ norm_eps: float
113
+ ):
114
+ super().__init__()
115
+ self.num_heads = num_heads
116
+ self.is_causal = is_causal
117
+
118
+ self.wq = nn.Linear(model_size, model_size, bias=False)
119
+ self.wk = nn.Linear(model_size, model_size, bias=False)
120
+ self.wv = nn.Linear(model_size, model_size, bias=False)
121
+ self.wo = nn.Linear(model_size, model_size, bias=False)
122
+ self.gate = nn.Linear(model_size, model_size, bias=False)
123
+
124
+ assert model_size % num_heads == 0
125
+ self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
126
+ self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
127
+
128
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
129
+
130
+ batch_size, seq_len = x.shape[:2]
131
+
132
+ xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1)
133
+ xk = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1)
134
+ xv = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1)
135
+
136
+ gate = self.gate(x)
137
+
138
+ xq = self.q_norm(xq)
139
+ xk = self.k_norm(xk)
140
+
141
+ xq = apply_rotary_emb(xq, freqs_cis[:seq_len])
142
+ xk = apply_rotary_emb(xk, freqs_cis[:seq_len])
143
+
144
+ if mask is not None:
145
+ assert mask.ndim == 2 # (b, s)
146
+ mask = mask[:, None, None]
147
+
148
+ output = F.scaled_dot_product_attention(
149
+ query=xq.transpose(1, 2),
150
+ key=xk.transpose(1, 2),
151
+ value=xv.transpose(1, 2),
152
+ attn_mask=mask,
153
+ is_causal=self.is_causal
154
+ ).transpose(1, 2)
155
+
156
+ output = output.reshape(batch_size, seq_len, -1)
157
+ output = output * torch.sigmoid(gate)
158
+
159
+ output = self.wo(output)
160
+
161
+ return output
162
+
163
+ class JointAttention(nn.Module):
164
+ def __init__(
165
+ self,
166
+ model_size: int,
167
+ num_heads: int,
168
+ text_model_size: int,
169
+ speaker_model_size: int,
170
+ speaker_patch_size: int,
171
+ norm_eps: float
172
+ ):
173
+ super().__init__()
174
+ self.speaker_patch_size = speaker_patch_size
175
+ self.num_heads = num_heads
176
+
177
+ self.wq = nn.Linear(model_size, model_size, bias=False)
178
+ self.wk = nn.Linear(model_size, model_size, bias=False)
179
+ self.wv = nn.Linear(model_size, model_size, bias=False)
180
+
181
+ self.wk_text = nn.Linear(text_model_size, model_size, bias=False)
182
+ self.wv_text = nn.Linear(text_model_size, model_size, bias=False)
183
+
184
+ self.wk_speaker = nn.Linear(speaker_model_size, model_size, bias=False)
185
+ self.wv_speaker = nn.Linear(speaker_model_size, model_size, bias=False)
186
+
187
+ assert model_size % num_heads == 0
188
+ self.q_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
189
+ self.k_norm = RMSNorm((num_heads, model_size // num_heads), eps=norm_eps)
190
+
191
+ self.gate = nn.Linear(model_size, model_size, bias=False)
192
+
193
+ self.wo = nn.Linear(model_size, model_size, bias=False)
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ text_state: torch.Tensor | None,
199
+ text_mask: torch.Tensor,
200
+ speaker_state: torch.Tensor | None,
201
+ speaker_mask: torch.Tensor,
202
+ freqs_cis: torch.Tensor,
203
+ kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
204
+ ) -> torch.Tensor:
205
+ batch_size, seq_len = x.shape[:2]
206
+
207
+ xq = self.wq(x).reshape(batch_size, seq_len, self.num_heads, -1)
208
+ xk_self = self.wk(x).reshape(batch_size, seq_len, self.num_heads, -1)
209
+ xv_self = self.wv(x).reshape(batch_size, seq_len, self.num_heads, -1)
210
+
211
+ xq = self.q_norm(xq)
212
+ xk_self = self.k_norm(xk_self)
213
+
214
+ gate = self.gate(x)
215
+
216
+
217
+ def _apply_rotary_half(y: torch.Tensor, fc: torch.Tensor) -> torch.Tensor:
218
+ y1, y2 = y.chunk(2, dim=-2)
219
+ y1 = apply_rotary_emb(y1, fc)
220
+ return torch.cat([y1, y2], dim=-2)
221
+
222
+ xq = _apply_rotary_half(xq, freqs_cis)
223
+ xk_self = _apply_rotary_half(xk_self, freqs_cis)
224
+
225
+ if kv_cache is None:
226
+
227
+ xk_text = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
228
+ xv_text = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
229
+
230
+ xk_speaker = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
231
+ xv_speaker = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
232
+
233
+ xk_text = self.k_norm(xk_text)
234
+ xk_speaker = self.k_norm(xk_speaker)
235
+
236
+ xk = torch.cat([xk_self, xk_text, xk_speaker], dim=1)
237
+ xv = torch.cat([xv_self, xv_text, xv_speaker], dim=1)
238
+
239
+ else:
240
+ xk_cross, xv_cross = kv_cache
241
+ xk = torch.cat([xk_self, xk_cross], dim=1)
242
+ xv = torch.cat([xv_self, xv_cross], dim=1)
243
+
244
+ self_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=x.device)
245
+ mask = torch.cat([self_mask, text_mask, speaker_mask], dim=1)
246
+ mask = mask[:, None, None]
247
+
248
+ output = F.scaled_dot_product_attention(
249
+ query=xq.transpose(1, 2),
250
+ key=xk.transpose(1, 2),
251
+ value=xv.transpose(1, 2),
252
+ attn_mask=mask,
253
+ is_causal=False
254
+ ).transpose(1, 2)
255
+
256
+ output = output.reshape(batch_size, seq_len, -1)
257
+ output = output * torch.sigmoid(gate)
258
+
259
+ output = self.wo(output)
260
+
261
+ return output
262
+
263
+ def get_kv_cache(
264
+ self,
265
+ text_state: torch.Tensor,
266
+ speaker_state: torch.Tensor,
267
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
268
+
269
+ batch_size = text_state.shape[0]
270
+
271
+ xk_text = self.wk_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
272
+ xv_text = self.wv_text(text_state).reshape(batch_size, text_state.shape[1], self.num_heads, -1)
273
+
274
+ xk_speaker = self.wk_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
275
+ xv_speaker = self.wv_speaker(speaker_state).reshape(batch_size, speaker_state.shape[1], self.num_heads, -1)
276
+
277
+ xk = torch.cat([xk_text, xk_speaker], dim=1)
278
+ xv = torch.cat([xv_text, xv_speaker], dim=1)
279
+
280
+ xk = self.k_norm(xk)
281
+
282
+ return xk, xv
283
+
284
+ class MLP(nn.Module):
285
+ def __init__(
286
+ self,
287
+ model_size: int,
288
+ intermediate_size: int
289
+ ):
290
+ super().__init__()
291
+ self.w1 = nn.Linear(model_size, intermediate_size, bias=False)
292
+ self.w3 = nn.Linear(model_size, intermediate_size, bias=False)
293
+ self.w2 = nn.Linear(intermediate_size, model_size, bias=False)
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
297
+
298
+
299
+ class EncoderTransformerBlock(nn.Module):
300
+ def __init__(
301
+ self,
302
+ model_size: int,
303
+ num_heads: int,
304
+ intermediate_size: int,
305
+ is_causal: bool,
306
+ norm_eps: float
307
+ ):
308
+ super().__init__()
309
+ self.attention = SelfAttention(
310
+ model_size=model_size,
311
+ num_heads=num_heads,
312
+ is_causal=is_causal,
313
+ norm_eps=norm_eps
314
+ )
315
+ self.mlp = MLP(
316
+ model_size=model_size,
317
+ intermediate_size=intermediate_size
318
+ )
319
+
320
+ self.attention_norm = RMSNorm(model_size, norm_eps)
321
+ self.mlp_norm = RMSNorm(model_size, norm_eps)
322
+
323
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor:
324
+ x = x + self.attention(self.attention_norm(x), mask, freqs_cis)
325
+ x = x + self.mlp(self.mlp_norm(x))
326
+
327
+ return x
328
+
329
+ class TransformerBlock(nn.Module):
330
+ def __init__(
331
+ self,
332
+ model_size: int,
333
+ num_heads: int,
334
+ intermediate_size: int,
335
+ norm_eps: float,
336
+ text_model_size: int,
337
+ speaker_model_size: int,
338
+ speaker_patch_size: int,
339
+ adaln_rank: int,
340
+ ):
341
+ super().__init__()
342
+ self.attention = JointAttention(
343
+ model_size=model_size,
344
+ num_heads=num_heads,
345
+ text_model_size=text_model_size,
346
+ speaker_model_size=speaker_model_size,
347
+ speaker_patch_size=speaker_patch_size,
348
+ norm_eps=norm_eps
349
+ )
350
+
351
+ self.mlp = MLP(
352
+ model_size=model_size,
353
+ intermediate_size=intermediate_size
354
+ )
355
+
356
+ self.attention_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps)
357
+ self.mlp_adaln = LowRankAdaLN(model_size=model_size, rank=adaln_rank, eps=norm_eps)
358
+
359
+ def forward(
360
+ self,
361
+ x: torch.Tensor,
362
+ cond_embed: torch.Tensor,
363
+ text_state: torch.Tensor | None,
364
+ text_mask: torch.Tensor,
365
+ speaker_state: torch.Tensor | None,
366
+ speaker_mask: torch.Tensor,
367
+ freqs_cis: torch.Tensor,
368
+ kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
369
+ ) -> torch.Tensor:
370
+
371
+ x_norm, attention_gate = self.attention_adaln(x, cond_embed)
372
+ x = x + attention_gate * self.attention(x_norm, text_state, text_mask, speaker_state, speaker_mask, freqs_cis, kv_cache)
373
+
374
+ x_norm, mlp_gate = self.mlp_adaln(x, cond_embed)
375
+ x = x + mlp_gate * self.mlp(x_norm)
376
+
377
+ return x
378
+
379
+ def get_kv_cache(
380
+ self,
381
+ text_state: torch.Tensor,
382
+ speaker_state: torch.Tensor,
383
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
384
+ return self.attention.get_kv_cache(text_state, speaker_state)
385
+
386
+ class TextEncoder(nn.Module):
387
+ def __init__(
388
+ self,
389
+ vocab_size: int,
390
+ model_size: int,
391
+ num_layers: int,
392
+ num_heads: int,
393
+ intermediate_size: int,
394
+ norm_eps: float,
395
+ max_seq_len: int,
396
+ ):
397
+ super().__init__()
398
+ self.text_embedding = nn.Embedding(vocab_size, model_size)
399
+
400
+ self.blocks = nn.ModuleList()
401
+ for i in range(num_layers):
402
+ block = EncoderTransformerBlock(
403
+ model_size=model_size,
404
+ num_heads=num_heads,
405
+ intermediate_size=intermediate_size,
406
+ is_causal=False,
407
+ norm_eps=norm_eps
408
+ )
409
+ self.blocks.append(block)
410
+
411
+ self.head_dim = model_size // num_heads
412
+
413
+
414
+ def forward(self, input_ids: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
415
+ x = self.text_embedding(input_ids)
416
+
417
+ freqs_cis = precompute_freqs_cis(self.head_dim, input_ids.shape[1]).to(x.device) # see below about avoiding recomputation
418
+ for block in self.blocks:
419
+ x = block(x, mask, freqs_cis)
420
+
421
+ return x
422
+
423
+ class SpeakerEncoder(nn.Module):
424
+ def __init__(
425
+ self,
426
+ latent_size: int,
427
+ patch_size: int,
428
+ model_size: int,
429
+ num_layers: int,
430
+ num_heads: int,
431
+ intermediate_size: int,
432
+ norm_eps: float,
433
+ max_patched_seq_len: int,
434
+ ):
435
+ super().__init__()
436
+ self.patch_size = patch_size
437
+
438
+ self.in_proj = nn.Linear(latent_size * patch_size, model_size, bias=True)
439
+
440
+ self.blocks = nn.ModuleList()
441
+ for i in range(num_layers):
442
+ block = EncoderTransformerBlock(
443
+ model_size=model_size,
444
+ num_heads=num_heads,
445
+ intermediate_size=intermediate_size,
446
+ is_causal=True,
447
+ norm_eps=norm_eps
448
+ )
449
+ self.blocks.append(block)
450
+
451
+ self.head_dim = model_size // num_heads
452
+
453
+ def forward(self, latent: torch.Tensor) -> torch.Tensor:
454
+ x = latent.reshape(*latent.shape[:-2], latent.shape[-2] // self.patch_size, latent.shape[-1] * self.patch_size)
455
+
456
+ x = self.in_proj(x)
457
+ x = x / 6. # this helped with initial activation dynamics in early ablations, could also bake into in_proj
458
+
459
+ freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device) # see below about avoiding recomputation
460
+
461
+ for block in self.blocks:
462
+ x = block(x, None, freqs_cis)
463
+
464
+ return x
465
+
466
+
467
+ class EchoDiT(nn.Module):
468
+ def __init__(
469
+ self,
470
+ latent_size: int,
471
+ #
472
+ model_size: int,
473
+ num_layers: int,
474
+ num_heads: int,
475
+ intermediate_size: int,
476
+ norm_eps: float,
477
+ max_seq_len: int,
478
+ #
479
+ text_vocab_size: int,
480
+ text_model_size: int,
481
+ text_num_layers: int,
482
+ text_num_heads: int,
483
+ text_intermediate_size: int,
484
+ text_max_seq_len: int,
485
+ #
486
+ speaker_patch_size: int,
487
+ speaker_model_size: int,
488
+ speaker_num_layers: int,
489
+ speaker_num_heads: int,
490
+ speaker_intermediate_size: int,
491
+ speaker_max_patched_seq_len: int,
492
+ #
493
+ timestep_embed_size: int,
494
+ adaln_rank: int,
495
+ ):
496
+ super().__init__()
497
+ self.speaker_patch_size = speaker_patch_size
498
+ self.timestep_embed_size = timestep_embed_size
499
+
500
+ self.text_encoder = TextEncoder(
501
+ vocab_size=text_vocab_size,
502
+ model_size=text_model_size,
503
+ num_layers=text_num_layers,
504
+ num_heads=text_num_heads,
505
+ intermediate_size=text_intermediate_size,
506
+ norm_eps=norm_eps,
507
+ max_seq_len=text_max_seq_len,
508
+ )
509
+ self.speaker_encoder = SpeakerEncoder(
510
+ latent_size=latent_size,
511
+ patch_size=speaker_patch_size,
512
+ model_size=speaker_model_size,
513
+ num_layers=speaker_num_layers,
514
+ num_heads=speaker_num_heads,
515
+ intermediate_size=speaker_intermediate_size,
516
+ norm_eps=norm_eps,
517
+ max_patched_seq_len=speaker_max_patched_seq_len,
518
+ )
519
+
520
+ self.text_norm = RMSNorm(text_model_size, norm_eps)
521
+ self.speaker_norm = RMSNorm(speaker_model_size, norm_eps)
522
+
523
+ self.cond_module = nn.Sequential(
524
+ nn.Linear(timestep_embed_size, model_size, bias=False),
525
+ nn.SiLU(),
526
+ nn.Linear(model_size, model_size, bias=False),
527
+ nn.SiLU(),
528
+ nn.Linear(model_size, model_size * 3, bias=False),
529
+ )
530
+
531
+ self.in_proj = nn.Linear(latent_size, model_size, bias=True)
532
+
533
+ self.blocks = nn.ModuleList()
534
+ for i in range(num_layers):
535
+ block = TransformerBlock(
536
+ model_size=model_size,
537
+ num_heads=num_heads,
538
+ intermediate_size=intermediate_size,
539
+ norm_eps=norm_eps,
540
+ text_model_size=text_model_size,
541
+ speaker_model_size=speaker_model_size,
542
+ speaker_patch_size=speaker_patch_size,
543
+ adaln_rank=adaln_rank,
544
+ )
545
+ self.blocks.append(block)
546
+
547
+ self.out_norm = RMSNorm(model_size, norm_eps)
548
+ self.out_proj = nn.Linear(model_size, latent_size, bias=True)
549
+
550
+ self.head_dim = model_size // num_heads
551
+
552
+
553
+ def forward(
554
+ self,
555
+ x: torch.Tensor,
556
+ t: torch.Tensor,
557
+ text_input_ids: torch.Tensor,
558
+ text_mask: torch.Tensor | None,
559
+ speaker_latent: torch.Tensor,
560
+ speaker_mask: torch.Tensor | None,
561
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
562
+ ) -> torch.Tensor:
563
+ """
564
+ x: (b, s, d)
565
+ t: (b,)
566
+ text_input_ids: (b, s_t) # not used when kv_cache is provided
567
+ text_mask: (b, s_t)
568
+ speaker_latent: (b, s_r, d) # not used when kv_cache is provided
569
+ speaker_mask: (b, s_r)
570
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]]
571
+
572
+ returns: (b, s, d)
573
+ """
574
+
575
+ freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]).to(x.device)
576
+ # can't register as buffer because we'd like it to stay in fp32; however, could optionally pass in to avoid recomputing
577
+
578
+ if kv_cache is None and speaker_state is None:
579
+ text_state = self.text_encoder(text_input_ids, text_mask)
580
+ text_state = self.text_norm(text_state)
581
+ speaker_state = self.speaker_encoder(speaker_latent)
582
+ speaker_state = self.speaker_norm(speaker_state)
583
+ else:
584
+ text_state, speaker_state = None, None
585
+
586
+ speaker_mask = speaker_mask[..., ::self.speaker_patch_size]
587
+
588
+ cond_embed = self.cond_module(get_timestep_embedding(t, self.timestep_embed_size))
589
+
590
+ assert cond_embed.ndim == 2
591
+ cond_embed = cond_embed[:, None]
592
+
593
+ x = self.in_proj(x)
594
+
595
+ for i, block in enumerate(self.blocks):
596
+ x = block(
597
+ x=x,
598
+ cond_embed=cond_embed,
599
+ text_state=text_state,
600
+ text_mask=text_mask,
601
+ speaker_state=speaker_state,
602
+ speaker_mask=speaker_mask,
603
+ freqs_cis=freqs_cis,
604
+ kv_cache=kv_cache[i] if kv_cache is not None else None,
605
+ )
606
+
607
+ x = self.out_norm(x)
608
+ x = self.out_proj(x)
609
+
610
+ return x.float()
611
+
612
+ def get_kv_cache(
613
+ self,
614
+ speaker_latent: torch.Tensor,
615
+ speaker_mask: torch.Tensor,
616
+ text_input_ids: torch.Tensor,
617
+ text_mask: torch.Tensor,
618
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
619
+
620
+ speaker_state = self.speaker_encoder(speaker_latent)
621
+ speaker_state = self.speaker_norm(speaker_state)
622
+
623
+ text_state = self.text_encoder(text_input_ids, text_mask)
624
+ text_state = self.text_norm(text_state)
625
+
626
+ return [self.blocks[i].get_kv_cache(text_state, speaker_state) for i in range(len(self.blocks))]
627
+
628
+
629
+ def get_kv_cache_from_precomputed_speaker_state(
630
+ self,
631
+ speaker_state: torch.Tensor,
632
+ speaker_mask: torch.Tensor,
633
+ text_input_ids: torch.Tensor,
634
+ text_mask: torch.Tensor,
635
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
636
+
637
+ # here, speaker state is already computed from the speaker latent encoder transformer
638
+
639
+ text_state = self.text_encoder(text_input_ids, text_mask)
640
+ text_state = self.text_norm(text_state)
641
+
642
+ return [self.blocks[i].get_kv_cache(text_state, speaker_state) for i in range(len(self.blocks))]
643
+
644
+
645
+
646
+ @property
647
+ def device(self) -> torch.device: return next(self.parameters()).device
648
+
649
+ @property
650
+ def dtype(self) -> torch.dtype: return next(self.parameters()).dtype
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
prompt_audio/EARS p004 freeform.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68947a209bc11064f749ca0a61b7959243df83565a0e462b87dfc0ffe03aa7b0
3
+ size 1526439
prompt_audio/EARS p005 freeform.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07344d073eb3e22c249ebfe15f31f4ba63fd9f17c71aeee93da199ff3b53fc45
3
+ size 1351147
prompt_audio/EARS p028 freeform.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8351eed5982f1fb5763a475c0fb69dba98a4bb49b0f2bbab12b978ff2b0fedeb
3
+ size 1211565
prompt_audio/EARS p036 freeform.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce77dbb86ea7c29edf2b9804ce9c9315334e9cfeef532dc0c50898a09bae1583
3
+ size 1227585
prompt_audio/expresso_02_ex03-ex01_calm_005.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2be4d1cb5646b3523a460ec40bf171f959a9b33bde918e6d0f795d00284f52a
3
+ size 21168080
prompt_audio/freesound_demon_chant(use_forcespeaker).mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:471f67fff5ea613ec4617b9822b1396da123a1133f199925436a2c40e5d1eb91
3
+ size 303438
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ torchcodec
4
+ gradio>=5.49
5
+ huggingface-hub
6
+ numpy
7
+ safetensors
8
+ einops
sampler_presets.json ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+
3
+ "Independent (High Speaker CFG)": {
4
+ "num_steps": "40",
5
+ "cfg_mode": "independent",
6
+ "cfg_scale_text": "3.0",
7
+ "cfg_scale_speaker": "8.0",
8
+ "cfg_min_t": "0.5",
9
+ "cfg_max_t": "1.0",
10
+ "truncation_factor": "1.",
11
+ "rescale_k": "1.",
12
+ "rescale_sigma": "3.0"
13
+ },
14
+ "Independent (High Speaker CFG) Flat": {
15
+ "num_steps": "40",
16
+ "cfg_mode": "independent",
17
+ "cfg_scale_text": "3.0",
18
+ "cfg_scale_speaker": "8.0",
19
+ "cfg_min_t": "0.5",
20
+ "cfg_max_t": "1.0",
21
+ "truncation_factor": "0.8",
22
+ "rescale_k": "1.2",
23
+ "rescale_sigma": "3.0"
24
+ },
25
+ "APG": {
26
+ "num_steps": "40",
27
+ "cfg_mode": "apg-independent",
28
+ "cfg_scale_text": "8.0",
29
+ "cfg_scale_speaker": "8.0",
30
+ "cfg_min_t": "0.5",
31
+ "cfg_max_t": "1.0",
32
+ "truncation_factor": "1.",
33
+ "rescale_k": "1.",
34
+ "rescale_sigma": "3.0",
35
+ "speaker_k_enable": false,
36
+ "speaker_k_scale": "1.5",
37
+ "speaker_k_min_t": "0.9",
38
+ "speaker_k_max_layers": "24",
39
+ "apg_eta_text": "0.5",
40
+ "apg_eta_speaker": "0.5",
41
+ "apg_momentum_text": "0.0",
42
+ "apg_momentum_speaker": "0.0"
43
+ },
44
+ "APG Flat": {
45
+ "num_steps": "40",
46
+ "cfg_mode": "apg-independent",
47
+ "cfg_scale_text": "8.0",
48
+ "cfg_scale_speaker": "8.0",
49
+ "cfg_min_t": "0.5",
50
+ "cfg_max_t": "1.0",
51
+ "truncation_factor": "0.8",
52
+ "rescale_k": "1.2",
53
+ "rescale_sigma": "3.0",
54
+ "speaker_k_enable": false,
55
+ "speaker_k_scale": "1.5",
56
+ "speaker_k_min_t": "0.9",
57
+ "speaker_k_max_layers": "24",
58
+ "apg_eta_text": "0.5",
59
+ "apg_eta_speaker": "0.5",
60
+ "apg_momentum_text": "0.0",
61
+ "apg_momentum_speaker": "0.0"
62
+ },
63
+ "Independent (High CFG)": {
64
+ "num_steps": "40",
65
+ "cfg_mode": "independent",
66
+ "cfg_scale_text": "8.0",
67
+ "cfg_scale_speaker": "8.0",
68
+ "cfg_min_t": "0.5",
69
+ "cfg_max_t": "1.0",
70
+ "truncation_factor": "1.",
71
+ "rescale_k": "1.",
72
+ "rescale_sigma": "3.0"
73
+ },
74
+ "Independent (High CFG) Flat": {
75
+ "num_steps": "40",
76
+ "cfg_mode": "independent",
77
+ "cfg_scale_text": "8.0",
78
+ "cfg_scale_speaker": "8.0",
79
+ "cfg_min_t": "0.5",
80
+ "cfg_max_t": "1.0",
81
+ "truncation_factor": "0.8",
82
+ "rescale_k": "1.2",
83
+ "rescale_sigma": "3.0"
84
+ },
85
+
86
+ "Independent": {
87
+ "num_steps": "40",
88
+ "cfg_mode": "independent",
89
+ "cfg_scale_text": "3.0",
90
+ "cfg_scale_speaker": "5.0",
91
+ "cfg_min_t": "0.5",
92
+ "cfg_max_t": "1.0",
93
+ "truncation_factor": "1.",
94
+ "rescale_k": "1.",
95
+ "rescale_sigma": "3.0"
96
+ },
97
+ "Independent Flat": {
98
+ "num_steps": "40",
99
+ "cfg_mode": "independent",
100
+ "cfg_scale_text": "3.0",
101
+ "cfg_scale_speaker": "5.0",
102
+ "cfg_min_t": "0.5",
103
+ "cfg_max_t": "1.0",
104
+ "truncation_factor": "0.8",
105
+ "rescale_k": "1.2",
106
+ "rescale_sigma": "3.0"
107
+ },
108
+ "Joint 20-step Flat": {
109
+ "num_steps": "20",
110
+ "cfg_mode": "joint-unconditional",
111
+ "cfg_scale_text": "3.0",
112
+ "cfg_scale_speaker": "3.0",
113
+ "cfg_min_t": "0.5",
114
+ "cfg_max_t": "1.0",
115
+ "truncation_factor": "0.8",
116
+ "rescale_k": "1.2",
117
+ "rescale_sigma": "3.0"
118
+ }
119
+ }
120
+
samplers.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from enum import Enum
3
+
4
+ import torch
5
+ from model import EchoDiT
6
+
7
+ # helper
8
+ def _get_uncond_text_input_ids_and_mask(batch_size: int, max_length: int, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]:
9
+ # returns zeros for text input ids, and (True, False, False, ... ) for text mask
10
+ text_input_ids_uncond = torch.zeros((batch_size, max_length), dtype=torch.int32)
11
+ text_mask_uncond = torch.zeros((batch_size, max_length), dtype=torch.bool)
12
+ text_mask_uncond[:, 0] = True
13
+ if device is not None:
14
+ text_input_ids_uncond = text_input_ids_uncond.to(device)
15
+ text_mask_uncond = text_mask_uncond.to(device)
16
+ return text_input_ids_uncond, text_mask_uncond
17
+
18
+
19
+ # SIMPLE SAMPLER FOR REFERENCE, SHOULD PROBABLY AVOID
20
+ @torch.inference_mode()
21
+ def sample_euler_cfg_simple(
22
+ model: EchoDiT,
23
+ speaker_latent: torch.Tensor,
24
+ speaker_mask: torch.Tensor,
25
+ text_input_ids: torch.Tensor,
26
+ text_mask: torch.Tensor,
27
+ rng_seed: int,
28
+ num_steps: int,
29
+ cfg_scale: float,
30
+ ) -> torch.Tensor:
31
+
32
+ device, dtype = model.device, model.dtype
33
+
34
+ batch_size = text_input_ids.shape[0]
35
+
36
+ torch.manual_seed(rng_seed)
37
+
38
+ t_schedule = torch.linspace(1., 0., num_steps + 1, device=device)
39
+
40
+ text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
41
+
42
+ speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
43
+
44
+ full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0)
45
+ full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
46
+
47
+ full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
48
+ full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
49
+
50
+ kv_cache = model.get_kv_cache(
51
+ speaker_latent=full_speaker_latent.to(dtype),
52
+ speaker_mask=full_speaker_mask,
53
+ text_input_ids=full_text_input_ids,
54
+ text_mask=full_text_mask,
55
+ )
56
+
57
+ x_t = torch.randn((batch_size, 640, 80), device=device, dtype=torch.float32)
58
+
59
+ for i in range(num_steps):
60
+ t, t_next = t_schedule[i], t_schedule[i+1]
61
+ v_cond, v_uncond = model(
62
+ x=torch.cat([x_t, x_t], dim=0).to(dtype),
63
+ t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
64
+ text_input_ids=None,
65
+ text_mask=full_text_mask,
66
+ speaker_latent=None,
67
+ speaker_mask=full_speaker_mask,
68
+ kv_cache=kv_cache,
69
+ ).float().chunk(2, dim=0)
70
+
71
+ v_pred = v_cond + cfg_scale * (v_cond - v_uncond)
72
+ # note: x_0_pred is x_t - v_pred * t
73
+ x_t = x_t + v_pred * (t_next - t)
74
+
75
+ return x_t
76
+
77
+
78
+ ######
79
+
80
+ def _temporal_score_rescale(v_pred: torch.Tensor, x_t: torch.Tensor, t: float, rescale_k: float, rescale_sigma: float) -> torch.Tensor:
81
+ if t < 1:
82
+ snr = (1 - t) ** 2 / (t ** 2)
83
+ ratio = (snr * rescale_sigma ** 2 + 1) / (snr * rescale_sigma ** 2 / rescale_k + 1)
84
+ return 1 / (1 - t) * (ratio * ((1 - t) * v_pred + x_t) - x_t)
85
+ return v_pred
86
+
87
+
88
+ def _get_first_n_kv_cache(kv_cache: List[List[torch.Tensor]], n: int) -> List[List[torch.Tensor]]:
89
+ return [[kv_cache[i][0][:n], kv_cache[i][1][:n]] for i in range(len(kv_cache))]
90
+
91
+ def _multiply_speaker_kv_cache(
92
+ kv_cache: List[List[torch.Tensor]],
93
+ scale: float,
94
+ text_length: int,
95
+ max_layers: int = 24,
96
+ ) -> List[List[torch.Tensor]]:
97
+ # multiplies speaker kv cache by scale
98
+ # speaker keys start after text keys (at position text_length)
99
+ for i in range(min(max_layers, len(kv_cache))):
100
+ for j in range(len(kv_cache[i])):
101
+ kv_cache[i][j][:, text_length:] *= scale
102
+
103
+
104
+ @torch.inference_mode()
105
+ def sample_euler_cfg(
106
+ model: EchoDiT,
107
+ speaker_latent: torch.Tensor,
108
+ speaker_mask: torch.Tensor,
109
+ text_input_ids: torch.Tensor,
110
+ text_mask: torch.Tensor,
111
+ rng_seed: int,
112
+ num_steps: int,
113
+ cfg_scale: float,
114
+ cfg_min_t: float,
115
+ cfg_max_t: float,
116
+ truncation_factor: float | None,
117
+ rescale_k: float | None,
118
+ rescale_sigma: float | None,
119
+ speaker_k_scale: float | None,
120
+ speaker_k_max_layers: int | None,
121
+ speaker_k_min_t: float | None,
122
+ block_size: int | None = None,
123
+ ) -> torch.Tensor:
124
+
125
+ if block_size is None:
126
+ block_size = 640
127
+
128
+ torch.manual_seed(rng_seed)
129
+
130
+ INIT_SCALE = 0.999
131
+
132
+ device, dtype = model.device, model.dtype
133
+
134
+ batch_size = text_input_ids.shape[0]
135
+
136
+ t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
137
+
138
+ text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
139
+
140
+ speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
141
+
142
+ full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond], dim=0)
143
+ full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
144
+
145
+ full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
146
+ full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
147
+
148
+ kv_cache_full = model.get_kv_cache(
149
+ speaker_latent=full_speaker_latent.to(dtype),
150
+ speaker_mask=full_speaker_mask,
151
+ text_input_ids=full_text_input_ids,
152
+ text_mask=full_text_mask,
153
+ ) # could make faster by not computing fully / recomputing for unconditional batch elements
154
+ kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
155
+ if speaker_k_scale is not None:
156
+ _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
157
+
158
+ x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
159
+
160
+ if truncation_factor is not None:
161
+ x_t = x_t * truncation_factor
162
+
163
+ for i in range(num_steps):
164
+ t, t_next = t_schedule[i], t_schedule[i+1]
165
+
166
+ has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
167
+
168
+ if has_cfg:
169
+ v_cond, v_uncond = model(
170
+ x=torch.cat([x_t, x_t], dim=0).to(dtype),
171
+ t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
172
+ text_input_ids=None,
173
+ text_mask=full_text_mask,
174
+ speaker_latent=None,
175
+ speaker_mask=full_speaker_mask,
176
+ kv_cache=kv_cache_full,
177
+ ).float().chunk(2, dim=0)
178
+ v_pred = v_cond + cfg_scale * (v_cond - v_uncond)
179
+ else:
180
+ v_pred = model(
181
+ x=x_t.to(dtype),
182
+ t=(torch.ones((batch_size,), device=device) * t).to(dtype),
183
+ text_input_ids=None,
184
+ text_mask=text_mask,
185
+ speaker_latent=None,
186
+ speaker_mask=speaker_mask,
187
+ kv_cache=kv_cache,
188
+ ).float()
189
+
190
+ if rescale_k is not None and rescale_sigma is not None:
191
+ v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
192
+
193
+ if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
194
+ _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
195
+
196
+ x_t = x_t + v_pred * (t_next - t)
197
+
198
+ return x_t
199
+
200
+
201
+ @torch.inference_mode()
202
+ def sample_euler_cfg_independent_guidances(
203
+ model: EchoDiT,
204
+ speaker_latent: torch.Tensor,
205
+ speaker_mask: torch.Tensor,
206
+ text_input_ids: torch.Tensor,
207
+ text_mask: torch.Tensor,
208
+ rng_seed: int,
209
+ num_steps: int,
210
+ cfg_scale_text: float,
211
+ cfg_scale_speaker: float,
212
+ cfg_min_t: float,
213
+ cfg_max_t: float,
214
+ truncation_factor: float | None,
215
+ rescale_k: float | None,
216
+ rescale_sigma: float | None,
217
+ speaker_k_scale: float | None,
218
+ speaker_k_max_layers: int | None,
219
+ speaker_k_min_t: float | None,
220
+ block_size: int | None = None,
221
+ ) -> torch.Tensor:
222
+
223
+ if block_size is None:
224
+ block_size = 640
225
+
226
+ torch.manual_seed(rng_seed)
227
+
228
+ INIT_SCALE = 0.999
229
+
230
+ device, dtype = model.device, model.dtype
231
+
232
+ batch_size = text_input_ids.shape[0]
233
+
234
+ t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
235
+
236
+ text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
237
+
238
+ speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
239
+
240
+ full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0)
241
+ full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
242
+
243
+ full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0)
244
+ full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
245
+
246
+ kv_cache_full = model.get_kv_cache(
247
+ speaker_latent=full_speaker_latent.to(dtype),
248
+ speaker_mask=full_speaker_mask,
249
+ text_input_ids=full_text_input_ids,
250
+ text_mask=full_text_mask,
251
+ )
252
+ kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
253
+
254
+ if speaker_k_scale is not None:
255
+ _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
256
+
257
+ x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
258
+ if truncation_factor is not None:
259
+ x_t = x_t * truncation_factor
260
+
261
+ for i in range(num_steps):
262
+ t, t_next = t_schedule[i], t_schedule[i+1]
263
+
264
+ has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
265
+
266
+ if has_cfg:
267
+ v_cond, v_uncond_text, v_uncond_speaker = model(
268
+ x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
269
+ t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
270
+ text_input_ids=None,
271
+ text_mask=full_text_mask,
272
+ speaker_latent=None,
273
+ speaker_mask=full_speaker_mask,
274
+ kv_cache=kv_cache_full,
275
+ ).float().chunk(3, dim=0)
276
+ v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + cfg_scale_speaker * (v_cond - v_uncond_speaker)
277
+ else:
278
+ v_pred = model(
279
+ x=x_t.to(dtype),
280
+ t=(torch.ones((batch_size,), device=device) * t).to(dtype),
281
+ text_input_ids=None,
282
+ text_mask=text_mask,
283
+ speaker_latent=None,
284
+ speaker_mask=speaker_mask,
285
+ kv_cache=kv_cache,
286
+ ).float()
287
+
288
+ if rescale_k is not None and rescale_sigma is not None:
289
+ v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
290
+
291
+ if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
292
+ _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
293
+
294
+ x_t = x_t + v_pred * (t_next - t)
295
+
296
+ return x_t
297
+
298
+
299
+
300
+ @torch.inference_mode()
301
+ def sample_euler_cfg_alternating_guidances(
302
+ model: EchoDiT,
303
+ speaker_latent: torch.Tensor,
304
+ speaker_mask: torch.Tensor,
305
+ text_input_ids: torch.Tensor,
306
+ text_mask: torch.Tensor,
307
+ rng_seed: int,
308
+ num_steps: int,
309
+ cfg_scale_text: float,
310
+ cfg_scale_speaker: float,
311
+ cfg_min_t: float,
312
+ cfg_max_t: float,
313
+ truncation_factor: float | None,
314
+ rescale_k: float | None,
315
+ rescale_sigma: float | None,
316
+ speaker_k_scale: float | None,
317
+ speaker_k_max_layers: int | None,
318
+ speaker_k_min_t: float | None,
319
+ block_size: int | None = None,
320
+ ) -> torch.Tensor:
321
+
322
+ if block_size is None:
323
+ block_size = 640
324
+
325
+ torch.manual_seed(rng_seed)
326
+
327
+ INIT_SCALE = 0.999
328
+
329
+ device, dtype = model.device, model.dtype
330
+
331
+ batch_size = text_input_ids.shape[0]
332
+
333
+ t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
334
+
335
+ text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
336
+
337
+ # TODO THIS / THE BELOW IS TECHNICALLY INCORRECT, AS IT ASSUMES A CAUSAL TEXT ENCODER (which is not the case)
338
+ # IF THE TEXT ENCODER WERE CAUSAL, THEN USING AN UNCOND TEXT MASK ON COND TEXT INPUTS GIVES YOU AN UNCOND STATE DUE TO BOS=0
339
+ # HOWEVER, MIGHT NOT MAKE MUCH OF A DIFFERENCE
340
+ # CHANGED ALL OTHER SAMPLERS TO USE CORRECT UNCONDITIONAL CACHES
341
+
342
+ speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
343
+
344
+ full_text_input_ids = torch.cat([text_input_ids, text_input_ids], dim=0)
345
+ full_text_mask = torch.cat([text_mask, text_mask_uncond], dim=0)
346
+
347
+ full_speaker_latent = torch.cat([speaker_latent, speaker_latent_uncond], dim=0)
348
+ full_speaker_mask = torch.cat([speaker_mask, speaker_mask_uncond], dim=0)
349
+
350
+ kv_cache_full = model.get_kv_cache(
351
+ speaker_latent=full_speaker_latent.to(dtype),
352
+ speaker_mask=full_speaker_mask,
353
+ text_input_ids=full_text_input_ids,
354
+ text_mask=full_text_mask,
355
+ )
356
+ kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
357
+
358
+ if speaker_k_scale is not None:
359
+ _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
360
+
361
+ x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
362
+ if truncation_factor is not None:
363
+ x_t = x_t * truncation_factor
364
+
365
+ for i in range(num_steps):
366
+ t, t_next = t_schedule[i], t_schedule[i+1]
367
+
368
+ has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
369
+
370
+ if has_cfg:
371
+ v_cond, v_uncond = model(
372
+ x=torch.cat([x_t, x_t], dim=0).to(dtype),
373
+ t=(torch.ones((batch_size * 2,), device=device) * t).to(dtype),
374
+ text_input_ids=None,
375
+ text_mask=torch.cat([text_mask, text_mask_uncond if i % 2 == 0 else text_mask], dim=0),
376
+ speaker_latent=None,
377
+ speaker_mask=torch.cat([speaker_mask, speaker_mask if i % 2 == 0 else speaker_mask_uncond], dim=0),
378
+ kv_cache=kv_cache_full,
379
+ ).float().chunk(2, dim=0)
380
+ v_pred = v_cond + (cfg_scale_text if i % 2 == 0 else cfg_scale_speaker) * (v_cond - v_uncond)
381
+ else:
382
+ v_pred = model(
383
+ x=x_t.to(dtype),
384
+ t=(torch.ones((batch_size,), device=device) * t).to(dtype),
385
+ text_input_ids=None,
386
+ text_mask=text_mask,
387
+ speaker_latent=None,
388
+ speaker_mask=speaker_mask,
389
+ kv_cache=kv_cache,
390
+ ).float()
391
+
392
+ if rescale_k is not None and rescale_sigma is not None:
393
+ v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
394
+
395
+ if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
396
+ _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
397
+
398
+ x_t = x_t + v_pred * (t_next - t)
399
+
400
+ return x_t
401
+
402
+
403
+ @torch.inference_mode()
404
+ def sample_euler_apg_independent_guidances(
405
+ model: EchoDiT,
406
+ speaker_latent: torch.Tensor,
407
+ speaker_mask: torch.Tensor,
408
+ text_input_ids: torch.Tensor,
409
+ text_mask: torch.Tensor,
410
+ rng_seed: int,
411
+ num_steps: int,
412
+ cfg_scale_text: float,
413
+ cfg_scale_speaker: float,
414
+ cfg_min_t: float,
415
+ cfg_max_t: float,
416
+ truncation_factor: float | None,
417
+ rescale_k: float | None,
418
+ rescale_sigma: float | None,
419
+ apg_eta_text: float,
420
+ apg_eta_speaker: float,
421
+ apg_momentum_text: float | None,
422
+ apg_momentum_speaker: float | None,
423
+ apg_norm_text: float | None,
424
+ apg_norm_speaker: float | None,
425
+ speaker_k_scale: float | None,
426
+ speaker_k_max_layers: int | None,
427
+ speaker_k_min_t: float | None,
428
+ block_size: int | None = None,
429
+ ) -> torch.Tensor:
430
+
431
+ if block_size is None:
432
+ block_size = 640
433
+
434
+ if apg_momentum_text is None:
435
+ apg_momentum_text = 0.0
436
+ if apg_momentum_speaker is None:
437
+ apg_momentum_speaker = 0.0
438
+
439
+ torch.manual_seed(rng_seed)
440
+
441
+ INIT_SCALE = 0.999
442
+
443
+ device, dtype = model.device, model.dtype
444
+
445
+ batch_size = text_input_ids.shape[0]
446
+
447
+ t_schedule = torch.linspace(1., 0., num_steps + 1, device=device) * INIT_SCALE
448
+
449
+ text_input_ids_uncond, text_mask_uncond = _get_uncond_text_input_ids_and_mask(text_input_ids.shape[0], text_input_ids.shape[1], device=device)
450
+
451
+ speaker_latent_uncond, speaker_mask_uncond = torch.zeros_like(speaker_latent), torch.zeros_like(speaker_mask)
452
+
453
+ full_text_input_ids = torch.cat([text_input_ids, text_input_ids_uncond, text_input_ids], dim=0)
454
+ full_text_mask = torch.cat([text_mask, text_mask_uncond, text_mask], dim=0)
455
+
456
+ full_speaker_latent = torch.cat([speaker_latent, speaker_latent, speaker_latent_uncond], dim=0)
457
+ full_speaker_mask = torch.cat([speaker_mask, speaker_mask, speaker_mask_uncond], dim=0)
458
+
459
+ kv_cache_full = model.get_kv_cache(
460
+ speaker_latent=full_speaker_latent.to(dtype),
461
+ speaker_mask=full_speaker_mask,
462
+ text_input_ids=full_text_input_ids,
463
+ text_mask=full_text_mask,
464
+ )
465
+ kv_cache = _get_first_n_kv_cache(kv_cache_full, batch_size)
466
+
467
+ if speaker_k_scale is not None:
468
+ _multiply_speaker_kv_cache(kv_cache_full, speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
469
+
470
+ x_t = torch.randn((batch_size, block_size, 80), device=device, dtype=torch.float32)
471
+ if truncation_factor is not None:
472
+ x_t = x_t * truncation_factor
473
+
474
+ buf_text = torch.zeros_like(x_t)
475
+ buf_speaker = torch.zeros_like(x_t)
476
+
477
+ for i in range(num_steps):
478
+ t, t_next = t_schedule[i], t_schedule[i+1]
479
+
480
+ has_cfg = ((t >= cfg_min_t) * (t <= cfg_max_t)).item()
481
+
482
+ if has_cfg:
483
+ v_cond, v_uncond_text, v_uncond_speaker = model(
484
+ x=torch.cat([x_t, x_t, x_t], dim=0).to(dtype),
485
+ t=(torch.ones((batch_size * 3,), device=device) * t).to(dtype),
486
+ text_input_ids=None,
487
+ text_mask=full_text_mask,
488
+ speaker_latent=None,
489
+ speaker_mask=full_speaker_mask,
490
+ kv_cache=kv_cache_full,
491
+ ).float().chunk(3, dim=0)
492
+
493
+ x0_cond = x_t - t * v_cond
494
+ x0_uncond_text = x_t - t * v_uncond_text
495
+ x0_uncond_speaker = x_t - t * v_uncond_speaker
496
+
497
+ diff_text = x0_cond - x0_uncond_text
498
+ diff_speaker = x0_cond - x0_uncond_speaker
499
+
500
+ buf_text = diff_text + apg_momentum_text * buf_text
501
+ diff_text = buf_text
502
+
503
+ buf_speaker = diff_speaker + apg_momentum_speaker * buf_speaker
504
+ diff_speaker = buf_speaker
505
+
506
+ if apg_norm_text is not None:
507
+ nt = torch.sqrt((diff_text * diff_text).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) + 1e-12)
508
+ s = torch.minimum(torch.ones_like(nt), (torch.as_tensor(apg_norm_text, device=device, dtype=diff_text.dtype) / nt))
509
+ diff_text = diff_text * s
510
+ if apg_norm_speaker is not None:
511
+ ns = torch.sqrt((diff_speaker * diff_speaker).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) + 1e-12)
512
+ s = torch.minimum(torch.ones_like(ns), (torch.as_tensor(apg_norm_speaker, device=device, dtype=diff_speaker.dtype) / ns))
513
+ diff_speaker = diff_speaker * s
514
+
515
+ c_norm = torch.sqrt((x0_cond * x0_cond).sum(dim=tuple(range(1, x0_cond.dim())), keepdim=True) + 1e-12)
516
+ c_hat = x0_cond / c_norm
517
+
518
+ par_text = (diff_text * c_hat).sum(dim=tuple(range(1, diff_text.dim())), keepdim=True) * c_hat
519
+ ort_text = diff_text - par_text
520
+ upd_text = ort_text + apg_eta_text * par_text
521
+
522
+ par_speaker = (diff_speaker * c_hat).sum(dim=tuple(range(1, diff_speaker.dim())), keepdim=True) * c_hat
523
+ ort_speaker = diff_speaker - par_speaker
524
+ upd_speaker = ort_speaker + apg_eta_speaker * par_speaker
525
+
526
+ x0_pred = x0_cond + cfg_scale_text * upd_text + cfg_scale_speaker * upd_speaker
527
+ v_pred = (x_t - x0_pred) / t
528
+ else:
529
+ v_pred = model(
530
+ x=x_t.to(dtype),
531
+ t=(torch.ones((batch_size,), device=device) * t).to(dtype),
532
+ text_input_ids=None,
533
+ text_mask=text_mask,
534
+ speaker_latent=None,
535
+ speaker_mask=speaker_mask,
536
+ kv_cache=kv_cache,
537
+ ).float()
538
+
539
+ if rescale_k is not None and rescale_sigma is not None:
540
+ v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma)
541
+
542
+ if speaker_k_scale is not None and t_next < speaker_k_min_t and t >= speaker_k_min_t:
543
+ _multiply_speaker_kv_cache(kv_cache_full, 1. / speaker_k_scale, text_input_ids.shape[-1], speaker_k_max_layers)
544
+
545
+ x_t = x_t + v_pred * (t_next - t)
546
+
547
+ return x_t
548
+
549
+
550
+
551
+ # router
552
+
553
+ class GuidanceMode(Enum):
554
+ INDEPENDENT = "independent"
555
+ APG = "apg"
556
+ JOINT = "joint"
557
+ ALTERNATING = "alternating"
558
+
559
+
560
+ def sample_euler_cfg_any(
561
+ model: EchoDiT,
562
+ speaker_latent: torch.Tensor,
563
+ speaker_mask: torch.Tensor,
564
+ text_input_ids: torch.Tensor,
565
+ text_mask: torch.Tensor,
566
+ rng_seed: int,
567
+ guidance_mode: GuidanceMode,
568
+ num_steps: int,
569
+ cfg_scale_text: float,
570
+ cfg_scale_speaker: float | None,
571
+ cfg_min_t: float,
572
+ cfg_max_t: float,
573
+ truncation_factor: float | None,
574
+ rescale_k: float | None,
575
+ rescale_sigma: float | None,
576
+ speaker_k_scale: float | None,
577
+ speaker_k_min_t: float | None,
578
+ speaker_k_max_layers: int | None,
579
+ apg_eta_text: float | None,
580
+ apg_eta_speaker: float | None,
581
+ apg_momentum_text: float | None,
582
+ apg_momentum_speaker: float | None,
583
+ apg_norm_text: float | None,
584
+ apg_norm_speaker: float | None,
585
+ block_size: int | None = None,
586
+ ) -> torch.Tensor:
587
+
588
+ if guidance_mode == GuidanceMode.INDEPENDENT:
589
+ assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for independent guidances"
590
+ return sample_euler_cfg_independent_guidances(
591
+ model=model,
592
+ speaker_latent=speaker_latent,
593
+ speaker_mask=speaker_mask,
594
+ text_input_ids=text_input_ids,
595
+ text_mask=text_mask,
596
+ rng_seed=rng_seed,
597
+ num_steps=num_steps,
598
+ cfg_scale_text=cfg_scale_text,
599
+ cfg_scale_speaker=cfg_scale_speaker,
600
+ cfg_min_t=cfg_min_t,
601
+ cfg_max_t=cfg_max_t,
602
+ truncation_factor=truncation_factor,
603
+ rescale_k=rescale_k,
604
+ rescale_sigma=rescale_sigma,
605
+ speaker_k_scale=speaker_k_scale,
606
+ speaker_k_max_layers=speaker_k_max_layers,
607
+ speaker_k_min_t=speaker_k_min_t,
608
+ block_size=block_size,
609
+ )
610
+
611
+ elif guidance_mode == GuidanceMode.APG:
612
+ assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for APG"
613
+ assert apg_eta_text is not None, "apg_eta_text must be provided for APG"
614
+ assert apg_eta_speaker is not None, "apg_eta_speaker must be provided for APG"
615
+ return sample_euler_apg_independent_guidances(
616
+ model=model,
617
+ speaker_latent=speaker_latent,
618
+ speaker_mask=speaker_mask,
619
+ text_input_ids=text_input_ids,
620
+ text_mask=text_mask,
621
+ rng_seed=rng_seed,
622
+ num_steps=num_steps,
623
+ cfg_scale_text=cfg_scale_text,
624
+ cfg_scale_speaker=cfg_scale_speaker,
625
+ cfg_min_t=cfg_min_t,
626
+ cfg_max_t=cfg_max_t,
627
+ truncation_factor=truncation_factor,
628
+ rescale_k=rescale_k,
629
+ rescale_sigma=rescale_sigma,
630
+ apg_eta_text=apg_eta_text,
631
+ apg_eta_speaker=apg_eta_speaker,
632
+ apg_momentum_text=apg_momentum_text,
633
+ apg_momentum_speaker=apg_momentum_speaker,
634
+ apg_norm_text=apg_norm_text,
635
+ apg_norm_speaker=apg_norm_speaker,
636
+ speaker_k_scale=speaker_k_scale,
637
+ speaker_k_max_layers=speaker_k_max_layers,
638
+ speaker_k_min_t=speaker_k_min_t,
639
+ block_size=block_size,
640
+ )
641
+
642
+ elif guidance_mode == GuidanceMode.JOINT:
643
+ assert cfg_scale_text == cfg_scale_speaker or cfg_scale_speaker is None, "cfg_scale_text and cfg_scale_speaker must be the same or cfg_scale_speaker must be None"
644
+ return sample_euler_cfg(
645
+ model=model,
646
+ speaker_latent=speaker_latent,
647
+ speaker_mask=speaker_mask,
648
+ text_input_ids=text_input_ids,
649
+ text_mask=text_mask,
650
+ rng_seed=rng_seed,
651
+ num_steps=num_steps,
652
+ cfg_scale=cfg_scale_text,
653
+ cfg_min_t=cfg_min_t,
654
+ cfg_max_t=cfg_max_t,
655
+ truncation_factor=truncation_factor,
656
+ rescale_k=rescale_k,
657
+ rescale_sigma=rescale_sigma,
658
+ speaker_k_scale=speaker_k_scale,
659
+ speaker_k_max_layers=speaker_k_max_layers,
660
+ speaker_k_min_t=speaker_k_min_t,
661
+ block_size=block_size,
662
+ )
663
+
664
+ elif guidance_mode == GuidanceMode.ALTERNATING:
665
+ assert cfg_scale_speaker is not None, "cfg_scale_speaker must be provided for alternating guidances"
666
+ return sample_euler_cfg_alternating_guidances(
667
+ model=model,
668
+ speaker_latent=speaker_latent,
669
+ speaker_mask=speaker_mask,
670
+ text_input_ids=text_input_ids,
671
+ text_mask=text_mask,
672
+ rng_seed=rng_seed,
673
+ num_steps=num_steps,
674
+ cfg_scale_text=cfg_scale_text,
675
+ cfg_scale_speaker=cfg_scale_speaker,
676
+ cfg_min_t=cfg_min_t,
677
+ cfg_max_t=cfg_max_t,
678
+ truncation_factor=truncation_factor,
679
+ rescale_k=rescale_k,
680
+ rescale_sigma=rescale_sigma,
681
+ speaker_k_scale=speaker_k_scale,
682
+ speaker_k_max_layers=speaker_k_max_layers,
683
+ speaker_k_min_t=speaker_k_min_t,
684
+ block_size=block_size,
685
+ )
686
+
687
+ else:
688
+ raise ValueError(f"Unknown guidance mode: {guidance_mode}")
689
+
690
+
silentcipher/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .server import get_model
2
+
3
+ __version__ = '1.0.4'
silentcipher/model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class Layer(nn.Module):
7
+ def __init__(self, dim_in, dim_out, kernel_size, stride, padding):
8
+ super(Layer, self).__init__()
9
+ self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
10
+ self.gate = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
11
+ self.bn = nn.BatchNorm2d(dim_out)
12
+
13
+ def forward(self, x):
14
+ return self.bn(self.conv(x) * torch.sigmoid(self.gate(x)))
15
+
16
+ class Encoder(nn.Module):
17
+ def __init__(self, out_dim=32, n_layers=3, message_dim=0, message_band_size=None, n_fft=None):
18
+ super(Encoder, self).__init__()
19
+ assert message_band_size is not None
20
+ assert n_fft is not None
21
+ self.message_band_size = message_band_size
22
+ main = [Layer(dim_in=1, dim_out=32, kernel_size=3, stride=1, padding=1)]
23
+
24
+ for i in range(n_layers-2):
25
+ main.append(Layer(dim_in=32, dim_out=32, kernel_size=3, stride=1, padding=1))
26
+ main.append(Layer(dim_in=32, dim_out=out_dim, kernel_size=3, stride=1, padding=1))
27
+
28
+ self.main = nn.Sequential(*main)
29
+ self.linear = nn.Linear(message_dim, message_band_size)
30
+ self.n_fft = n_fft
31
+
32
+ def forward(self, x):
33
+ h = self.main(x)
34
+ return h
35
+
36
+ def transform_message(self, msg):
37
+ output = self.linear(msg.transpose(2, 3)).transpose(2, 3)
38
+ if self.message_band_size != self.n_fft // 2 + 1:
39
+ output = torch.nn.functional.pad(output, (0, 0, 0, self.n_fft // 2 + 1 - self.message_band_size))
40
+ return output
41
+
42
+ class CarrierDecoder(nn.Module):
43
+ def __init__(self, config, conv_dim, n_layers=4, message_band_size=1024):
44
+ super(CarrierDecoder, self).__init__()
45
+ self.config = config
46
+ self.message_band_size = message_band_size
47
+ layers = [Layer(dim_in=conv_dim, dim_out=96, kernel_size=3, stride=1, padding=1)]
48
+
49
+ for i in range(n_layers-2):
50
+ layers.append(Layer(dim_in=96, dim_out=96, kernel_size=3, stride=1, padding=1))
51
+
52
+ layers.append(Layer(dim_in=96, dim_out=1, kernel_size=1, stride=1, padding=0))
53
+
54
+ self.main = nn.Sequential(*layers)
55
+
56
+ def forward(self, x, message_sdr):
57
+ h = self.main(x)
58
+
59
+ if self.config.ensure_negative_message:
60
+ h = torch.abs(h)
61
+
62
+ h[:, :, self.message_band_size:, :] = 0
63
+
64
+ if not self.config.no_normalization:
65
+ h = h / torch.mean(h**2, dim=2, keepdim=True)**0.5 / (10**(message_sdr/20))
66
+
67
+ return h
68
+
69
+ class MsgDecoder(nn.Module):
70
+ def __init__(self, message_dim=0, message_band_size=None, channel_dim=128, num_layers=10):
71
+ super(MsgDecoder, self).__init__()
72
+ assert message_band_size is not None
73
+ self.message_band_size = message_band_size
74
+
75
+ main = [
76
+ nn.Dropout(0),
77
+ Layer(dim_in=1, dim_out=channel_dim, kernel_size=3, stride=1, padding=1)
78
+ ]
79
+ for l in range(num_layers - 2):
80
+ main += [
81
+ nn.Dropout(0),
82
+ Layer(dim_in=channel_dim, dim_out=channel_dim, kernel_size=3, stride=1, padding=1),
83
+ ]
84
+ main += [
85
+ nn.Dropout(0),
86
+ Layer(dim_in=channel_dim, dim_out=message_dim, kernel_size=3, stride=1, padding=1)
87
+ ]
88
+ self.main = nn.Sequential(*main)
89
+ self.linear = nn.Linear(self.message_band_size, 1)
90
+
91
+ def forward(self, x):
92
+
93
+ h = self.main(x[:, :, :self.message_band_size])
94
+ h = self.linear(h.transpose(2, 3)).squeeze(3).unsqueeze(1)
95
+ return h
silentcipher/server.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from calendar import c
2
+ import os
3
+ import argparse
4
+ import re
5
+ from tabnanny import check
6
+ import yaml
7
+ import time
8
+ import numpy as np
9
+ import soundfile as sf
10
+ from scipy import stats as st
11
+ import librosa
12
+ from pydub import AudioSegment
13
+ import torch
14
+ from torch import nn
15
+
16
+ from .model import Encoder, CarrierDecoder, MsgDecoder
17
+ from .stft import STFT
18
+
19
+ class Model():
20
+
21
+ def __init__(self, config, device='cpu'):
22
+
23
+ self.config = config
24
+ self.device = device
25
+
26
+ self.n_messages = config.n_messages
27
+ self.model_type = config.model_type
28
+ self.message_dim = config.message_dim
29
+ self.message_len = config.message_len
30
+
31
+ # model dimensions
32
+ self.enc_conv_dim = 16
33
+ self.enc_num_repeat = 3
34
+ self.dec_c_num_repeat = self.enc_num_repeat
35
+ self.dec_m_conv_dim = 1
36
+ self.dec_m_num_repeat = 8
37
+ self.encoder_out_dim = 32
38
+ self.dec_c_conv_dim = 32*3
39
+
40
+ self.enc_c = Encoder(n_layers=self.config.enc_n_layers,
41
+ message_dim=self.message_dim,
42
+ out_dim=self.encoder_out_dim,
43
+ message_band_size=self.config.message_band_size,
44
+ n_fft=self.config.N_FFT)
45
+
46
+ self.dec_c = CarrierDecoder(config=self.config,
47
+ conv_dim=self.dec_c_conv_dim,
48
+ n_layers=self.config.dec_c_n_layers,
49
+ message_band_size=self.config.message_band_size)
50
+
51
+ self.dec_m = [MsgDecoder(message_dim=self.message_dim,
52
+ message_band_size=self.config.message_band_size) for _ in range(self.n_messages)]
53
+ # ------ make parallel ------
54
+ self.enc_c = self.enc_c.to(self.device)
55
+ self.dec_c = self.dec_c.to(self.device)
56
+ self.dec_m = [m.to(self.device) for m in self.dec_m]
57
+
58
+ self.average_energy_VCTK=0.002837200844477648
59
+ self.stft = STFT(self.config.N_FFT, self.config.HOP_LENGTH)
60
+ self.stft.to(self.device)
61
+ self.load_models(config.load_ckpt)
62
+ self.sr = self.config.SR
63
+
64
+ def letters_encoding(self, patch_len, message_lst):
65
+
66
+ """
67
+ Encodes a list of messages into a compact representation and a padded representation.
68
+
69
+ Args:
70
+ patch_len (int): The length of the patch.
71
+ message_lst (list): A list of messages to be encoded.
72
+
73
+ Returns:
74
+ tuple: A tuple containing two numpy arrays:
75
+ - message: A padded representation of the messages, where each message is repeated to match the patch length.
76
+ - message_compact: A compact representation of the messages, where each message is encoded as a one-hot vector.
77
+
78
+ Raises:
79
+ AssertionError: If the length of any message in message_lst is not equal to self.config.message_len - 1.
80
+ """
81
+
82
+ message = []
83
+ message_compact = []
84
+ for i in range(self.n_messages):
85
+
86
+ assert len(message_lst[i]) == self.config.message_len - 1
87
+ index = np.concatenate((np.array(message_lst[i])+1, [0]))
88
+ one_hot = np.identity(self.message_dim)[index]
89
+ message_compact.append(one_hot)
90
+ if patch_len % self.message_len == 0:
91
+ message.append(np.tile(one_hot.T, (1, patch_len // self.message_len)))
92
+ else:
93
+ _ = np.tile(one_hot.T, (1, patch_len // self.message_len))
94
+ _ = np.concatenate([_, one_hot.T[:, 0:patch_len % self.message_len]], axis=1)
95
+ message.append(_)
96
+ message = np.stack(message)
97
+ message_compact = np.stack(message_compact)
98
+ # message = np.pad(message, ((0, 0), (0, 129 - self.message_dim), (0, 0)), 'constant')
99
+ return message, message_compact
100
+
101
+ def get_best_ps(self, y_one_sec):
102
+
103
+ """
104
+ Calculates the best phase shift value for watermark decoding.
105
+
106
+ Args:
107
+ y_one_sec (numpy.ndarray): Input audio signal.
108
+
109
+ Returns:
110
+ int: The best phase shift value.
111
+
112
+ """
113
+
114
+ def check_accuracy(pred_values):
115
+
116
+ accuracy = 0
117
+ for i in range(pred_values.shape[1]):
118
+ unique, counts = np.unique(pred_values[:, i], return_counts=True)
119
+ accuracy += np.max(counts) / pred_values.shape[0]
120
+
121
+ return accuracy / pred_values.shape[1]
122
+
123
+ y = torch.FloatTensor(y_one_sec).unsqueeze(0).unsqueeze(0).to(self.device)
124
+ max_accuracy = 0
125
+ final_phase_shift = 0
126
+
127
+ for ps in range(0, self.config.HOP_LENGTH, 10):
128
+
129
+ carrier, _ = self.stft.transform(y[0:1, 0:1, ps:].squeeze(1))
130
+ carrier = carrier[:, None]
131
+
132
+ for i in range(self.n_messages): # decode each msg_i using decoder_m_i
133
+ msg_reconst = self.dec_m[i](carrier)
134
+ pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
135
+ pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
136
+ pred_values = pred_values.reshape([-1, self.config.message_len])
137
+ cur_acc = check_accuracy(pred_values)
138
+ if cur_acc > max_accuracy:
139
+ max_accuracy = cur_acc
140
+ final_phase_shift = ps
141
+
142
+ return final_phase_shift
143
+
144
+ def get_confidence(self, pred_values, message):
145
+ """
146
+ Calculates the confidence of the predicted values based on the provided message.
147
+
148
+ Parameters:
149
+ pred_values (numpy.ndarray): The predicted values.
150
+ message (str): The message used for prediction.
151
+
152
+ Returns:
153
+ float: The confidence score.
154
+
155
+ Raises:
156
+ AssertionError: If the length of the message is not equal to the number of columns in pred_values.
157
+
158
+ """
159
+ assert len(message) == pred_values.shape[1], f'{len(message)} | {pred_values.shape}'
160
+ return np.mean((pred_values == message[None]).astype(np.float32)).item()
161
+
162
+ def sdr(self, orig, recon):
163
+ """
164
+ Calculate the Signal-to-Distortion Ratio (SDR) between the original and reconstructed signals.
165
+
166
+ Parameters:
167
+ orig (numpy.ndarray): The original signal.
168
+ recon (numpy.ndarray): The reconstructed signal.
169
+
170
+ Returns:
171
+ float: The Signal-to-Distortion Ratio (SDR) value.
172
+
173
+ """
174
+
175
+ rms1 = ((np.mean(orig ** 2)) ** 0.5)
176
+ rms2 = ((np.mean((orig - recon) ** 2)) ** 0.5)
177
+ sdr = 20 * np.log10(rms1 / rms2)
178
+ return sdr
179
+
180
+ def load_audio(self, path):
181
+ """
182
+ Load an audio file from the given path and return the audio array and sample rate.
183
+
184
+ Args:
185
+ path (str): The path to the audio file.
186
+
187
+ Returns:
188
+ tuple: A tuple containing the audio array and sample rate.
189
+
190
+ """
191
+ audio = AudioSegment.from_file(path)
192
+ audio_array, sr = (np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((-1, audio.channels)) / (
193
+ 1 << (8 * audio.sample_width - 1))), audio.frame_rate
194
+ if audio_array.shape[1] == 1:
195
+ audio_array = audio_array[:, 0]
196
+
197
+ return audio_array, sr
198
+
199
+ def encode(self, in_path, out_path, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
200
+ """
201
+ Encodes a message into an audio file.
202
+
203
+ Parameters:
204
+ - in_path (str): The path to the input audio file.
205
+ - out_path (str): The path to save the output audio file.
206
+ - message_list (list): A list of messages to be encoded into the audio file.
207
+ - message_sdr (float, optional): The Signal-to-Distortion Ratio (SDR) of the message. Defaults to None.
208
+ - calc_sdr (bool, optional): Whether to calculate the SDR of the encoded audio. Defaults to True.
209
+ - disable_checks (bool, optional): Whether to disable input checks. Defaults to False.
210
+
211
+ Returns:
212
+ - dict: A dictionary containing the status of the encoding process, the SDR value(s), the time taken for encoding, and the time taken per second of audio.
213
+
214
+ """
215
+ y, orig_sr = self.load_audio(in_path)
216
+ start = time.time()
217
+ encoded_y, sdr = self.encode_wav(y, orig_sr, message_list=message_list, message_sdr=message_sdr, calc_sdr=calc_sdr, disable_checks=disable_checks)
218
+ time_taken = time.time() - start
219
+ sf.write(out_path, encoded_y, orig_sr)
220
+
221
+ if type(sdr) == list:
222
+ return {'status': True, 'sdr': [f'{sdr_i:.2f}' for sdr_i in sdr], 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
223
+ else:
224
+ return {'status': True, 'sdr': f'{sdr:.2f}', 'time_taken': time_taken, 'time_taken_per_second': time_taken / (y.shape[0] / orig_sr)}
225
+
226
+ def decode(self, path, phase_shift_decoding):
227
+ """
228
+ Decode the audio file at the given path using phase shift decoding.
229
+
230
+ Parameters:
231
+ path (str): The path to the audio file.
232
+ phase_shift_decoding (bool): Flag indicating whether to use phase shift decoding.
233
+
234
+ Returns:
235
+ dictionary: A dictionary containing the decoded message status and value
236
+ """
237
+
238
+ y, orig_sr = self.load_audio(path)
239
+
240
+ return self.decode_wav(y, orig_sr, phase_shift_decoding)
241
+
242
+ def encode_wav(self, y_multi_channel, orig_sr, message_list, message_sdr=None, calc_sdr=True, disable_checks=False):
243
+
244
+ """
245
+ Encodes a multi-channel audio waveform with a given message.
246
+
247
+ Args:
248
+ y_multi_channel (numpy.ndarray): The multi-channel audio waveform to be encoded.
249
+ orig_sr (int): The original sampling rate of the audio waveform.
250
+ message_list (list): The list of messages to be encoded. Each message may correspond to a channel in the audio waveform.
251
+ message_sdr (float, optional): The signal-to-distortion ratio (SDR) of the message. If not provided, the default SDR from the configuration is used.
252
+ calc_sdr (bool, optional): Flag indicating whether to calculate the SDR of the encoded waveform. Defaults to True.
253
+ disable_checks (bool, optional): Flag indicating whether to disable input audio checks. Defaults to False.
254
+
255
+ Returns:
256
+ tuple: A tuple containing the encoded multi-channel audio waveform and the SDR (if calculated).
257
+
258
+ Raises:
259
+ AssertionError: If the number of messages does not match the number of channels in the input audio waveform.
260
+ """
261
+
262
+ single_channel = False
263
+ if len(y_multi_channel.shape) == 1:
264
+ single_channel = True
265
+ y_multi_channel = y_multi_channel[:, None]
266
+
267
+ if message_sdr is None:
268
+ message_sdr = self.config.message_sdr
269
+ print(f'Using the default SDR of {self.config.message_sdr} dB')
270
+
271
+ if type(message_list[0]) == int:
272
+ message_list = [message_list]*y_multi_channel.shape[1]
273
+
274
+ y_watermarked_multi_channel = []
275
+ sdrs = []
276
+
277
+ assert len(message_list) == y_multi_channel.shape[1], f'{len(message_list)} | {y_multi_channel.shape[1]} Mismatch in the number of messages and channels in the input audio.'
278
+
279
+ for channel_i in range(y_multi_channel.shape[1]):
280
+ y = y_multi_channel[:, channel_i]
281
+ message = message_list[channel_i]
282
+
283
+ with torch.no_grad():
284
+
285
+ orig_y = y.copy()
286
+ if orig_sr != self.sr:
287
+ if orig_sr > self.sr:
288
+ print(f'WARNING! Reducing the sampling rate of the original audio from {orig_sr} -> {self.sr}. High frequency components may be lost!')
289
+ y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
290
+ original_power = np.mean(y**2)
291
+
292
+ if not disable_checks:
293
+ if original_power == 0:
294
+ print('WARNING! The input audio has a power of 0.This means the audio is likely just silence. Skipping encoding.')
295
+ return orig_y, 0
296
+
297
+ y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
298
+ y = torch.FloatTensor(y).unsqueeze(0).unsqueeze(0).to(self.device)
299
+ carrier, carrier_phase = self.stft.transform(y.squeeze(1))
300
+ carrier = carrier[:, None]
301
+ carrier_phase = carrier_phase[:, None]
302
+
303
+ def binary_encode(mes):
304
+ binary_message = ''.join(['{0:08b}'.format(mes_i) for mes_i in mes])
305
+ four_bit_msg = []
306
+ for i in range(len(binary_message)//2):
307
+ four_bit_msg.append(int(binary_message[i*2:i*2+2], 2))
308
+ return four_bit_msg
309
+
310
+ binary_encoded_message = binary_encode(message)
311
+
312
+ msgs, msgs_compact = self.letters_encoding(carrier.shape[3], [binary_encoded_message])
313
+ msg_enc = torch.from_numpy(msgs[None]).to(self.device).float()
314
+
315
+ carrier_enc = self.enc_c(carrier) # encode the carrier
316
+ msg_enc = self.enc_c.transform_message(msg_enc)
317
+
318
+ merged_enc = torch.cat((carrier_enc, carrier.repeat(1, 32, 1, 1), msg_enc.repeat(1, 32, 1, 1)), dim=1) # concat encodings on features axis
319
+
320
+ message_info = self.dec_c(merged_enc, message_sdr)
321
+ if self.config.frame_level_normalization:
322
+ message_info = message_info*(torch.mean((carrier**2), dim=2, keepdim=True)**0.5) # *time_weighing
323
+ elif self.config.utterance_level_normalization:
324
+ message_info = message_info*(torch.mean((carrier**2), dim=(2,3), keepdim=True)**0.5) # *time_weighing
325
+
326
+ if self.config.ensure_negative_message:
327
+ message_info = -message_info
328
+ carrier_reconst = torch.nn.functional.relu(message_info + carrier) # decode carrier, output in stft domain
329
+ elif self.config.ensure_constrained_message:
330
+ message_info[message_info > carrier] = carrier[message_info > carrier]
331
+ message_info[-message_info > carrier] = -carrier[-message_info > carrier]
332
+ carrier_reconst = message_info + carrier # decode carrier, output in stft domain
333
+ assert torch.all(carrier_reconst >= 0), 'negative values found in carrier_reconst'
334
+ else:
335
+ carrier_reconst = torch.abs(message_info + carrier) # decode carrier, output in stft domain
336
+
337
+ self.stft.num_samples = y.shape[2]
338
+
339
+ y = self.stft.inverse(carrier_reconst.squeeze(1), carrier_phase.squeeze(1)).data.cpu().numpy()[0, 0]
340
+ y = y * np.sqrt(original_power / (self.average_energy_VCTK)) # Noise has a power of 5% power of VCTK samples
341
+ if orig_sr != self.sr:
342
+ y = librosa.resample(y, orig_sr = self.sr, target_sr = orig_sr)
343
+
344
+ if calc_sdr:
345
+ sdr = self.sdr(orig_y, y)
346
+ else:
347
+ sdr = 0
348
+
349
+ y_watermarked_multi_channel.append(y[:, None])
350
+ sdrs.append(sdr)
351
+
352
+ y_watermarked_multi_channel = np.concatenate(y_watermarked_multi_channel, axis=1)
353
+
354
+ if single_channel:
355
+ y_watermarked_multi_channel = y_watermarked_multi_channel[:, 0]
356
+ sdrs = sdrs[0]
357
+
358
+ return y_watermarked_multi_channel, sdrs
359
+
360
+ def decode_wav(self, y_multi_channel, orig_sr, phase_shift_decoding):
361
+ """
362
+ Decode the given audio waveform to extract hidden messages.
363
+
364
+ Args:
365
+ y_multi_channel (numpy.ndarray): The multi-channel audio waveform.
366
+ orig_sr (int): The original sample rate of the audio waveform.
367
+ phase_shift_decoding (str): Flag indicating whether to perform phase shift decoding.
368
+
369
+ Returns:
370
+ dict or list: A list of dictionary containing the decoded messages, confidences, and status for each channel if the input is multi-channel.
371
+ Otherwise, a dictionary containing the decoded messages, confidences, and status for a single channel.
372
+
373
+ Raises:
374
+ Exception: If the decoding process fails.
375
+
376
+ """
377
+ single_channel = False
378
+ if len(y_multi_channel.shape) == 1:
379
+ single_channel = True
380
+ y_multi_channel = y_multi_channel[:, None]
381
+
382
+ results = []
383
+
384
+ for channel_i in range(y_multi_channel.shape[1]):
385
+ y = y_multi_channel[:, channel_i]
386
+ try:
387
+ with torch.no_grad():
388
+ if orig_sr != self.sr:
389
+ y = librosa.resample(y, orig_sr = orig_sr, target_sr = self.sr)
390
+ original_power = np.mean(y**2)
391
+ y = y * np.sqrt(self.average_energy_VCTK / original_power) # Noise has a power of 5% power of VCTK samples
392
+ if phase_shift_decoding and phase_shift_decoding != 'false':
393
+ ps = self.get_best_ps(y)
394
+ else:
395
+ ps = 0
396
+ y = torch.FloatTensor(y[ps:]).unsqueeze(0).unsqueeze(0).to(self.device)
397
+ carrier, _ = self.stft.transform(y.squeeze(1))
398
+ carrier = carrier[:, None]
399
+
400
+ msg_reconst_list = []
401
+ confidence = []
402
+
403
+ for i in range(self.n_messages): # decode each msg_i using decoder_m_i
404
+ msg_reconst = self.dec_m[i](carrier)
405
+ pred_values = torch.argmax(msg_reconst[0, 0], dim=0).data.cpu().numpy()
406
+ pred_values = pred_values[0:int(msg_reconst.shape[3]/self.config.message_len)*self.config.message_len]
407
+ pred_values = pred_values.reshape([-1, self.config.message_len])
408
+
409
+ ord_values = st.mode(pred_values, keepdims=False).mode
410
+ end_char = np.min(np.nonzero(ord_values == 0)[0])
411
+ confidence.append(self.get_confidence(pred_values, ord_values))
412
+ if end_char == self.config.message_len:
413
+ ord_values = ord_values[:self.config.message_len-1]
414
+ else:
415
+ ord_values = np.concatenate([ord_values[end_char+1:], ord_values[:end_char]], axis=0)
416
+
417
+ # pred_values = ''.join([chr(v + 64) for v in ord_values])
418
+ msg_reconst_list.append((ord_values - 1).tolist())
419
+
420
+ def convert_to_8_bit_segments(msg_list):
421
+ segment_message_list = []
422
+ for msg_list_i in msg_list:
423
+ binary_format = ''.join(['{0:02b}'.format(mes_i) for mes_i in msg_list_i])
424
+ eight_bit_segments = [int(binary_format[i*8:i*8+8], 2) for i in range(len(binary_format)//8)]
425
+ segment_message_list.append(eight_bit_segments)
426
+ return segment_message_list
427
+ msg_reconst_list = convert_to_8_bit_segments(msg_reconst_list)
428
+
429
+ results.append({'messages': msg_reconst_list, 'confidences': confidence, 'status': True})
430
+ except:
431
+ results.append({'messages': [], 'confidences': [], 'error': 'Could not find message', 'status': False})
432
+
433
+ if single_channel:
434
+ results = results[0]
435
+
436
+ return results
437
+
438
+ def convert_dataparallel_to_normal(self, checkpoint):
439
+
440
+ return {i[len('module.'):] if i.startswith('module.') else i: checkpoint[i] for i in checkpoint }
441
+
442
+ def load_models(self, ckpt_dir):
443
+
444
+ self.enc_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "enc_c.ckpt"), map_location=self.device)))
445
+ self.dec_c.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, "dec_c.ckpt"), map_location=self.device)))
446
+ for i,m in enumerate(self.dec_m):
447
+ m.load_state_dict(self.convert_dataparallel_to_normal(torch.load(os.path.join(ckpt_dir, f"dec_m_{i}.ckpt"), map_location=self.device)))
448
+
449
+
450
+ def get_model(model_type='44.1k', ckpt_path='../Models/44_1_khz/73999_iteration', config_path='../Models/44_1_khz/73999_iteration/hparams.yaml', device='cpu'):
451
+
452
+ if model_type == '44.1k':
453
+ if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
454
+ print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
455
+ from huggingface_hub import snapshot_download
456
+ folder_dir = snapshot_download(repo_id="sony/silentcipher")
457
+ ckpt_path = os.path.join(folder_dir, '44_1_khz/73999_iteration')
458
+ config_path = os.path.join(folder_dir, '44_1_khz/73999_iteration/hparams.yaml')
459
+
460
+ config = yaml.safe_load(open(config_path))
461
+ config = argparse.Namespace(**config)
462
+ config.load_ckpt = ckpt_path
463
+ model = Model(config, device)
464
+ elif model_type == '16k':
465
+ if not os.path.exists(ckpt_path) or not os.path.exists(config_path):
466
+ print('ckpt path or config path does not exist! Downloading the model from the Hugging Face Hub...')
467
+ from huggingface_hub import snapshot_download
468
+ folder_dir = snapshot_download(repo_id="sony/silentcipher")
469
+ ckpt_path = os.path.join(folder_dir, '16_khz/97561_iteration')
470
+ config_path = os.path.join(folder_dir, '16_khz/97561_iteration/hparams.yaml')
471
+
472
+ config = yaml.safe_load(open(config_path))
473
+ config = argparse.Namespace(**config)
474
+ config.load_ckpt = ckpt_path
475
+
476
+ model = Model(config, device)
477
+ else:
478
+ print('Please specify a valid model_type [44.1k, 16k]')
479
+
480
+ return model
silentcipher/stft.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Singleton(type):
4
+ _instances = {}
5
+ def __call__(cls, *args, **kwargs):
6
+ if cls not in cls._instances:
7
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
8
+ return cls._instances[cls]
9
+
10
+ class STFT(torch.nn.Module, metaclass=Singleton):
11
+ def __init__(self, filter_length=1024, hop_length=512):
12
+ super(STFT, self).__init__()
13
+
14
+ self.filter_length = filter_length
15
+ self.hop_len = hop_length
16
+ self.win_len = filter_length
17
+ self.window = torch.hann_window(self.win_len)
18
+ self.num_samples = -1
19
+
20
+ def transform(self, x):
21
+ x = torch.nn.functional.pad(x, (0, self.win_len - x.shape[1]%self.win_len))
22
+ fft = torch.stft(x, self.filter_length, self.hop_len, self.win_len, window=self.window.to(x.device), return_complex=True)
23
+
24
+ real_part, imag_part = fft.real, fft.imag
25
+
26
+ squared = real_part**2 + imag_part**2
27
+ additive_epsilon = torch.ones_like(squared) * (squared == 0).float() * 1e-24
28
+ magnitude = torch.sqrt(squared + additive_epsilon) - torch.sqrt(additive_epsilon)
29
+
30
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)).float()
31
+ return magnitude, phase
32
+
33
+ def inverse(self, magnitude, phase):
34
+
35
+ recombine_magnitude_phase = magnitude*torch.cos(phase) + 1j*magnitude*torch.sin(phase)
36
+ inverse_transform = torch.istft(recombine_magnitude_phase, self.filter_length, hop_length=self.hop_len, win_length=self.win_len, window=self.window.to(magnitude.device)).unsqueeze(1) # , length=self.num_samples
37
+ padding = self.win_len - (self.num_samples % self.win_len)
38
+ inverse_transform = inverse_transform[:, :, :-padding]
39
+ return inverse_transform
40
+
text_presets.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Reading | [S1] The old lighthouse keeper had seen many storms in his thirty years on the rock, but nothing like this. The fog rolled in thick as wool, swallowing the beam of light before it could reach the churning waves below. Then he heard it, three short bells from the channel, where no ship should be at this hour. He grabbed his lantern and peered into the mist, his heart pounding. Something was out there, something that shouldn't exist.
2
+
3
+ Reading | [S1] Deep beneath the ocean's surface, where sunlight fades to perpetual twilight, extraordinary creatures have evolved in ways that defy imagination. Bioluminescent jellyfish pulse with ethereal blue light, while giant squid hunt in the crushing darkness. At depths of over two miles, the pressure is immense, enough to collapse a submarine, yet life persists.
4
+
5
+ Reading | [S1] The telegram arrived on a Tuesday morning in June, nineteen forty-three. Margaret's hands trembled as she tore open the envelope, dreading the words she knew might be inside. Her brother had shipped out to North Africa six months ago, and his letters had grown increasingly sparse.
6
+
7
+ Reading | [S1] The ancient map showed a path through the Whispering Mountains that no living traveler had taken in generations. Legends spoke of a hidden valley where time moved differently, where a single day in the outside world meant years had passed within. As dawn broke over the snow-capped peaks, Elena shouldered her pack and began the ascent. Whatever waited at the journey's end, whether treasure or peril,
8
+
9
+ Cartoon | [S1] After giving everything some more thought, I've decided it's in the best interest of humanity to acquire Nexus AI. (laughs) I've spoken with the CEO and he's on board. Well (laughs), at least that's the impression he gave initially.
10
+
11
+ Single (Disfluent) | [S1] ... explore how we can design, create interfaces that are not confusing, but at the same time can be powerful. Um, you know, I think, uh, in the, the famous, um, usability book, it's, uh, it's this, um, um, oh, geez, I'm, I'm blanking on the term, uh, uh, the, the rule about, um, uh, it's like the simplicity rule. I can't recall. Oh, cognitive load maybe.
12
+
13
+ Single (Disfluent) | [S1] Uh, complacency when the motivation isn't structured properly. Like for example, if you, if you're in the cor- if you work in the corporation for many years, a lot of corporate employees, they just, they're, they're aiming for that stock vesting and they're, they're doing just a sufficient job to, to, to reach that vesting and, and they don't, they're not performing any better than that. Um, and so I think, um, that showed me an important insight. Yeah.
14
+
15
+ Single (Disfluent) | [S1] We see the pattern of revelations, major shifts. I think Neptune in Pisces, which that transit has been happening all of 2021, and Neptune will remain in the sign of Pisces until March of 2029. So it's several years more of this transit. And what it brings is a lot of things, you know, the thing that I tend to emphasize is the profound dissolution or profound changes
16
+
17
+ Single (Disfluent) | [S1] I asked her, "Do you have like a phrase you use," and she mentioned she actually does. Like when things get tense, when there's like a moment, like if her, if her roommate is like venting about work drama or just like is stressed, and her, her roommate like deals with anxiety, I'm like, "Oh, this is probably how it feels to live with me." But, um, and like if, if, if things are rough, like she'll internally just like use this practice where she's like, like, "Not my problem, not mine to carry, not mine to handle, not mine to change." Like she'll sort of repeat that. So that's interesting.
18
+
19
+ Single (Disfluent) | [S1] If I examine the, the, if, if you examine the range of options, uh, beginning from, like, say, individual all the way, right? There will be some revenue stream, uh, there will be some purchase, there'll be some hardware profit margin for someone who creates a smart product, um, uh, there will be memberships, personal and business, uh, and then there'll be usage-based, right? So I still believe that that's kinda how, those are all the metrics. To your point, what is a membership? Up to now, folks
20
+
21
+ Single (Disfluent) | [S1] I think, if, if we can keep it under 25 points allowed, sure, our odds improve significantly. We wouldn't need to put up huge numbers ourselves, or at least that's the theory. And I should, I want to share some other stats which might be a bit outside our current discussion, but regarding this compared to 2018, the team's final four games that year, they managed 18 points total.
22
+
23
+ Singing | [S1] (singing) Amazing grace, how sweet the sound, that saved a wretch like me. I once was lost, but now am found, was blind, but now I see.
24
+
25
+ Conversation | [S1] Alright then. So, so 18 years you spent in that, uh, in that role, but alongside that in, in, was it while you were working that position in '93, you started doing some work with the network? [S2] Uh, yes. It was somewhere around '93. I, I, I played tennis pretty well, you know? I, I, I competed as a tennis player. And the, I got a chance to do some broadcasting over in Brisbane.
26
+
27
+ Conversation | [S1] ... that will provide the analytics component- [S2] Right. [S1] ... to ideally get you to adopt some of their other tools. And- [S2] (laughs) [S1] ... some of those features are valuable too. [S2] That's interesting. [S1] Mailchimp, I mean, that's campaign manage-, uh, not exactly campaign management, but messaging platforms. [S2] Uh-huh. [S1] The, the companies that are, you know,
28
+
29
+ Conversation | [S1] They were like, they were pumped for it, going wild for it, and it disappeared immediately. [S2] Yeah, I think it's about people understanding what's available first. Um... [S1] I think the finish on that one too was really nice. [S2] Yeah. [S1] I mean, that was pretty awesome. [S2] Have you seen those new editions?
30
+
31
+ Conversation | [S1] He was just practicing with them and they were on rotation. [S2] So that was probably in January. [S1] I think startup stereotypes, there is some like that, but some of them, I think they need to be changed. Like we don't all work twenty-hour days. [S2] No, they just need to, it's called not, it's based in Silicon Valley. [S1] Yeah. [S2] But the stereotypes would apply if they, it was called Techlife- [S1] Palo Alto. [S2] ... Cupertino or Mountain View, California.
32
+
33
+
34
+ Conversation | [S1] That's a nice overview. [S2] We were at the downtown cinema. [S1] By that, you mean the one in Riverside? [S2] Yeah. [S1] Yeah. So not exactly downtown. [S2] Not exactly downtown, yeah. [S1] I know a little bit about that area. [S2] (laughs) [S1] You know, Millbrook doesn't have a cinema. [S2] (laughs) It's the closest one for us. It's the closest. [S1] Yeah, that's true. [S2] The most nearby. [S1] Riverside is nearby. [S2] Riverside's close. [S1] That's fair. [S2] Support nearby. [S1] You can say, say Riverside, definitely. [S2] Well, yeah, fair enough.
35
+
36
+ Conversation | [S1] But they also, they also discovered, um, they also discovered like patterns in the desert, um, near Peru, like in the Atacama Desert. [S2] Yeah. [S1] Um, and like, it was like, of like perfectly, like, geo- geometric shapes. And they're like, "Yo, this is definitely not like formed by wind. This has to be artificial." [S2] Yeah, it's too precise.
37
+
38
+ Conversation | [S1] 'Cause I, yeah, there, there has to be a way that they can just make the, the system recognize that, no, you did not earn this- [S2] (laughs) [S1] ... on your own. You still have to go and complete one if you want it for your own- [S2] Right. [S1] ... like, profile. [S2] Right. Mm-hmm. [S1] So, yeah. [S2] Um, yeah. So let's actually move into multiplayer.
39
+
40
+ Conversation | [S1] Yeah. [S2] Yeah. TRS as a whole is just relaxed. [S1] But anyway, you know that Mirror app that launched and then got removed like a month later? [S2] Mirror, what, like, to your future? [S1] Yeah. [S2] Oh. [S1] So basically, there was an app, there's a show coming out. [S2] This is a show. [S1] Coming, I don't know what it is. [S2] Yeah, yeah, yeah. [S1] Like 2026 or something. Basically, Marcus, have you heard about this? [S2] I'm sorry, I don't know. No, I don't have an, it's an app- [S1] Okay, so I'll explain. I'll explain. [S2] Yeah. [S1] For context. So there's this app that launched in terms of the show called Mirror.
41
+
42
+ Conversation | [S1] Jamie Patterson, right? [S2] No, I know where- [S1] I know where- [S2] ... Patterson works as well. I know where- [S1] I know- I know he used to work near- on this street, and this is a weird street. [S2] The only person who I don't know where they work, Jamie. But anyway, why are we even talking about who works where? [S1] It was a- it was- it was a really weird street name where Jamie worked. [S2] I- I drove past this street on my commute. [S1] No, you didn't. [S2] Yeah, I did. [S1] No, you drove past the street that my street is down the street of. [S2] Nice. There's, like, one street in Oakfield, I think I'll be able to find it, mate.