mohammed-aljafry commited on
Commit
f38952d
·
verified ·
1 Parent(s): 2bc171c

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +1064 -0
model.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from functools import partial
7
+ import math
8
+ import copy
9
+ from typing import Optional, Tuple, Union, List
10
+ from torch import Tensor
11
+ from collections import OrderedDict
12
+ import numpy as np
13
+
14
+ from timm.models.resnet import resnet50d, resnet101d, resnet26d, resnet18d
15
+ from timm.models.registry import register_model
16
+
17
+ # --- Helper Functions for Model Definition ---
18
+ def to_2tuple(x):
19
+ if isinstance(x, tuple): return x
20
+ return (x, x)
21
+
22
+ def _get_clones(module, N):
23
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
24
+
25
+ def _get_activation_fn(activation):
26
+ """Return an activation function given a string"""
27
+ if activation == "relu":
28
+ return F.relu
29
+ if activation == "gelu":
30
+ return F.gelu
31
+ if activation == "glu":
32
+ return F.glu
33
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
34
+
35
+ def build_attn_mask(mask_type):
36
+ mask = torch.ones((151, 151), dtype=torch.bool).cuda()
37
+ if mask_type == "seperate_all":
38
+ mask[:50, :50] = False
39
+ mask[50:67, 50:67] = False
40
+ mask[67:84, 67:84] = False
41
+ mask[84:101, 84:101] = False
42
+ mask[101:151, 101:151] = False
43
+ elif mask_type == "seperate_view":
44
+ mask[:50, :50] = False
45
+ mask[50:67, 50:67] = False
46
+ mask[67:84, 67:84] = False
47
+ mask[84:101, 84:101] = False
48
+ mask[101:151, :] = False
49
+ mask[:, 101:151] = False
50
+ return mask
51
+
52
+ # --- Model Components ---
53
+ class HybridEmbed(nn.Module):
54
+ def __init__(
55
+ self,
56
+ backbone,
57
+ img_size=224,
58
+ patch_size=1,
59
+ feature_size=None,
60
+ in_chans=3,
61
+ embed_dim=768,
62
+ ):
63
+ super().__init__()
64
+ assert isinstance(backbone, nn.Module)
65
+ img_size = to_2tuple(img_size)
66
+ patch_size = to_2tuple(patch_size)
67
+ self.img_size = img_size
68
+ self.patch_size = patch_size
69
+ self.backbone = backbone
70
+ if feature_size is None:
71
+ with torch.no_grad():
72
+ training = backbone.training
73
+ if training:
74
+ backbone.eval()
75
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
76
+ if isinstance(o, (list, tuple)):
77
+ o = o[-1] # last feature if backbone outputs list/tuple of features
78
+ feature_size = o.shape[-2:]
79
+ feature_dim = o.shape[1]
80
+ backbone.train(training)
81
+ else:
82
+ feature_size = to_2tuple(feature_size)
83
+ if hasattr(self.backbone, "feature_info"):
84
+ feature_dim = self.backbone.feature_info.channels()[-1]
85
+ else:
86
+ feature_dim = self.backbone.num_features
87
+
88
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
89
+
90
+ def forward(self, x):
91
+ x = self.backbone(x)
92
+ if isinstance(x, (list, tuple)):
93
+ x = x[-1] # last feature if backbone outputs list/tuple of features
94
+ x = self.proj(x)
95
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
96
+ return x, global_x
97
+
98
+
99
+ class PositionEmbeddingSine(nn.Module):
100
+ def __init__(
101
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
102
+ ):
103
+ super().__init__()
104
+ self.num_pos_feats = num_pos_feats
105
+ self.temperature = temperature
106
+ self.normalize = normalize
107
+ if scale is not None and normalize is False:
108
+ raise ValueError("normalize should be True if scale is passed")
109
+ if scale is None:
110
+ scale = 2 * math.pi
111
+ self.scale = scale
112
+
113
+ def forward(self, tensor):
114
+ x = tensor
115
+ bs, _, h, w = x.shape
116
+ not_mask = torch.ones((bs, h, w), device=x.device)
117
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
118
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
119
+ if self.normalize:
120
+ eps = 1e-6
121
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
122
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
123
+
124
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
125
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
126
+
127
+ pos_x = x_embed[:, :, :, None] / dim_t
128
+ pos_y = y_embed[:, :, :, None] / dim_t
129
+ pos_x = torch.stack(
130
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
131
+ ).flatten(3)
132
+ pos_y = torch.stack(
133
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
134
+ ).flatten(3)
135
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
136
+ return pos
137
+
138
+
139
+ class TransformerEncoder(nn.Module):
140
+ def __init__(self, encoder_layer, num_layers, norm=None):
141
+ super().__init__()
142
+ self.layers = _get_clones(encoder_layer, num_layers)
143
+ self.num_layers = num_layers
144
+ self.norm = norm
145
+
146
+ def forward(
147
+ self,
148
+ src,
149
+ mask: Optional[Tensor] = None,
150
+ src_key_padding_mask: Optional[Tensor] = None,
151
+ pos: Optional[Tensor] = None,
152
+ ):
153
+ output = src
154
+
155
+ for layer in self.layers:
156
+ output = layer(
157
+ output,
158
+ src_mask=mask,
159
+ src_key_padding_mask=src_key_padding_mask,
160
+ pos=pos,
161
+ )
162
+
163
+ if self.norm is not None:
164
+ output = self.norm(output)
165
+
166
+ return output
167
+
168
+
169
+ class SpatialSoftmax(nn.Module):
170
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
171
+ super().__init__()
172
+
173
+ self.data_format = data_format
174
+ self.height = height
175
+ self.width = width
176
+ self.channel = channel
177
+
178
+ if temperature:
179
+ self.temperature = nn.Parameter(torch.ones(1) * temperature)
180
+ else:
181
+ self.temperature = 1.0
182
+
183
+ pos_x, pos_y = np.meshgrid(
184
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
185
+ )
186
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
187
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
188
+ self.register_buffer("pos_x", pos_x)
189
+ self.register_buffer("pos_y", pos_y)
190
+
191
+ def forward(self, feature):
192
+ if self.data_format == "NHWC":
193
+ feature = (
194
+ feature.transpose(1, 3)
195
+ .tranpose(2, 3)
196
+ .view(-1, self.height * self.width)
197
+ )
198
+ else:
199
+ feature = feature.view(-1, self.height * self.width)
200
+
201
+ weight = F.softmax(feature / self.temperature, dim=-1)
202
+ expected_x = torch.sum(
203
+ torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
204
+ )
205
+ expected_y = torch.sum(
206
+ torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
207
+ )
208
+ expected_xy = torch.cat([expected_x, expected_y], 1)
209
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
210
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
211
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
212
+ return feature_keypoints
213
+
214
+
215
+ class MultiPath_Generator(nn.Module):
216
+ def __init__(self, in_channel, embed_dim, out_channel):
217
+ super().__init__()
218
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
219
+ self.tconv0 = nn.Sequential(
220
+ nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
221
+ nn.BatchNorm2d(256),
222
+ nn.ReLU(True),
223
+ )
224
+ self.tconv1 = nn.Sequential(
225
+ nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
226
+ nn.BatchNorm2d(256),
227
+ nn.ReLU(True),
228
+ )
229
+ self.tconv2 = nn.Sequential(
230
+ nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
231
+ nn.BatchNorm2d(192),
232
+ nn.ReLU(True),
233
+ )
234
+ self.tconv3 = nn.Sequential(
235
+ nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
236
+ nn.BatchNorm2d(64),
237
+ nn.ReLU(True),
238
+ )
239
+ self.tconv4_list = torch.nn.ModuleList(
240
+ [
241
+ nn.Sequential(
242
+ nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
243
+ nn.Tanh(),
244
+ )
245
+ for _ in range(6)
246
+ ]
247
+ )
248
+
249
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
250
+
251
+ def forward(self, x, measurements):
252
+ mask = measurements[:, :6]
253
+ mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
254
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
255
+ velocity = velocity.repeat(1, 32, 2, 2)
256
+
257
+ n, d, c = x.shape
258
+ x = x.transpose(1, 2)
259
+ x = x.view(n, -1, 2, 2)
260
+ x = torch.cat([x, velocity], dim=1)
261
+ x = self.tconv0(x)
262
+ x = self.tconv1(x)
263
+ x = self.tconv2(x)
264
+ x = self.tconv3(x)
265
+ x = self.upsample(x)
266
+ xs = []
267
+ for i in range(6):
268
+ xt = self.tconv4_list[i](x)
269
+ xs.append(xt)
270
+ xs = torch.stack(xs, dim=1)
271
+ x = torch.sum(xs * mask, dim=1)
272
+ x = self.spatial_softmax(x)
273
+ return x
274
+
275
+
276
+ class LinearWaypointsPredictor(nn.Module):
277
+ def __init__(self, input_dim, cumsum=True):
278
+ super().__init__()
279
+ self.cumsum = cumsum
280
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
281
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
282
+ self.head_relu = nn.ReLU(inplace=True)
283
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
284
+
285
+ def forward(self, x, measurements):
286
+ # input shape: n 10 embed_dim
287
+ bs, n, dim = x.shape
288
+ x = x + self.rank_embed
289
+ x = x.reshape(-1, dim)
290
+
291
+ mask = measurements[:, :6]
292
+ mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
293
+
294
+ rs = []
295
+ for i in range(6):
296
+ res = self.head_fc1_list[i](x)
297
+ res = self.head_relu(res)
298
+ res = self.head_fc2_list[i](res)
299
+ rs.append(res)
300
+ rs = torch.stack(rs, 1)
301
+ x = torch.sum(rs * mask, dim=1)
302
+
303
+ x = x.view(bs, n, 2)
304
+ if self.cumsum:
305
+ x = torch.cumsum(x, 1)
306
+ return x
307
+
308
+
309
+ class GRUWaypointsPredictor(nn.Module):
310
+ def __init__(self, input_dim, waypoints=10):
311
+ super().__init__()
312
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
313
+ self.encoder = nn.Linear(2, 64)
314
+ self.decoder = nn.Linear(64, 2)
315
+ self.waypoints = waypoints
316
+
317
+ def forward(self, x, target_point):
318
+ bs = x.shape[0]
319
+ z = self.encoder(target_point).unsqueeze(0)
320
+ output, _ = self.gru(x, z)
321
+ output = output.reshape(bs * self.waypoints, -1)
322
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
323
+ output = torch.cumsum(output, 1)
324
+ return output
325
+
326
+ class GRUWaypointsPredictorWithCommand(nn.Module):
327
+ def __init__(self, input_dim, waypoints=10):
328
+ super().__init__()
329
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
330
+ self.encoder = nn.Linear(2, 64)
331
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
332
+ self.waypoints = waypoints
333
+
334
+ def forward(self, x, target_point, measurements):
335
+ bs, n, dim = x.shape
336
+ mask = measurements[:, :6, None, None]
337
+ mask = mask.repeat(1, 1, self.waypoints, 2)
338
+
339
+ z = self.encoder(target_point).unsqueeze(0)
340
+ outputs = []
341
+ for i in range(6):
342
+ output, _ = self.grus[i](x, z)
343
+ output = output.reshape(bs * self.waypoints, -1)
344
+ output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
345
+ output = torch.cumsum(output, 1)
346
+ outputs.append(output)
347
+ outputs = torch.stack(outputs, 1)
348
+ output = torch.sum(outputs * mask, dim=1)
349
+ return output
350
+
351
+
352
+ class TransformerDecoder(nn.Module):
353
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
354
+ super().__init__()
355
+ self.layers = _get_clones(decoder_layer, num_layers)
356
+ self.num_layers = num_layers
357
+ self.norm = norm
358
+ self.return_intermediate = return_intermediate
359
+
360
+ def forward(
361
+ self,
362
+ tgt,
363
+ memory,
364
+ tgt_mask: Optional[Tensor] = None,
365
+ memory_mask: Optional[Tensor] = None,
366
+ tgt_key_padding_mask: Optional[Tensor] = None,
367
+ memory_key_padding_mask: Optional[Tensor] = None,
368
+ pos: Optional[Tensor] = None,
369
+ query_pos: Optional[Tensor] = None,
370
+ ):
371
+ output = tgt
372
+
373
+ intermediate = []
374
+
375
+ for layer in self.layers:
376
+ output = layer(
377
+ output,
378
+ memory,
379
+ tgt_mask=tgt_mask,
380
+ memory_mask=memory_mask,
381
+ tgt_key_padding_mask=tgt_key_padding_mask,
382
+ memory_key_padding_mask=memory_key_padding_mask,
383
+ pos=pos,
384
+ query_pos=query_pos,
385
+ )
386
+ if self.return_intermediate:
387
+ intermediate.append(self.norm(output))
388
+
389
+ if self.norm is not None:
390
+ output = self.norm(output)
391
+ if self.return_intermediate:
392
+ intermediate.pop()
393
+ intermediate.append(output)
394
+
395
+ if self.return_intermediate:
396
+ return torch.stack(intermediate)
397
+
398
+ return output.unsqueeze(0)
399
+
400
+
401
+ class TransformerEncoderLayer(nn.Module):
402
+ def __init__(
403
+ self,
404
+ d_model,
405
+ nhead,
406
+ dim_feedforward=2048,
407
+ dropout=0.1,
408
+ activation=nn.ReLU(),
409
+ normalize_before=False,
410
+ ):
411
+ super().__init__()
412
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
413
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
414
+ self.dropout = nn.Dropout(dropout)
415
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
416
+
417
+ self.norm1 = nn.LayerNorm(d_model)
418
+ self.norm2 = nn.LayerNorm(d_model)
419
+ self.dropout1 = nn.Dropout(dropout)
420
+ self.dropout2 = nn.Dropout(dropout)
421
+
422
+ self.activation = activation()
423
+ self.normalize_before = normalize_before
424
+
425
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
426
+ return tensor if pos is None else tensor + pos
427
+
428
+ def forward_post(
429
+ self,
430
+ src,
431
+ src_mask: Optional[Tensor] = None,
432
+ src_key_padding_mask: Optional[Tensor] = None,
433
+ pos: Optional[Tensor] = None,
434
+ ):
435
+ q = k = self.with_pos_embed(src, pos)
436
+ src2 = self.self_attn(
437
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
438
+ )[0]
439
+ src = src + self.dropout1(src2)
440
+ src = self.norm1(src)
441
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
442
+ src = src + self.dropout2(src2)
443
+ src = self.norm2(src)
444
+ return src
445
+
446
+ def forward_pre(
447
+ self,
448
+ src,
449
+ src_mask: Optional[Tensor] = None,
450
+ src_key_padding_mask: Optional[Tensor] = None,
451
+ pos: Optional[Tensor] = None,
452
+ ):
453
+ src2 = self.norm1(src)
454
+ q = k = self.with_pos_embed(src2, pos)
455
+ src2 = self.self_attn(
456
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
457
+ )[0]
458
+ src = src + self.dropout1(src2)
459
+ src2 = self.norm2(src)
460
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
461
+ src = src + self.dropout2(src2)
462
+ return src
463
+
464
+ def forward(
465
+ self,
466
+ src,
467
+ src_mask: Optional[Tensor] = None,
468
+ src_key_padding_mask: Optional[Tensor] = None,
469
+ pos: Optional[Tensor] = None,
470
+ ):
471
+ if self.normalize_before:
472
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
473
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
474
+
475
+
476
+ class TransformerDecoderLayer(nn.Module):
477
+ def __init__(
478
+ self,
479
+ d_model,
480
+ nhead,
481
+ dim_feedforward=2048,
482
+ dropout=0.1,
483
+ activation=nn.ReLU(),
484
+ normalize_before=False,
485
+ ):
486
+ super().__init__()
487
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
488
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
489
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
490
+ self.dropout = nn.Dropout(dropout)
491
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
492
+
493
+ self.norm1 = nn.LayerNorm(d_model)
494
+ self.norm2 = nn.LayerNorm(d_model)
495
+ self.norm3 = nn.LayerNorm(d_model)
496
+ self.dropout1 = nn.Dropout(dropout)
497
+ self.dropout2 = nn.Dropout(dropout)
498
+ self.dropout3 = nn.Dropout(dropout)
499
+
500
+ self.activation = activation()
501
+ self.normalize_before = normalize_before
502
+
503
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
504
+ return tensor if pos is None else tensor + pos
505
+
506
+ def forward_post(
507
+ self,
508
+ tgt,
509
+ memory,
510
+ tgt_mask: Optional[Tensor] = None,
511
+ memory_mask: Optional[Tensor] = None,
512
+ tgt_key_padding_mask: Optional[Tensor] = None,
513
+ memory_key_padding_mask: Optional[Tensor] = None,
514
+ pos: Optional[Tensor] = None,
515
+ query_pos: Optional[Tensor] = None,
516
+ ):
517
+ q = k = self.with_pos_embed(tgt, query_pos)
518
+ tgt2 = self.self_attn(
519
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
520
+ )[0]
521
+ tgt = tgt + self.dropout1(tgt2)
522
+ tgt = self.norm1(tgt)
523
+ tgt2 = self.multihead_attn(
524
+ query=self.with_pos_embed(tgt, query_pos),
525
+ key=self.with_pos_embed(memory, pos),
526
+ value=memory,
527
+ attn_mask=memory_mask,
528
+ key_padding_mask=memory_key_padding_mask,
529
+ )[0]
530
+ tgt = tgt + self.dropout2(tgt2)
531
+ tgt = self.norm2(tgt)
532
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
533
+ tgt = tgt + self.dropout3(tgt2)
534
+ tgt = self.norm3(tgt)
535
+ return tgt
536
+
537
+ def forward_pre(
538
+ self,
539
+ tgt,
540
+ memory,
541
+ tgt_mask: Optional[Tensor] = None,
542
+ memory_mask: Optional[Tensor] = None,
543
+ tgt_key_padding_mask: Optional[Tensor] = None,
544
+ memory_key_padding_mask: Optional[Tensor] = None,
545
+ pos: Optional[Tensor] = None,
546
+ query_pos: Optional[Tensor] = None,
547
+ ):
548
+ tgt2 = self.norm1(tgt)
549
+ q = k = self.with_pos_embed(tgt2, query_pos)
550
+ tgt2 = self.self_attn(
551
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
552
+ )[0]
553
+ tgt = tgt + self.dropout1(tgt2)
554
+ tgt2 = self.norm2(tgt)
555
+ tgt2 = self.multihead_attn(
556
+ query=self.with_pos_embed(tgt2, query_pos),
557
+ key=self.with_pos_embed(memory, pos),
558
+ value=memory,
559
+ attn_mask=memory_mask,
560
+ key_padding_mask=memory_key_padding_mask,
561
+ )[0]
562
+ tgt = tgt + self.dropout2(tgt2)
563
+ tgt2 = self.norm3(tgt)
564
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
565
+ tgt = tgt + self.dropout3(tgt2)
566
+ return tgt
567
+
568
+ def forward(
569
+ self,
570
+ tgt,
571
+ memory,
572
+ tgt_mask: Optional[Tensor] = None,
573
+ memory_mask: Optional[Tensor] = None,
574
+ tgt_key_padding_mask: Optional[Tensor] = None,
575
+ memory_key_padding_mask: Optional[Tensor] = None,
576
+ pos: Optional[Tensor] = None,
577
+ query_pos: Optional[Tensor] = None,
578
+ ):
579
+ if self.normalize_before:
580
+ return self.forward_pre(
581
+ tgt,
582
+ memory,
583
+ tgt_mask,
584
+ memory_mask,
585
+ tgt_key_padding_mask,
586
+ memory_key_padding_mask,
587
+ pos,
588
+ query_pos,
589
+ )
590
+ return self.forward_post(
591
+ tgt,
592
+ memory,
593
+ tgt_mask,
594
+ memory_mask,
595
+ tgt_key_padding_mask,
596
+ memory_key_padding_mask,
597
+ pos,
598
+ query_pos,
599
+ )
600
+
601
+
602
+ class Interfuser(nn.Module):
603
+ def __init__(
604
+ self,
605
+ img_size=224,
606
+ multi_view_img_size=112,
607
+ patch_size=8,
608
+ in_chans=3,
609
+ embed_dim=768,
610
+ enc_depth=6,
611
+ dec_depth=6,
612
+ dim_feedforward=2048,
613
+ normalize_before=False,
614
+ rgb_backbone_name="r26",
615
+ lidar_backbone_name="r26",
616
+ num_heads=8,
617
+ norm_layer=None,
618
+ dropout=0.1,
619
+ end2end=False,
620
+ direct_concat=True,
621
+ separate_view_attention=False,
622
+ separate_all_attention=False,
623
+ act_layer=None,
624
+ weight_init="",
625
+ freeze_num=-1,
626
+ with_lidar=False,
627
+ with_right_left_sensors=True,
628
+ with_center_sensor=False,
629
+ traffic_pred_head_type="det",
630
+ waypoints_pred_head="heatmap",
631
+ reverse_pos=True,
632
+ use_different_backbone=False,
633
+ use_view_embed=True,
634
+ use_mmad_pretrain=None,
635
+ ):
636
+ super().__init__()
637
+ self.traffic_pred_head_type = traffic_pred_head_type
638
+ self.num_features = (
639
+ self.embed_dim
640
+ ) = embed_dim
641
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
642
+ act_layer = act_layer or nn.GELU
643
+
644
+ self.reverse_pos = reverse_pos
645
+ self.waypoints_pred_head = waypoints_pred_head
646
+ self.with_lidar = with_lidar
647
+ self.with_right_left_sensors = with_right_left_sensors
648
+ self.with_center_sensor = with_center_sensor
649
+
650
+ self.direct_concat = direct_concat
651
+ self.separate_view_attention = separate_view_attention
652
+ self.separate_all_attention = separate_all_attention
653
+ self.end2end = end2end
654
+ self.use_view_embed = use_view_embed
655
+
656
+ if self.direct_concat:
657
+ in_chans = in_chans * 4
658
+ self.with_center_sensor = False
659
+ self.with_right_left_sensors = False
660
+
661
+ if self.separate_view_attention:
662
+ self.attn_mask = build_attn_mask("seperate_view")
663
+ elif self.separate_all_attention:
664
+ self.attn_mask = build_attn_mask("seperate_all")
665
+ else:
666
+ self.attn_mask = None
667
+
668
+ if use_different_backbone:
669
+ if rgb_backbone_name == "r50":
670
+ self.rgb_backbone = resnet50d(
671
+ pretrained=True,
672
+ in_chans=in_chans,
673
+ features_only=True,
674
+ out_indices=[4],
675
+ )
676
+ elif rgb_backbone_name == "r26":
677
+ self.rgb_backbone = resnet26d(
678
+ pretrained=True,
679
+ in_chans=in_chans,
680
+ features_only=True,
681
+ out_indices=[4],
682
+ )
683
+ elif rgb_backbone_name == "r18":
684
+ self.rgb_backbone = resnet18d(
685
+ pretrained=True,
686
+ in_chans=in_chans,
687
+ features_only=True,
688
+ out_indices=[4],
689
+ )
690
+ if lidar_backbone_name == "r50":
691
+ self.lidar_backbone = resnet50d(
692
+ pretrained=False,
693
+ in_chans=in_chans,
694
+ features_only=True,
695
+ out_indices=[4],
696
+ )
697
+ elif lidar_backbone_name == "r26":
698
+ self.lidar_backbone = resnet26d(
699
+ pretrained=False,
700
+ in_chans=in_chans,
701
+ features_only=True,
702
+ out_indices=[4],
703
+ )
704
+ elif lidar_backbone_name == "r18":
705
+ self.lidar_backbone = resnet18d(
706
+ pretrained=False, in_chans=3, features_only=True, out_indices=[4]
707
+ )
708
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
709
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
710
+
711
+ if use_mmad_pretrain:
712
+ params = torch.load(use_mmad_pretrain)["state_dict"]
713
+ updated_params = OrderedDict()
714
+ for key in params:
715
+ if "backbone" in key:
716
+ updated_params[key.replace("backbone.", "")] = params[key]
717
+ self.rgb_backbone.load_state_dict(updated_params)
718
+
719
+ self.rgb_patch_embed = rgb_embed_layer(
720
+ img_size=img_size,
721
+ patch_size=patch_size,
722
+ in_chans=in_chans,
723
+ embed_dim=embed_dim,
724
+ )
725
+ self.lidar_patch_embed = lidar_embed_layer(
726
+ img_size=img_size,
727
+ patch_size=patch_size,
728
+ in_chans=3,
729
+ embed_dim=embed_dim,
730
+ )
731
+ else:
732
+ if rgb_backbone_name == "r50":
733
+ self.rgb_backbone = resnet50d(
734
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
735
+ )
736
+ elif rgb_backbone_name == "r101":
737
+ self.rgb_backbone = resnet101d(
738
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
739
+ )
740
+ elif rgb_backbone_name == "r26":
741
+ self.rgb_backbone = resnet26d(
742
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
743
+ )
744
+ elif rgb_backbone_name == "r18":
745
+ self.rgb_backbone = resnet18d(
746
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
747
+ )
748
+ embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
749
+
750
+ self.rgb_patch_embed = embed_layer(
751
+ img_size=img_size,
752
+ patch_size=patch_size,
753
+ in_chans=in_chans,
754
+ embed_dim=embed_dim,
755
+ )
756
+ self.lidar_patch_embed = embed_layer(
757
+ img_size=img_size,
758
+ patch_size=patch_size,
759
+ in_chans=in_chans,
760
+ embed_dim=embed_dim,
761
+ )
762
+
763
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
764
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
765
+
766
+ if self.end2end:
767
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
768
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
769
+ elif self.waypoints_pred_head == "heatmap":
770
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
771
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
772
+ else:
773
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
774
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
775
+
776
+ if self.end2end:
777
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
778
+ elif self.waypoints_pred_head == "heatmap":
779
+ self.waypoints_generator = MultiPath_Generator(
780
+ embed_dim + 32, embed_dim, 10
781
+ )
782
+ elif self.waypoints_pred_head == "gru":
783
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
784
+ elif self.waypoints_pred_head == "gru-command":
785
+ self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
786
+ elif self.waypoints_pred_head == "linear":
787
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
788
+ elif self.waypoints_pred_head == "linear-sum":
789
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
790
+
791
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
792
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
793
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
794
+
795
+ if self.traffic_pred_head_type == "det":
796
+ self.traffic_pred_head = nn.Sequential(
797
+ *[
798
+ nn.Linear(embed_dim + 32, 64),
799
+ nn.ReLU(),
800
+ nn.Linear(64, 7),
801
+ nn.Sigmoid(),
802
+ ]
803
+ )
804
+ elif self.traffic_pred_head_type == "seg":
805
+ self.traffic_pred_head = nn.Sequential(
806
+ *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
807
+ )
808
+
809
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
810
+
811
+ encoder_layer = TransformerEncoderLayer(
812
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
813
+ )
814
+ self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
815
+
816
+ decoder_layer = TransformerDecoderLayer(
817
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
818
+ )
819
+ decoder_norm = nn.LayerNorm(embed_dim)
820
+ self.decoder = TransformerDecoder(
821
+ decoder_layer, dec_depth, decoder_norm, return_intermediate=False
822
+ )
823
+ self.reset_parameters()
824
+
825
+ def reset_parameters(self):
826
+ nn.init.uniform_(self.global_embed)
827
+ nn.init.uniform_(self.view_embed)
828
+ nn.init.uniform_(self.query_embed)
829
+ nn.init.uniform_(self.query_pos_embed)
830
+
831
+ def forward_features(
832
+ self,
833
+ front_image,
834
+ left_image,
835
+ right_image,
836
+ front_center_image,
837
+ lidar,
838
+ measurements,
839
+ ):
840
+ features = []
841
+
842
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
843
+ if self.use_view_embed:
844
+ front_image_token = (
845
+ front_image_token
846
+ + self.view_embed[:, :, 0:1, :]
847
+ + self.position_encoding(front_image_token)
848
+ )
849
+ else:
850
+ front_image_token = front_image_token + self.position_encoding(
851
+ front_image_token
852
+ )
853
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
854
+ front_image_token_global = (
855
+ front_image_token_global
856
+ + self.view_embed[:, :, 0, :]
857
+ + self.global_embed[:, :, 0:1]
858
+ )
859
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
860
+ features.extend([front_image_token, front_image_token_global])
861
+
862
+ if self.with_right_left_sensors:
863
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
864
+ if self.use_view_embed:
865
+ left_image_token = (
866
+ left_image_token
867
+ + self.view_embed[:, :, 1:2, :]
868
+ + self.position_encoding(left_image_token)
869
+ )
870
+ else:
871
+ left_image_token = left_image_token + self.position_encoding(
872
+ left_image_token
873
+ )
874
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
875
+ left_image_token_global = (
876
+ left_image_token_global
877
+ + self.view_embed[:, :, 1, :]
878
+ + self.global_embed[:, :, 1:2]
879
+ )
880
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
881
+
882
+ right_image_token, right_image_token_global = self.rgb_patch_embed(
883
+ right_image
884
+ )
885
+ if self.use_view_embed:
886
+ right_image_token = (
887
+ right_image_token
888
+ + self.view_embed[:, :, 2:3, :]
889
+ + self.position_encoding(right_image_token)
890
+ )
891
+ else:
892
+ right_image_token = right_image_token + self.position_encoding(
893
+ right_image_token
894
+ )
895
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
896
+ right_image_token_global = (
897
+ right_image_token_global
898
+ + self.view_embed[:, :, 2, :]
899
+ + self.global_embed[:, :, 2:3]
900
+ )
901
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
902
+
903
+ features.extend(
904
+ [
905
+ left_image_token,
906
+ left_image_token_global,
907
+ right_image_token,
908
+ right_image_token_global,
909
+ ]
910
+ )
911
+
912
+ if self.with_center_sensor:
913
+ (
914
+ front_center_image_token,
915
+ front_center_image_token_global,
916
+ ) = self.rgb_patch_embed(front_center_image)
917
+ if self.use_view_embed:
918
+ front_center_image_token = (
919
+ front_center_image_token
920
+ + self.view_embed[:, :, 3:4, :]
921
+ + self.position_encoding(front_center_image_token)
922
+ )
923
+ else:
924
+ front_center_image_token = (
925
+ front_center_image_token
926
+ + self.position_encoding(front_center_image_token)
927
+ )
928
+
929
+ front_center_image_token = front_center_image_token.flatten(2).permute(
930
+ 2, 0, 1
931
+ )
932
+ front_center_image_token_global = (
933
+ front_center_image_token_global
934
+ + self.view_embed[:, :, 3, :]
935
+ + self.global_embed[:, :, 3:4]
936
+ )
937
+ front_center_image_token_global = front_center_image_token_global.permute(
938
+ 2, 0, 1
939
+ )
940
+ features.extend([front_center_image_token, front_center_image_token_global])
941
+
942
+ if self.with_lidar:
943
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
944
+ if self.use_view_embed:
945
+ lidar_token = (
946
+ lidar_token
947
+ + self.view_embed[:, :, 4:5, :]
948
+ + self.position_encoding(lidar_token)
949
+ )
950
+ else:
951
+ lidar_token = lidar_token + self.position_encoding(lidar_token)
952
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
953
+ lidar_token_global = (
954
+ lidar_token_global
955
+ + self.view_embed[:, :, 4, :]
956
+ + self.global_embed[:, :, 4:5]
957
+ )
958
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
959
+ features.extend([lidar_token, lidar_token_global])
960
+
961
+ features = torch.cat(features, 0)
962
+ return features
963
+
964
+ def forward(self, x):
965
+ front_image = x["rgb"]
966
+ left_image = x["rgb_left"]
967
+ right_image = x["rgb_right"]
968
+ front_center_image = x["rgb_center"]
969
+ measurements = x["measurements"]
970
+ target_point = x["target_point"]
971
+ lidar = x["lidar"]
972
+
973
+ if self.direct_concat:
974
+ img_size = front_image.shape[-1]
975
+ left_image = torch.nn.functional.interpolate(
976
+ left_image, size=(img_size, img_size)
977
+ )
978
+ right_image = torch.nn.functional.interpolate(
979
+ right_image, size=(img_size, img_size)
980
+ )
981
+ front_center_image = torch.nn.functional.interpolate(
982
+ front_center_image, size=(img_size, img_size)
983
+ )
984
+ front_image = torch.cat(
985
+ [front_image, left_image, right_image, front_center_image], dim=1
986
+ )
987
+ features = self.forward_features(
988
+ front_image,
989
+ left_image,
990
+ right_image,
991
+ front_center_image,
992
+ lidar,
993
+ measurements,
994
+ )
995
+
996
+ bs = front_image.shape[0]
997
+
998
+ if self.end2end:
999
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
1000
+ else:
1001
+ tgt = self.position_encoding(
1002
+ torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1003
+ )
1004
+ tgt = tgt.flatten(2)
1005
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1006
+ tgt = tgt.permute(2, 0, 1)
1007
+
1008
+ memory = self.encoder(features, mask=self.attn_mask)
1009
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1010
+
1011
+ hs = hs.permute(1, 0, 2)
1012
+
1013
+ if self.end2end:
1014
+ waypoints = self.waypoints_generator(hs, target_point)
1015
+ return waypoints
1016
+
1017
+ if self.waypoints_pred_head != "heatmap":
1018
+ traffic_feature = hs[:, :400]
1019
+ is_junction_feature = hs[:, 400]
1020
+ traffic_light_state_feature = hs[:, 400]
1021
+ stop_sign_feature = hs[:, 400]
1022
+ waypoints_feature = hs[:, 401:411]
1023
+ else:
1024
+ traffic_feature = hs[:, :400]
1025
+ is_junction_feature = hs[:, 400]
1026
+ traffic_light_state_feature = hs[:, 400]
1027
+ stop_sign_feature = hs[:, 400]
1028
+ waypoints_feature = hs[:, 401:405]
1029
+
1030
+ if self.waypoints_pred_head == "heatmap":
1031
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1032
+ elif self.waypoints_pred_head == "gru":
1033
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
1034
+ elif self.waypoints_pred_head == "gru-command":
1035
+ waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1036
+ elif self.waypoints_pred_head == "linear":
1037
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1038
+ elif self.waypoints_pred_head == "linear-sum":
1039
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1040
+
1041
+ is_junction = self.junction_pred_head(is_junction_feature)
1042
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1043
+ stop_sign = self.stop_sign_head(stop_sign_feature)
1044
+
1045
+ velocity = measurements[:, 6:7].unsqueeze(-1)
1046
+ velocity = velocity.repeat(1, 400, 32)
1047
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1048
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
1049
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1050
+
1051
+ # --- Model Builder Function ---
1052
+ @register_model
1053
+ def interfuser_baseline(**kwargs):
1054
+ model = Interfuser(
1055
+ enc_depth=6,
1056
+ dec_depth=6,
1057
+ embed_dim=256,
1058
+ rgb_backbone_name="r50",
1059
+ lidar_backbone_name="r18",
1060
+ waypoints_pred_head="gru-command", # Matching the original code's logic
1061
+ use_different_backbone=True,
1062
+ direct_concat=False, # Matching the original code's logic
1063
+ )
1064
+ return model