| | import torch |
| |
|
| | from .utils import kabsch |
| | |
| | def rbf(positions, target_position, sigma): |
| | R, t = kabsch(positions.detach(), target_position.detach()) |
| | positions = torch.matmul(positions, R.transpose(-2, -1)) + t |
| | log_ri = ( |
| | -0.5 / sigma**2 * (positions - target_position).square().mean((-2, -1)) |
| | ) |
| | return log_ri |
| |
|
| | def grad_log_wrt_positions(positions, target_position, sigma): |
| | """ |
| | Gradient of log kernel w.r.t. the ORIGINAL positions: same shape as positions (..., N, 3). |
| | """ |
| | pos = positions.clone().detach().requires_grad_(True) |
| | log_ri = rbf(pos, target_position, sigma) |
| | |
| | (grad_pos,) = torch.autograd.grad(log_ri.sum(), pos, create_graph=False, retain_graph=False) |
| | return grad_pos |