| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class ProbabilityPathTracer: |
| | def __init__(self, oracle_model, tokenizer, device): |
| | self.oracle = oracle_model |
| | self.tokenizer = tokenizer |
| | self.device = device |
| | self.mask_id = tokenizer.mask_token_id |
| | self.history = {} |
| |
|
| | @torch.inference_mode() |
| | def compute_loglikeli(self, xt): |
| | is_revealed = (xt != self.mask_id) |
| | |
| | if not is_revealed.any(): |
| | return 0.0 |
| |
|
| | |
| | logits = self.oracle( |
| | input_ids=xt, |
| | attention_mask=torch.ones_like(xt, device=xt.device) |
| | ).logits |
| | |
| | |
| | nll = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | xt.view(-1), |
| | reduction='none' |
| | ) |
| | |
| | nll = nll.view(xt.shape) |
| | |
| | |
| | avg_ll = -(nll * is_revealed.float()).sum(dim=1) / is_revealed.float().sum(dim=1).clamp(min=1) |
| | |
| | return avg_ll.item() |
| |
|
| | def log_step(self, xt, step_idx): |
| | score = self.compute_loglikeli(xt) |
| | self.history[f"trace_step_{step_idx}"] = score |
| |
|
| | def get_trace(self): |
| | return self.history |