DarthReca commited on
Commit
ad28d31
·
verified ·
1 Parent(s): 09e20f6

Update modeling_actu.py

Browse files
Files changed (1) hide show
  1. modeling_actu.py +48 -271
modeling_actu.py CHANGED
@@ -1,203 +1,89 @@
1
- from dataclasses import dataclass
2
-
3
  import numpy as np
4
  import timm
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
- from einops import rearrange, repeat
9
  from segmentation_models_pytorch.base import SegmentationHead
10
  from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
11
  from timm.layers.create_act import create_act_layer
12
- from transformers import PretrainedConfig, PreTrainedModel
13
- from transformers.modeling_outputs import SemanticSegmenterOutput
14
 
15
  from .convlstm import ConvLSTM
16
 
17
 
18
- class ACTUConfig(PretrainedConfig):
19
- model_type = "actu"
20
-
21
  def __init__(
22
  self,
23
- # Base ACTU parameters
24
- in_channels: int = 3,
25
- kernel_size: tuple[int, int] = (3, 3),
26
- padding="same",
27
- stride=(1, 1),
28
- backbone="resnet34",
29
  bias=True,
30
  batch_first=True,
31
  bidirectional=False,
32
  original_resolution=(256, 256),
33
- act_layer="sigmoid",
34
- n_classes=1,
35
- # Variant control parameters
36
- use_dem_input: bool = False,
37
- use_climate_branch: bool = False,
38
- # Climate branch parameters
39
- climate_seq_len=5,
40
- climate_input_dim=6,
41
- lstm_hidden_dim=128,
42
- num_lstm_layers=1,
43
  **kwargs,
44
  ):
45
- super().__init__(**kwargs)
46
- self.in_channels = in_channels
47
- self.kernel_size = kernel_size
48
- self.padding = padding
49
- self.stride = stride
50
- self.backbone = backbone
51
- self.bias = bias
52
- self.batch_first = batch_first
53
- self.bidirectional = bidirectional
54
- self.original_resolution = original_resolution
55
- self.act_layer = act_layer
56
  self.n_classes = n_classes
57
-
58
- # Parameters to control variants
59
- self.use_dem_input = use_dem_input
60
- self.use_climate_branch = use_climate_branch
61
- self.climate_seq_len = climate_seq_len
62
- self.climate_input_dim = climate_input_dim
63
- self.lstm_hidden_dim = lstm_hidden_dim
64
- self.num_lstm_layers = num_lstm_layers
65
-
66
- # Adjust in_channels if DEM is used
67
- if self.use_dem_input:
68
- self.in_channels += 1
69
-
70
-
71
- class ACTUForImageSegmentation(PreTrainedModel):
72
- config_class = ACTUConfig
73
-
74
- def __init__(self, config: ACTUConfig):
75
- super().__init__(config)
76
- self.config = config
77
 
78
  self.encoder: nn.Module = timm.create_model(
79
- config.backbone, features_only=True, in_chans=config.in_channels
80
  )
81
 
82
  with torch.no_grad():
83
- dummy_input_channels = config.in_channels
84
- dummy_input = torch.randn(
85
- 1, dummy_input_channels, *config.original_resolution
86
  )
87
- embs = self.encoder(dummy_input)
88
- self.embs_shape = [e.shape for e in embs]
89
- self.encoder_channels = [e[1] for e in self.embs_shape]
90
 
 
 
91
  self.convlstm = nn.ModuleList(
92
- [
93
- ConvLSTM(
94
- in_channels=shape[1],
95
- hidden_channels=shape[1],
96
- kernel_size=config.kernel_size,
97
- padding=config.padding,
98
- stride=config.stride,
99
- bias=config.bias,
100
- batch_first=config.batch_first,
101
- bidirectional=config.bidirectional,
102
- )
103
- for shape in self.embs_shape
104
- ]
105
- )
106
-
107
- if self.config.use_climate_branch:
108
- self.climate_branch = ClimateBranchLSTM(
109
- output_shapes=[e[1:] for e in self.embs_shape],
110
- lstm_hidden_dim=config.lstm_hidden_dim,
111
- climate_seq_len=config.climate_seq_len,
112
- climate_input_dim=config.climate_input_dim,
113
- num_lstm_layers=config.num_lstm_layers,
114
- )
115
- self.fusers = nn.ModuleList(
116
- GatedFusion(enc, enc) for enc in self.encoder_channels
117
  )
 
 
 
 
 
 
 
118
 
119
  self.decoder = UnetDecoder(
120
- encoder_channels=[1] + self.encoder_channels,
121
- decoder_channels=self.encoder_channels[::-1],
122
- n_blocks=len(self.encoder_channels),
123
  )
124
-
125
  self.seg_head = nn.Sequential(
126
  SegmentationHead(
127
- in_channels=self.encoder_channels[0],
128
- out_channels=config.n_classes,
129
  ),
130
- create_act_layer(config.act_layer, inplace=True),
131
  )
 
 
132
 
133
- def forward(
134
- self,
135
- pixel_values: torch.Tensor,
136
- climate: torch.Tensor = None,
137
- dem: torch.Tensor = None,
138
- labels: torch.Tensor = None,
139
- **kwargs,
140
- ) -> SemanticSegmenterOutput:
141
- b, t = pixel_values.shape[:2]
142
- original_size = pixel_values.shape[-2:]
143
-
144
- # Handle DEM input
145
- if self.config.use_dem_input:
146
- if dem is None:
147
- raise ValueError(
148
- "DEM tensor must be provided when use_dem_input is True."
149
- )
150
- dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t)
151
- pixel_values = torch.cat([pixel_values, dem_repeated], dim=2)
152
-
153
- # 1. Encode images per time step
154
- encoded_sequence = self._encode_images(pixel_values)
155
-
156
- # 2. Handle Climate Branch Fusion
157
- if self.config.use_climate_branch:
158
- if climate is None:
159
- raise ValueError(
160
- "Climate tensor must be provided when use_climate_branch is True."
161
- )
162
-
163
- climate_features = self.climate_branch(climate)
164
-
165
- # Reshape for fusion
166
- encoded_sequence_reshaped = [
167
- rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence
168
- ]
169
- climate_features_reshaped = [
170
- rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features
171
- ]
172
-
173
- # Fuse features
174
- fused_features = [
175
- fuser(img, clim)
176
- for fuser, img, clim in zip(
177
- self.fusers, encoded_sequence_reshaped, climate_features_reshaped
178
- )
179
- ]
180
-
181
- # Reshape back to sequence
182
- encoded_sequence = [
183
- rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features
184
- ]
185
-
186
- # 3. Process sequence with ConvLSTM
187
- temporal_features = self._encode_timeseries(encoded_sequence)
188
-
189
- # 4. Decode to get the segmentation map
190
- logits = self._decode(temporal_features, size=original_size)
191
-
192
- loss = None
193
- if labels is not None:
194
- loss_fct = nn.CrossEntropyLoss()
195
- loss = loss_fct(logits, labels.float().unsqueeze(1))
196
-
197
- return SemanticSegmenterOutput(
198
- loss=loss,
199
- logits=logits,
200
- )
201
 
202
  def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]:
203
  B = x.size(0)
@@ -221,112 +107,3 @@ class ACTUForImageSegmentation(PreTrainedModel):
221
  trend_map, size=size, mode="bilinear", align_corners=False
222
  )
223
  return trend_map
224
-
225
-
226
- class ClimateBranchLSTM(nn.Module):
227
- """
228
- Processes climate time series data using an LSTM.
229
- Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5)
230
- Output shape: (B, T, output_dim) -> e.g., (B, 5, 128)
231
- """
232
-
233
- def __init__(
234
- self,
235
- output_shapes: list[tuple[int, int, int]],
236
- climate_input_dim=5,
237
- climate_seq_len=6,
238
- lstm_hidden_dim=64,
239
- num_lstm_layers=1,
240
- ):
241
- super().__init__()
242
- self.climate_seq_len = climate_seq_len
243
- self.climate_input_dim = climate_input_dim
244
- self.lstm_hidden_dim = lstm_hidden_dim
245
- self.num_lstm_layers = num_lstm_layers
246
- self.proj_dim = 128
247
- self.output_shapes = output_shapes
248
-
249
- self.lstm = nn.LSTM(
250
- input_size=climate_input_dim,
251
- hidden_size=lstm_hidden_dim,
252
- num_layers=num_lstm_layers,
253
- batch_first=True, # Crucial: expects input shape (batch, seq_len, features)
254
- dropout=0.3 if num_lstm_layers > 1 else 0,
255
- bidirectional=False,
256
- )
257
-
258
- # Linear layer to project LSTM output to the desired final dimension
259
- self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim)
260
-
261
- self.upsamples = nn.ModuleList(
262
- _build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes
263
- )
264
-
265
- def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]:
266
- # climate_data shape: (B, T, T_1, C_clim), e.g., (B, 5, 6, 5)
267
- B_img, B_cli, T, C = climate_data.shape
268
-
269
- # Reshape for LSTM: Treat each sequence independently
270
- lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C")
271
-
272
- # Pass through LSTM
273
- _, (hidden, _) = self.lstm.forward(lstm_input)
274
- # Get the last layer's hidden state
275
- last_hidden = (
276
- hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1]
277
- )
278
- if last_hidden.ndim == 3:
279
- last_hidden = hidden.mean(dim=0)
280
-
281
- # Pass the final hidden state through the fully connected layer(s) and upsample
282
- climate_features = self.fc(last_hidden)
283
- climate_features = rearrange(climate_features, "b c -> b c 1 1")
284
- climate_features = [
285
- rearrange(
286
- u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli
287
- )
288
- for u in self.upsamples
289
- ]
290
-
291
- return climate_features
292
-
293
-
294
- class GatedFusion(nn.Module):
295
- def __init__(self, img_channels, clim_channels):
296
- super().__init__()
297
- self.gate = nn.Sequential(
298
- nn.Sequential(
299
- nn.Conv2d(
300
- img_channels + clim_channels, img_channels, kernel_size=3, padding=1
301
- ),
302
- nn.ReLU(inplace=True),
303
- nn.Conv2d(img_channels, img_channels, kernel_size=1),
304
- nn.Sigmoid(), # Gate values between 0 and 1
305
- )
306
- )
307
-
308
- def forward(self, img_feat, clim_feat):
309
- gate = self.gate(torch.cat([img_feat, clim_feat], dim=1))
310
- return gate * img_feat + (1 - gate) * clim_feat
311
-
312
-
313
- def _build_upsampler(
314
- in_channels: int, target_channels: int, target_h: int
315
- ) -> nn.Sequential:
316
- layers = []
317
- current_h = 1
318
-
319
- # Expand to target channels early (e.g., 1x1 → 1x1 with target_channels)
320
- layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()]
321
-
322
- # Upsample spatially to target_h
323
- while current_h < target_h:
324
- next_h = min(current_h * 2, target_h)
325
- layers += [
326
- nn.Upsample(scale_factor=2, mode="nearest"),
327
- nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1),
328
- nn.GELU(),
329
- ]
330
- current_h = next_h
331
-
332
- return nn.Sequential(*layers)
 
 
 
1
  import numpy as np
2
  import timm
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from einops import rearrange
7
  from segmentation_models_pytorch.base import SegmentationHead
8
  from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
9
  from timm.layers.create_act import create_act_layer
 
 
10
 
11
  from .convlstm import ConvLSTM
12
 
13
 
14
+ class ACTU(nn.Module):
 
 
15
  def __init__(
16
  self,
17
+ in_channels,
18
+ kernel_size,
19
+ padding,
20
+ stride,
21
+ backbone: str,
 
22
  bias=True,
23
  batch_first=True,
24
  bidirectional=False,
25
  original_resolution=(256, 256),
26
+ act_layer: str = "sigmoid",
27
+ n_classes: int = 1,
 
 
 
 
 
 
 
 
28
  **kwargs,
29
  ):
30
+ super(ACTU, self).__init__()
 
 
 
 
 
 
 
 
 
 
31
  self.n_classes = n_classes
32
+ self.backbone = backbone
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  self.encoder: nn.Module = timm.create_model(
35
+ backbone, features_only=True, in_chans=in_channels
36
  )
37
 
38
  with torch.no_grad():
39
+ embs = self.encoder.forward(
40
+ torch.randn(1, in_channels, *original_resolution)
 
41
  )
42
+ embs_shape = [e.shape for e in embs]
 
 
43
 
44
+ # The ConvLSTM expects inputs of shape (B, T, feature_dim, H_enc, W_enc)
45
+ # We assume the provided ConvLSTM code is available.
46
  self.convlstm = nn.ModuleList(
47
+ ConvLSTM(
48
+ in_channels=shape[1],
49
+ hidden_channels=shape[1],
50
+ kernel_size=kernel_size,
51
+ padding=padding,
52
+ stride=stride,
53
+ bias=bias,
54
+ batch_first=batch_first,
55
+ bidirectional=bidirectional,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
+ for shape in embs_shape
58
+ )
59
+ # If bidirectional, the hidden representation is concatenated from both directions.
60
+ n_upsamples = int(np.log2(original_resolution[0] / embs_shape[-1][-2]))
61
+ skip_channels_list = [shape[1] for shape in embs_shape[-(n_upsamples + 1) : -1]]
62
+ skip_channels_list = skip_channels_list[::-1] # Reverse the list.
63
+ encoder_channels = [e[1] for e in embs_shape]
64
 
65
  self.decoder = UnetDecoder(
66
+ encoder_channels=[1, *encoder_channels],
67
+ decoder_channels=encoder_channels[::-1],
68
+ n_blocks=len(encoder_channels),
69
  )
 
70
  self.seg_head = nn.Sequential(
71
  SegmentationHead(
72
+ in_channels=encoder_channels[0],
73
+ out_channels=n_classes,
74
  ),
75
+ create_act_layer(act_layer, inplace=True),
76
  )
77
+ self.encoder_channels = encoder_channels
78
+ self.embs_shape = embs_shape
79
 
80
+ def forward(self, x: torch.Tensor, **kwargs):
81
+ size = x.size()[-2:]
82
+ # Process each time step through the encoder.
83
+ x = self._encode_images(x)
84
+ # Pass the encoded sequence through the ConvLSTM.
85
+ x = self._encode_timeseries(x)
86
+ return self._decode(x, size=size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]:
89
  B = x.size(0)
 
107
  trend_map, size=size, mode="bilinear", align_corners=False
108
  )
109
  return trend_map