Update region.py
Browse filesfix: update internal decode size.
region.py
CHANGED
|
@@ -71,26 +71,20 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
|
| 71 |
return w.size_encoder(fourier_features(size, w.size_features))
|
| 72 |
|
| 73 |
|
|
|
|
| 74 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 75 |
"""
|
| 76 |
-
|
| 77 |
-
for 1024 bins representing width and height in log-scale.
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
where size values are clamped to be at least 1/1024.
|
| 82 |
-
|
| 83 |
-
To convert from bin back to size:
|
| 84 |
-
size = 2^((bin / 1023.0) * 10.0 - 10.0)
|
| 85 |
-
|
| 86 |
-
Args:
|
| 87 |
-
hidden_state: The final hidden state tensor from the text model.
|
| 88 |
-
|
| 89 |
-
Returns:
|
| 90 |
-
A tensor containing logits for 1024 bins for width and height.
|
| 91 |
-
Shape is (2, 1024) where the first dimension corresponds to width and height.
|
| 92 |
"""
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
|
|
|
|
| 71 |
return w.size_encoder(fourier_features(size, w.size_features))
|
| 72 |
|
| 73 |
|
| 74 |
+
# region.py
|
| 75 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 76 |
"""
|
| 77 |
+
Returns logits for width & height bins without collapsing batch/seq dims.
|
|
|
|
| 78 |
|
| 79 |
+
Input (hidden_state): (..., C)
|
| 80 |
+
Output: (..., 2, bins) # keeps all leading dims intact
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
"""
|
| 82 |
+
x = mlp(hidden_state, w.size_decoder) # (..., size_out_dim)
|
| 83 |
+
last = x.shape[-1]
|
| 84 |
+
if last % 2 != 0:
|
| 85 |
+
raise RuntimeError(f"size_out_dim must be even, got {last}")
|
| 86 |
+
return x.view(*x.shape[:-1], 2, last // 2) # (..., 2, bins)
|
| 87 |
+
|
| 88 |
|
| 89 |
|
| 90 |
def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
|