HV-Khurdula's picture
Update rope.py
09597f3 verified
# Ethically sourced from https://github.com/xjdr-alt/entropix
import torch
def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
use_scaled: bool = False,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
t = torch.arange(end, dtype=dtype).unsqueeze(1)
freqs = t * freqs.unsqueeze(0)
freqs = torch.exp(1j * freqs)
return torch.stack([freqs.real, freqs.imag], dim=-1)
# rope.py
import torch
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: torch.Tensor,
position_ids: torch.Tensor,
num_heads: int,
rot_dim: int = 32,
interleave: bool = False,
) -> torch.Tensor:
"""
RoPE as used in the original moondream2 text stack:
x: (B, H, T, D)
freqs_cis: (max_T, rot_dim//2, 2) where [...,0]=cos, [...,1]=sin
position_ids: (T,) or (B,T)
returns x with first rot_dim dims rotated.
"""
assert rot_dim == freqs_cis.shape[-2] * 2
assert num_heads == x.shape[1]
B, H, T, D = x.shape
rd = min(rot_dim, D)
x_rot, x_pass = x[..., :rd], x[..., rd:]
# split real/imag parts depending on layout
if interleave:
xr = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
xi = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
else:
d = x_rot.shape[-1] // 2
xr, xi = x_rot[..., :d], x_rot[..., d:]
# gather cos/sin for these positions
if position_ids.dim() == 2 and position_ids.size(0) == B:
freq = freqs_cis[position_ids] # (B, T, rd//2, 2)
else: # (T,) or scalar
freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
rot_half = rd // 2
cos = freq[..., 0][..., :rot_half].unsqueeze(1).to(x.dtype) # (B,1,T,rot_half)
sin = freq[..., 1][..., :rot_half].unsqueeze(1).to(x.dtype)
# complex multiply
yr = xr * cos - xi * sin
yi = xr * sin + xi * cos
y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd)
return torch.cat([y, x_pass], dim=-1)