NeoPy commited on
Commit
e1bc36a
·
verified ·
1 Parent(s): 9bf28e2
infer/lib/predictors/RMVPE/RMVPE.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+
8
+ sys.path.append(os.getcwd())
9
+
10
+ from main.library.predictors.RMVPE.mel import MelSpectrogram
11
+
12
+ N_MELS, N_CLASS = 128, 360
13
+
14
+ class RMVPE:
15
+ def __init__(self, model_path, is_half, device=None, providers=None, onnx=False, hpa=False):
16
+ self.onnx = onnx
17
+
18
+ if self.onnx:
19
+ import onnxruntime as ort
20
+
21
+ sess_options = ort.SessionOptions()
22
+ sess_options.log_severity_level = 3
23
+ self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
24
+ else:
25
+ from main.library.predictors.RMVPE.e2e import E2E
26
+ model = E2E(4, 1, (2, 2), 5, 4, 1, 16, hpa=hpa)
27
+
28
+ model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
29
+ model.eval()
30
+ if is_half: model = model.half()
31
+ self.model = model.to(device)
32
+
33
+ self.device = device
34
+ self.is_half = is_half
35
+ self.mel_extractor = MelSpectrogram(N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
36
+ cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
37
+ self.cents_mapping = np.pad(cents_mapping, (4, 4))
38
+
39
+ def mel2hidden(self, mel, chunk_size = 32000):
40
+ with torch.no_grad():
41
+ n_frames = mel.shape[-1]
42
+ mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
43
+
44
+ output_chunks = []
45
+ pad_frames = mel.shape[-1]
46
+
47
+ for start in range(0, pad_frames, chunk_size):
48
+ mel_chunk = mel[..., start:min(start + chunk_size, pad_frames)]
49
+ assert mel_chunk.shape[-1] % 32 == 0
50
+
51
+ if self.onnx:
52
+ mel_chunk = mel_chunk.cpu().numpy().astype(np.float32)
53
+
54
+ out_chunk = torch.as_tensor(
55
+ self.model.run(
56
+ [self.model.get_outputs()[0].name],
57
+ {self.model.get_inputs()[0].name: mel_chunk}
58
+ )[0],
59
+ device=self.device
60
+ )
61
+ else:
62
+ if self.is_half: mel_chunk = mel_chunk.half()
63
+ out_chunk = self.model(mel_chunk)
64
+
65
+ output_chunks.append(out_chunk)
66
+
67
+ hidden = torch.cat(output_chunks, dim=1)
68
+ return hidden[:, :n_frames]
69
+
70
+ def decode(self, hidden, thred=0.03):
71
+ f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
72
+ f0[f0 == 10] = 0
73
+
74
+ return f0
75
+
76
+ def infer_from_audio(self, audio, thred=0.03):
77
+ hidden = self.mel2hidden(
78
+ self.mel_extractor(
79
+ torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True
80
+ )
81
+ )
82
+
83
+ return self.decode(
84
+ hidden.squeeze(0).cpu().numpy().astype(np.float32),
85
+ thred=thred
86
+ )
87
+
88
+ def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
89
+ f0 = self.infer_from_audio(audio, thred)
90
+ f0[(f0 < f0_min) | (f0 > f0_max)] = 0
91
+
92
+ return f0
93
+
94
+ def to_local_average_cents(self, salience, thred=0.05):
95
+ center = np.argmax(salience, axis=1)
96
+ salience = np.pad(salience, ((0, 0), (4, 4)))
97
+ center += 4
98
+ todo_salience, todo_cents_mapping = [], []
99
+ starts = center - 4
100
+ ends = center + 5
101
+
102
+ for idx in range(salience.shape[0]):
103
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
104
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
105
+
106
+ todo_salience = np.array(todo_salience)
107
+ devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
108
+ devided[np.max(salience, axis=1) <= thred] = 0
109
+
110
+ return devided
infer/lib/predictors/RMVPE/deepunet.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+
6
+ import torch.nn as nn
7
+
8
+ sys.path.append(os.getcwd())
9
+
10
+ from main.library.predictors.RMVPE.yolo import YOLO13Encoder, YOLO13FullPADDecoder, HyperACE
11
+
12
+ class ConvBlockRes(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels,
16
+ out_channels,
17
+ momentum=0.01
18
+ ):
19
+ super(ConvBlockRes, self).__init__()
20
+ self.conv = nn.Sequential(
21
+ nn.Conv2d(
22
+ in_channels=in_channels,
23
+ out_channels=out_channels,
24
+ kernel_size=(3, 3),
25
+ stride=(1, 1),
26
+ padding=(1, 1),
27
+ bias=False
28
+ ),
29
+ nn.BatchNorm2d(
30
+ out_channels,
31
+ momentum=momentum
32
+ ),
33
+ nn.ReLU(),
34
+ nn.Conv2d(
35
+ in_channels=out_channels,
36
+ out_channels=out_channels,
37
+ kernel_size=(3, 3),
38
+ stride=(1, 1),
39
+ padding=(1, 1),
40
+ bias=False
41
+ ),
42
+ nn.BatchNorm2d(
43
+ out_channels,
44
+ momentum=momentum
45
+ ),
46
+ nn.ReLU()
47
+ )
48
+
49
+ if in_channels != out_channels:
50
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
51
+ self.is_shortcut = True
52
+ else: self.is_shortcut = False
53
+
54
+ def forward(self, x):
55
+ return (
56
+ self.conv(x) + self.shortcut(x)
57
+ ) if self.is_shortcut else (
58
+ self.conv(x) + x
59
+ )
60
+
61
+ class ResEncoderBlock(nn.Module):
62
+ def __init__(
63
+ self,
64
+ in_channels,
65
+ out_channels,
66
+ kernel_size,
67
+ n_blocks=1,
68
+ momentum=0.01
69
+ ):
70
+ super(ResEncoderBlock, self).__init__()
71
+ self.n_blocks = n_blocks
72
+ self.conv = nn.ModuleList()
73
+ self.conv.append(
74
+ ConvBlockRes(
75
+ in_channels,
76
+ out_channels,
77
+ momentum
78
+ )
79
+ )
80
+
81
+ for _ in range(n_blocks - 1):
82
+ self.conv.append(
83
+ ConvBlockRes(
84
+ out_channels,
85
+ out_channels,
86
+ momentum
87
+ )
88
+ )
89
+
90
+ self.kernel_size = kernel_size
91
+ if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
92
+
93
+ def forward(self, x):
94
+ for i in range(self.n_blocks):
95
+ x = self.conv[i](x)
96
+
97
+ if self.kernel_size is not None: return x, self.pool(x)
98
+ else: return x
99
+
100
+ class Encoder(nn.Module):
101
+ def __init__(
102
+ self,
103
+ in_channels,
104
+ in_size,
105
+ n_encoders,
106
+ kernel_size,
107
+ n_blocks,
108
+ out_channels=16,
109
+ momentum=0.01
110
+ ):
111
+ super(Encoder, self).__init__()
112
+ self.n_encoders = n_encoders
113
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
114
+ self.layers = nn.ModuleList()
115
+
116
+ for _ in range(self.n_encoders):
117
+ self.layers.append(
118
+ ResEncoderBlock(
119
+ in_channels,
120
+ out_channels,
121
+ kernel_size,
122
+ n_blocks,
123
+ momentum=momentum
124
+ )
125
+ )
126
+
127
+ in_channels = out_channels
128
+ out_channels *= 2
129
+ in_size //= 2
130
+
131
+ self.out_size = in_size
132
+ self.out_channel = out_channels
133
+
134
+ def forward(self, x):
135
+ concat_tensors = []
136
+ x = self.bn(x)
137
+
138
+ for layer in self.layers:
139
+ t, x = layer(x)
140
+ concat_tensors.append(t)
141
+
142
+ return x, concat_tensors
143
+
144
+ class Intermediate(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels,
148
+ out_channels,
149
+ n_inters,
150
+ n_blocks,
151
+ momentum=0.01
152
+ ):
153
+ super(Intermediate, self).__init__()
154
+ self.layers = nn.ModuleList()
155
+ self.layers.append(
156
+ ResEncoderBlock(
157
+ in_channels,
158
+ out_channels,
159
+ None,
160
+ n_blocks,
161
+ momentum
162
+ )
163
+ )
164
+
165
+ for _ in range(n_inters - 1):
166
+ self.layers.append(
167
+ ResEncoderBlock(
168
+ out_channels,
169
+ out_channels,
170
+ None,
171
+ n_blocks,
172
+ momentum
173
+ )
174
+ )
175
+
176
+ def forward(self, x):
177
+ for layer in self.layers:
178
+ x = layer(x)
179
+
180
+ return x
181
+
182
+ class ResDecoderBlock(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channels,
186
+ out_channels,
187
+ stride,
188
+ n_blocks=1,
189
+ momentum=0.01
190
+ ):
191
+ super(ResDecoderBlock, self).__init__()
192
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
193
+ self.conv1 = nn.Sequential(
194
+ nn.ConvTranspose2d(
195
+ in_channels=in_channels,
196
+ out_channels=out_channels,
197
+ kernel_size=(3, 3),
198
+ stride=stride,
199
+ padding=(1, 1),
200
+ output_padding=out_padding,
201
+ bias=False
202
+ ),
203
+ nn.BatchNorm2d(
204
+ out_channels,
205
+ momentum=momentum
206
+ ),
207
+ nn.ReLU()
208
+ )
209
+
210
+ self.conv2 = nn.ModuleList()
211
+ self.conv2.append(
212
+ ConvBlockRes(
213
+ out_channels * 2,
214
+ out_channels,
215
+ momentum
216
+ )
217
+ )
218
+
219
+ for _ in range(n_blocks - 1):
220
+ self.conv2.append(
221
+ ConvBlockRes(
222
+ out_channels,
223
+ out_channels,
224
+ momentum
225
+ )
226
+ )
227
+
228
+ def forward(self, x, concat_tensor):
229
+ x = torch.cat((self.conv1(x), concat_tensor), dim=1)
230
+ for conv2 in self.conv2:
231
+ x = conv2(x)
232
+
233
+ return x
234
+
235
+ class Decoder(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels,
239
+ n_decoders,
240
+ stride,
241
+ n_blocks,
242
+ momentum=0.01
243
+ ):
244
+ super(Decoder, self).__init__()
245
+ self.layers = nn.ModuleList()
246
+
247
+ for _ in range(n_decoders):
248
+ out_channels = in_channels // 2
249
+ self.layers.append(
250
+ ResDecoderBlock(
251
+ in_channels,
252
+ out_channels,
253
+ stride,
254
+ n_blocks,
255
+ momentum
256
+ )
257
+ )
258
+ in_channels = out_channels
259
+
260
+ def forward(self, x, concat_tensors):
261
+ for i, layer in enumerate(self.layers):
262
+ x = layer(x, concat_tensors[-1 - i])
263
+
264
+ return x
265
+
266
+ class DeepUnet(nn.Module):
267
+ def __init__(
268
+ self,
269
+ kernel_size,
270
+ n_blocks,
271
+ en_de_layers=5,
272
+ inter_layers=4,
273
+ in_channels=1,
274
+ en_out_channels=16
275
+ ):
276
+ super(DeepUnet, self).__init__()
277
+ self.encoder = Encoder(
278
+ in_channels,
279
+ 128,
280
+ en_de_layers,
281
+ kernel_size,
282
+ n_blocks,
283
+ en_out_channels
284
+ )
285
+ self.intermediate = Intermediate(
286
+ self.encoder.out_channel // 2,
287
+ self.encoder.out_channel,
288
+ inter_layers,
289
+ n_blocks
290
+ )
291
+ self.decoder = Decoder(
292
+ self.encoder.out_channel,
293
+ en_de_layers,
294
+ kernel_size,
295
+ n_blocks
296
+ )
297
+
298
+ def forward(self, x):
299
+ x, concat_tensors = self.encoder(x)
300
+
301
+ return self.decoder(
302
+ self.intermediate(x),
303
+ concat_tensors
304
+ )
305
+
306
+ class HPADeepUnet(nn.Module):
307
+ def __init__(
308
+ self,
309
+ in_channels=1,
310
+ en_out_channels=16,
311
+ base_channels=64,
312
+ hyperace_k=2,
313
+ hyperace_l=1,
314
+ num_hyperedges=16,
315
+ num_heads=8
316
+ ):
317
+ super().__init__()
318
+ self.encoder = YOLO13Encoder(
319
+ in_channels,
320
+ base_channels
321
+ )
322
+
323
+ enc_ch = self.encoder.out_channels
324
+
325
+ self.hyperace = HyperACE(
326
+ in_channels=enc_ch,
327
+ out_channels=enc_ch[-1],
328
+ num_hyperedges=num_hyperedges,
329
+ num_heads=num_heads,
330
+ k=hyperace_k,
331
+ l=hyperace_l
332
+ )
333
+
334
+ self.decoder = YOLO13FullPADDecoder(
335
+ encoder_channels=enc_ch,
336
+ hyperace_out_c=enc_ch[-1],
337
+ out_channels_final=en_out_channels
338
+ )
339
+
340
+ def forward(self, x):
341
+ features = self.encoder(x)
342
+
343
+ return nn.functional.interpolate(
344
+ self.decoder(
345
+ features,
346
+ self.hyperace(features)
347
+ ),
348
+ size=x.shape[2:],
349
+ mode='bilinear',
350
+ align_corners=False
351
+ )
infer/lib/predictors/RMVPE/e2e.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import torch.nn as nn
6
+
7
+ sys.path.append(os.getcwd())
8
+
9
+ from main.library.predictors.RMVPE.deepunet import DeepUnet, HPADeepUnet
10
+
11
+ N_MELS, N_CLASS = 128, 360
12
+
13
+ class BiGRU(nn.Module):
14
+ def __init__(
15
+ self,
16
+ input_features,
17
+ hidden_features,
18
+ num_layers
19
+ ):
20
+ super(BiGRU, self).__init__()
21
+ self.gru = nn.GRU(
22
+ input_features,
23
+ hidden_features,
24
+ num_layers=num_layers,
25
+ batch_first=True,
26
+ bidirectional=True
27
+ )
28
+
29
+ def forward(self, x):
30
+ try:
31
+ return self.gru(x)[0]
32
+ except:
33
+ torch.backends.cudnn.enabled = False
34
+ return self.gru(x)[0]
35
+
36
+ class E2E(nn.Module):
37
+ def __init__(
38
+ self,
39
+ n_blocks,
40
+ n_gru,
41
+ kernel_size,
42
+ en_de_layers=5,
43
+ inter_layers=4,
44
+ in_channels=1,
45
+ en_out_channels=16,
46
+ hpa=False
47
+ ):
48
+ super(E2E, self).__init__()
49
+ self.unet = (
50
+ HPADeepUnet(
51
+ in_channels=in_channels,
52
+ en_out_channels=en_out_channels,
53
+ base_channels=64,
54
+ hyperace_k=2,
55
+ hyperace_l=1,
56
+ num_hyperedges=16,
57
+ num_heads=4
58
+ )
59
+ ) if hpa else (
60
+ DeepUnet(
61
+ kernel_size,
62
+ n_blocks,
63
+ en_de_layers,
64
+ inter_layers,
65
+ in_channels,
66
+ en_out_channels
67
+ )
68
+ )
69
+
70
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
71
+ self.fc = (
72
+ nn.Sequential(
73
+ BiGRU(3 * 128, 256, n_gru),
74
+ nn.Linear(512, N_CLASS),
75
+ nn.Dropout(0.25),
76
+ nn.Sigmoid()
77
+ )
78
+ ) if n_gru else (
79
+ nn.Sequential(
80
+ nn.Linear(3 * N_MELS, N_CLASS),
81
+ nn.Dropout(0.25),
82
+ nn.Sigmoid()
83
+ )
84
+ )
85
+
86
+ def forward(self, mel):
87
+ return self.fc(
88
+ self.cnn(
89
+ self.unet(
90
+ mel.transpose(-1, -2).unsqueeze(1)
91
+ )
92
+ ).transpose(1, 2).flatten(-2)
93
+ )
infer/lib/predictors/RMVPE/mel.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from librosa.filters import mel
10
+
11
+ sys.path.append(os.getcwd())
12
+
13
+ class MelSpectrogram(nn.Module):
14
+ def __init__(
15
+ self,
16
+ n_mel_channels,
17
+ sample_rate,
18
+ win_length,
19
+ hop_length,
20
+ n_fft=None,
21
+ mel_fmin=0,
22
+ mel_fmax=None,
23
+ clamp=1e-5
24
+ ):
25
+ super().__init__()
26
+ n_fft = win_length if n_fft is None else n_fft
27
+ self.hann_window = {}
28
+ mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
29
+ mel_basis = torch.from_numpy(mel_basis).float()
30
+ self.register_buffer("mel_basis", mel_basis)
31
+ self.n_fft = win_length if n_fft is None else n_fft
32
+ self.hop_length = hop_length
33
+ self.win_length = win_length
34
+ self.sample_rate = sample_rate
35
+ self.n_mel_channels = n_mel_channels
36
+ self.clamp = clamp
37
+
38
+ def forward(self, audio, keyshift=0, speed=1, center=True):
39
+ factor = 2 ** (keyshift / 12)
40
+ win_length_new = int(np.round(self.win_length * factor))
41
+
42
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
43
+ if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
44
+
45
+ n_fft = int(np.round(self.n_fft * factor))
46
+ hop_length = int(np.round(self.hop_length * speed))
47
+
48
+ if str(audio.device).startswith(("ocl", "privateuseone")):
49
+ if not hasattr(self, "stft"):
50
+ from main.library.backends.utils import STFT
51
+
52
+ self.stft = STFT(
53
+ filter_length=n_fft,
54
+ hop_length=hop_length,
55
+ win_length=win_length_new
56
+ ).to(audio.device)
57
+
58
+ magnitude = self.stft.transform(audio, 1e-9)
59
+ else:
60
+ fft = torch.stft(
61
+ audio,
62
+ n_fft=n_fft,
63
+ hop_length=hop_length,
64
+ win_length=win_length_new,
65
+ window=self.hann_window[keyshift_key],
66
+ center=center,
67
+ return_complex=True
68
+ )
69
+
70
+ magnitude = (fft.real.pow(2) + fft.imag.pow(2)).sqrt()
71
+
72
+ if keyshift != 0:
73
+ size = self.n_fft // 2 + 1
74
+ resize = magnitude.size(1)
75
+
76
+ if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
77
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
78
+
79
+ mel_output = self.mel_basis @ magnitude
80
+ return mel_output.clamp(min=self.clamp).log()
infer/lib/predictors/RMVPE/yolo.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def autopad(k, p=None):
6
+ if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
7
+ return p
8
+
9
+ class Conv(nn.Module):
10
+ def __init__(
11
+ self,
12
+ c1,
13
+ c2,
14
+ k=1,
15
+ s=1,
16
+ p=None,
17
+ g=1,
18
+ act=True
19
+ ):
20
+ super().__init__()
21
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
22
+ self.bn = nn.BatchNorm2d(c2)
23
+ self.act = nn.SiLU() if act else nn.Identity()
24
+
25
+ def forward(self, x):
26
+ return self.act(
27
+ self.bn(
28
+ self.conv(x)
29
+ )
30
+ )
31
+
32
+ class DSConv(nn.Module):
33
+ def __init__(
34
+ self,
35
+ c1,
36
+ c2,
37
+ k=3,
38
+ s=1,
39
+ p=None,
40
+ act=True
41
+ ):
42
+ super().__init__()
43
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
44
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
45
+ self.bn = nn.BatchNorm2d(c2)
46
+ self.act = nn.SiLU() if act else nn.Identity()
47
+
48
+ def forward(self, x):
49
+ return self.act(
50
+ self.bn(
51
+ self.pwconv(
52
+ self.dwconv(x)
53
+ )
54
+ )
55
+ )
56
+
57
+ class DS_Bottleneck(nn.Module):
58
+ def __init__(
59
+ self,
60
+ c1,
61
+ c2,
62
+ k=3,
63
+ shortcut=True
64
+ ):
65
+ super().__init__()
66
+ self.dsconv1 = DSConv(c1, c1, k=3, s=1)
67
+ self.dsconv2 = DSConv(c1, c2, k=k, s=1)
68
+ self.shortcut = shortcut and c1 == c2
69
+
70
+ def forward(self, x):
71
+ return x + self.dsconv2(self.dsconv1(x)) if self.shortcut else self.dsconv2(self.dsconv1(x))
72
+
73
+ class DS_C3k(nn.Module):
74
+ def __init__(
75
+ self,
76
+ c1,
77
+ c2,
78
+ n=1,
79
+ k=3,
80
+ e=0.5
81
+ ):
82
+ super().__init__()
83
+ self.cv1 = Conv(c1, int(c2 * e), 1, 1)
84
+ self.cv2 = Conv(c1, int(c2 * e), 1, 1)
85
+ self.cv3 = Conv(2 * int(c2 * e), c2, 1, 1)
86
+ self.m = nn.Sequential(
87
+ *[
88
+ DS_Bottleneck(
89
+ int(c2 * e),
90
+ int(c2 * e),
91
+ k=k,
92
+ shortcut=True
93
+ )
94
+ for _ in range(n)
95
+ ]
96
+ )
97
+
98
+ def forward(self, x):
99
+ return self.cv3(
100
+ torch.cat(
101
+ (self.m(self.cv1(x)), self.cv2(x)),
102
+ dim=1
103
+ )
104
+ )
105
+
106
+ class DS_C3k2(nn.Module):
107
+ def __init__(
108
+ self,
109
+ c1,
110
+ c2,
111
+ n=1,
112
+ k=3,
113
+ e=0.5
114
+ ):
115
+ super().__init__()
116
+ self.cv1 = Conv(c1, int(c2 * e), 1, 1)
117
+ self.m = DS_C3k(int(c2 * e), int(c2 * e), n=n, k=k, e=1.0)
118
+ self.cv2 = Conv(int(c2 * e), c2, 1, 1)
119
+
120
+ def forward(self, x):
121
+ return self.cv2(
122
+ self.m(
123
+ self.cv1(x)
124
+ )
125
+ )
126
+
127
+ class AdaptiveHyperedgeGeneration(nn.Module):
128
+ def __init__(
129
+ self,
130
+ in_channels,
131
+ num_hyperedges,
132
+ num_heads
133
+ ):
134
+ super().__init__()
135
+ self.num_hyperedges = num_hyperedges
136
+ self.num_heads = num_heads
137
+ self.head_dim = max(1, in_channels // num_heads)
138
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
139
+ self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False)
140
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
141
+ self.scale = self.head_dim ** -0.5
142
+
143
+ def forward(self, x):
144
+ B, N, C = x.shape
145
+ P = (
146
+ self.global_proto.unsqueeze(0) +
147
+ self.context_mapper(
148
+ torch.cat(
149
+ (
150
+ F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1),
151
+ F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
152
+ ),
153
+ dim=1
154
+ )
155
+ ).view(B, self.num_hyperedges, C))
156
+
157
+ return F.softmax((
158
+ (self.query_proj(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) @ P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)) * self.scale
159
+ ).mean(dim=1).permute(0, 2, 1), dim=-1)
160
+
161
+ class HypergraphConvolution(nn.Module):
162
+ def __init__(
163
+ self,
164
+ in_channels,
165
+ out_channels
166
+ ):
167
+ super().__init__()
168
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
169
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
170
+ self.act = nn.SiLU()
171
+
172
+ def forward(self, x, A):
173
+ return x + self.act(self.W_v(A.transpose(1, 2).bmm(self.act(self.W_e(A.bmm(x))))))
174
+
175
+ class AdaptiveHypergraphComputation(nn.Module):
176
+ def __init__(
177
+ self,
178
+ in_channels,
179
+ out_channels,
180
+ num_hyperedges,
181
+ num_heads
182
+ ):
183
+ super().__init__()
184
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(in_channels, num_hyperedges, num_heads)
185
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
186
+
187
+ def forward(self, x):
188
+ B, _, H, W = x.shape
189
+ x_flat = x.flatten(2).permute(0, 2, 1)
190
+
191
+ return self.hypergraph_conv(x_flat, self.adaptive_hyperedge_gen(x_flat)).permute(0, 2, 1).view(B, -1, H, W)
192
+
193
+ class C3AH(nn.Module):
194
+ def __init__(
195
+ self,
196
+ c1,
197
+ c2,
198
+ num_hyperedges,
199
+ num_heads,
200
+ e=0.5
201
+ ):
202
+ super().__init__()
203
+ self.cv1 = Conv(c1, int(c1 * e), 1, 1)
204
+ self.cv2 = Conv(c1, int(c1 * e), 1, 1)
205
+ self.ahc = AdaptiveHypergraphComputation(int(c1 * e), int(c1 * e), num_hyperedges, num_heads)
206
+ self.cv3 = Conv(2 * int(c1 * e), c2, 1, 1)
207
+
208
+ def forward(self, x):
209
+ return self.cv3(
210
+ torch.cat(
211
+ (self.ahc(self.cv2(x)), self.cv1(x)),
212
+ dim=1
213
+ )
214
+ )
215
+
216
+ class HyperACE(nn.Module):
217
+ def __init__(
218
+ self,
219
+ in_channels,
220
+ out_channels,
221
+ num_hyperedges=16,
222
+ num_heads=8,
223
+ k=2,
224
+ l=1,
225
+ c_h=0.5,
226
+ c_l=0.25
227
+ ):
228
+ super().__init__()
229
+ c2, c3, c4, c5 = in_channels
230
+ c_mid = c4
231
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
232
+ self.c_h = int(c_mid * c_h)
233
+ self.c_l = int(c_mid * c_l)
234
+ self.c_s = c_mid - self.c_h - self.c_l
235
+ self.high_order_branch = nn.ModuleList([
236
+ C3AH(
237
+ self.c_h,
238
+ self.c_h,
239
+ num_hyperedges=num_hyperedges,
240
+ num_heads=num_heads, e=1.0
241
+ )
242
+ for _ in range(k)
243
+ ])
244
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
245
+ self.low_order_branch = nn.Sequential(
246
+ *[
247
+ DS_C3k(
248
+ self.c_l,
249
+ self.c_l,
250
+ n=1,
251
+ k=3,
252
+ e=1.0
253
+ )
254
+ for _ in range(l)
255
+ ]
256
+ )
257
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
258
+
259
+ def forward(self, x):
260
+ B2, B3, B4, B5 = x
261
+ _, _, H4, W4 = B4.shape
262
+
263
+ x_h, x_l, x_s = self.fuse_conv(
264
+ torch.cat(
265
+ (
266
+ F.interpolate(
267
+ B2,
268
+ size=(H4, W4),
269
+ mode='bilinear',
270
+ align_corners=False
271
+ ),
272
+ F.interpolate(
273
+ B3,
274
+ size=(H4, W4),
275
+ mode='bilinear',
276
+ align_corners=False
277
+ ),
278
+ B4,
279
+ F.interpolate(
280
+ B5,
281
+ size=(H4, W4),
282
+ mode='bilinear',
283
+ align_corners=False
284
+ )
285
+ ),
286
+ dim=1
287
+ )
288
+ ).split([self.c_h, self.c_l, self.c_s], dim=1)
289
+
290
+ return self.final_fuse(
291
+ torch.cat(
292
+ (
293
+ self.high_order_fuse(torch.cat([m(x_h) for m in self.high_order_branch], dim=1)),
294
+ self.low_order_branch(x_l),
295
+ x_s
296
+ ),
297
+ dim=1
298
+ )
299
+ )
300
+
301
+ class GatedFusion(nn.Module):
302
+ def __init__(
303
+ self,
304
+ in_channels
305
+ ):
306
+ super().__init__()
307
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
308
+
309
+ def forward(self, f_in, h):
310
+ return f_in + self.gamma * h
311
+
312
+ class YOLO13Encoder(nn.Module):
313
+ def __init__(
314
+ self,
315
+ in_channels,
316
+ base_channels=32
317
+ ):
318
+ super().__init__()
319
+ self.stem = DSConv(
320
+ in_channels,
321
+ base_channels,
322
+ k=3,
323
+ s=1
324
+ )
325
+ self.p2 = nn.Sequential(
326
+ DSConv(
327
+ base_channels,
328
+ base_channels*2, k=3, s=(2, 2)),
329
+ DS_C3k2(
330
+ base_channels*2,
331
+ base_channels*2,
332
+ n=1
333
+ )
334
+ )
335
+ self.p3 = nn.Sequential(
336
+ DSConv(
337
+ base_channels*2,
338
+ base_channels*4,
339
+ k=3,
340
+ s=(2, 2)
341
+ ),
342
+ DS_C3k2(
343
+ base_channels*4,
344
+ base_channels*4,
345
+ n=2
346
+ )
347
+ )
348
+ self.p4 = nn.Sequential(
349
+ DSConv(
350
+ base_channels*4,
351
+ base_channels*8,
352
+ k=3,
353
+ s=(2, 2)
354
+ ),
355
+ DS_C3k2(
356
+ base_channels*8,
357
+ base_channels*8,
358
+ n=2
359
+ )
360
+ )
361
+ self.p5 = nn.Sequential(
362
+ DSConv(
363
+ base_channels*8,
364
+ base_channels*16,
365
+ k=3,
366
+ s=(2, 2)
367
+ ),
368
+ DS_C3k2(
369
+ base_channels*16,
370
+ base_channels*16,
371
+ n=1
372
+ )
373
+ )
374
+
375
+ self.out_channels = [base_channels*2, base_channels*4, base_channels*8, base_channels*16]
376
+
377
+ def forward(self, x):
378
+ p2 = self.p2(self.stem(x))
379
+ p3 = self.p3(p2)
380
+ p4 = self.p4(p3)
381
+ p5 = self.p5(p4)
382
+
383
+ return [p2, p3, p4, p5]
384
+
385
+ class YOLO13FullPADDecoder(nn.Module):
386
+ def __init__(self, encoder_channels, hyperace_out_c, out_channels_final):
387
+ super().__init__()
388
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
389
+ c_d5, c_d4, c_d3, c_d2 = c_p5, c_p4, c_p3, c_p2
390
+
391
+ self.h_to_d5 = Conv(
392
+ hyperace_out_c,
393
+ c_d5,
394
+ 1,
395
+ 1
396
+ )
397
+ self.h_to_d4 = Conv(
398
+ hyperace_out_c,
399
+ c_d4,
400
+ 1,
401
+ 1
402
+ )
403
+ self.h_to_d3 = Conv(
404
+ hyperace_out_c,
405
+ c_d3,
406
+ 1,
407
+ 1
408
+ )
409
+ self.h_to_d2 = Conv(
410
+ hyperace_out_c,
411
+ c_d2,
412
+ 1,
413
+ 1
414
+ )
415
+
416
+ self.fusion_d5 = GatedFusion(c_d5)
417
+ self.fusion_d4 = GatedFusion(c_d4)
418
+ self.fusion_d3 = GatedFusion(c_d3)
419
+ self.fusion_d2 = GatedFusion(c_d2)
420
+
421
+ self.skip_p5 = Conv(
422
+ c_p5,
423
+ c_d5,
424
+ 1,
425
+ 1
426
+ )
427
+ self.skip_p4 = Conv(
428
+ c_p4,
429
+ c_d4,
430
+ 1,
431
+ 1
432
+ )
433
+ self.skip_p3 = Conv(
434
+ c_p3,
435
+ c_d3,
436
+ 1,
437
+ 1
438
+ )
439
+ self.skip_p2 = Conv(
440
+ c_p2,
441
+ c_d2,
442
+ 1,
443
+ 1
444
+ )
445
+
446
+ self.up_d5 = DS_C3k2(
447
+ c_d5,
448
+ c_d4,
449
+ n=1
450
+ )
451
+ self.up_d4 = DS_C3k2(
452
+ c_d4,
453
+ c_d3,
454
+ n=1
455
+ )
456
+ self.up_d3 = DS_C3k2(
457
+ c_d3,
458
+ c_d2,
459
+ n=1
460
+ )
461
+
462
+ self.final_d2 = DS_C3k2(
463
+ c_d2,
464
+ c_d2,
465
+ n=1
466
+ )
467
+ self.final_conv = Conv(
468
+ c_d2,
469
+ out_channels_final,
470
+ 1,
471
+ 1
472
+ )
473
+
474
+ def forward(self, enc_feats, h_ace):
475
+ p2, p3, p4, p5 = enc_feats
476
+ d5 = self.skip_p5(p5)
477
+
478
+ d4 = self.up_d5(
479
+ F.interpolate(
480
+ self.fusion_d5(d5, self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode='bilinear', align_corners=False))),
481
+ size=p4.shape[2:],
482
+ mode='bilinear',
483
+ align_corners=False
484
+ )
485
+ ) + self.skip_p4(p4)
486
+
487
+ d3 = self.up_d4(
488
+ F.interpolate(
489
+ self.fusion_d4(d4, self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode='bilinear', align_corners=False))),
490
+ size=p3.shape[2:],
491
+ mode='bilinear',
492
+ align_corners=False
493
+ )
494
+ ) + self.skip_p3(p3)
495
+
496
+ d2 = self.up_d3(
497
+ F.interpolate(
498
+ self.fusion_d3(d3, self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode='bilinear', align_corners=False))),
499
+ size=p2.shape[2:],
500
+ mode='bilinear',
501
+ align_corners=False
502
+ )
503
+ ) + self.skip_p2(p2)
504
+
505
+ return self.final_conv(
506
+ self.final_d2(
507
+ self.fusion_d2(
508
+ d2,
509
+ self.h_to_d2(
510
+ F.interpolate(h_ace, size=d2.shape[2:], mode='bilinear', align_corners=False)
511
+ )
512
+ )
513
+ )
514
+ )