| | import sys |
| | import torch |
| | import random |
| | import numpy as np |
| | from tqdm import tqdm |
| | from src.utils.model_utils import _print |
| |
|
| | class UnconditionalSampler: |
| | def __init__(self, tokenizer, model): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| |
|
| | self.device = self.model.device |
| | self.mask_id = self.tokenizer.mask_token_id |
| | self.seed_everything(seed=42) |
| |
|
| | @torch.inference_mode() |
| | def sample_unconditional(self, xt, num_steps, tau=0.7, kappa_fn=lambda t: t, eta=1, alpha=1., banned_token_ids=None, return_logits=None): |
| | """ |
| | Stochastic remasking sampling method for iterative refinement of sequences. |
| | |
| | Args: |
| | xt (Tensor): Initial token tensor. |
| | num_steps (int): Number of refinement steps. |
| | tau (float): Temperature parameter for softmax sampling. |
| | kappa_fn (callable): Function controlling the unmasking schedule. |
| | eta (float): Scaling factor for score adjustments. |
| | alpha (float): Weighting for confidence-based scoring. |
| | |
| | Returns: |
| | Tensor: Final sampled sequence tensor. |
| | """ |
| | |
| | dt = 1 / num_steps |
| | fix_mask = xt != self.mask_id |
| | attention_mask = torch.ones_like(xt).to(self.device) |
| |
|
| | for i in range(1, num_steps + 1): |
| | kappa_t = kappa_fn(i * dt) |
| | logits = self.model(input_ids=xt, attention_mask=attention_mask) |
| | last_mask = xt == self.mask_id |
| | unmask_t = ~last_mask & ~fix_mask |
| |
|
| | x0, logp = self.stochastic_sample_from_categorical(logits, tau, banned_token_ids=banned_token_ids) |
| |
|
| | |
| | entropy = torch.distributions.Categorical(logits=logits).entropy() |
| | score = alpha * logp + (1 - alpha) * -entropy |
| | score = score.masked_fill(fix_mask, float('inf')) |
| |
|
| | score[unmask_t] = score[unmask_t] * eta |
| |
|
| | num_to_mask = ((~fix_mask).sum(1, keepdim=True).float() * (1 - kappa_t)).long() |
| | lowest_k_mask = self.topk_lowest_masking(score, num_to_mask) |
| |
|
| | xt[lowest_k_mask] = self.mask_id |
| | mask_2_x0 = last_mask & ~lowest_k_mask |
| | xt[mask_2_x0] = x0[mask_2_x0] |
| |
|
| | |
| |
|
| | xt[xt == self.mask_id] = x0[xt == self.mask_id] |
| | return xt, logits if return_logits else xt |
| |
|
| | def stochastic_sample_from_categorical(self, logits, temperature, noise_scale=1.0, banned_token_ids=None): |
| | """ |
| | Sample from a categorical distribution with optional temperature scaling and Gumbel noise. |
| | """ |
| | logits = logits.double() |
| |
|
| | if banned_token_ids is not None: |
| | banned_token_mask = torch.zeros_like(logits, device=logits.device).bool() |
| | for token_id in banned_token_ids: |
| | banned_token_mask[..., token_id] = True |
| | logits = logits.masked_fill(banned_token_mask, float('-inf')) |
| |
|
| | if temperature != 0: |
| | gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8) |
| | logits = logits / temperature + noise_scale * gumbel_noise |
| | scores, tokens = logits.log_softmax(dim=-1).max(dim=-1) |
| |
|
| | return tokens, scores |
| |
|
| | def topk_lowest_masking(self, scores, cutoff_len): |
| | """ |
| | scores: [b, n] |
| | cutoff_len: [b, 1] |
| | returns: |
| | mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise |
| | """ |
| | sorted_index = scores.sort(-1)[0] |
| | cutoff = sorted_index.gather(dim=-1, index=cutoff_len) |
| | return scores < cutoff |
| |
|
| | def seed_everything(self, seed): |
| | """ |
| | Set the seed for reproducibility across various libraries. |
| | """ |
| | if seed is None: |
| | return |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |