HV-Khurdula commited on
Commit
c0e9503
·
verified ·
1 Parent(s): f542ccb

Update region.py

Browse files

fix: update internal decode size.

Files changed (1) hide show
  1. region.py +10 -16
region.py CHANGED
@@ -71,26 +71,20 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
71
  return w.size_encoder(fourier_features(size, w.size_features))
72
 
73
 
 
74
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75
  """
76
- Takes as input the last hidden state from the text model and outputs logits
77
- for 1024 bins representing width and height in log-scale.
78
 
79
- The bins are distributed according to the formula:
80
- bin = (log2(size) + 10.0) / 10.0 * 1023.0
81
- where size values are clamped to be at least 1/1024.
82
-
83
- To convert from bin back to size:
84
- size = 2^((bin / 1023.0) * 10.0 - 10.0)
85
-
86
- Args:
87
- hidden_state: The final hidden state tensor from the text model.
88
-
89
- Returns:
90
- A tensor containing logits for 1024 bins for width and height.
91
- Shape is (2, 1024) where the first dimension corresponds to width and height.
92
  """
93
- return mlp(hidden_state, w.size_decoder).view(2, -1)
 
 
 
 
 
94
 
95
 
96
  def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
 
71
  return w.size_encoder(fourier_features(size, w.size_features))
72
 
73
 
74
+ # region.py
75
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
76
  """
77
+ Returns logits for width & height bins without collapsing batch/seq dims.
 
78
 
79
+ Input (hidden_state): (..., C)
80
+ Output: (..., 2, bins) # keeps all leading dims intact
 
 
 
 
 
 
 
 
 
 
 
81
  """
82
+ x = mlp(hidden_state, w.size_decoder) # (..., size_out_dim)
83
+ last = x.shape[-1]
84
+ if last % 2 != 0:
85
+ raise RuntimeError(f"size_out_dim must be even, got {last}")
86
+ return x.view(*x.shape[:-1], 2, last // 2) # (..., 2, bins)
87
+
88
 
89
 
90
  def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor: