| | import sys |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def _print(s): |
| | print(s) |
| | sys.stdout.flush() |
| |
|
| |
|
| | def get_latents(model, tokenizer, sequence, device): |
| | tokens = tokenizer(sequence, return_tensors="pt").to(device) |
| | with torch.no_grad(): |
| | outputs = model(**tokens) |
| | embeds = outputs.hidden_states[-1].squeeze(0) |
| | return embeds |
| |
|
| |
|
| |
|
| | |
| | def freeze_model(model: nn.Module): |
| | |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| |
|
| |
|
| |
|
| | |
| | def apply_gptj_freezing(model, N_layers): |
| | def unfreeze_n_layers(model, N_layers): |
| | |
| | model_layers = len(model.transformer.h) |
| | for i, h in enumerate(model.transformer.h): |
| | if i >= model_layers - N_layers: |
| | for module in h.attn.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| |
|
| | def check_frozen_model(model, N_layers: int): |
| | """ |
| | Verify that only the last N_layers of model.transformer.h are unfrozen. |
| | Source: https://github.com/enijkamp/progen2/blob/main/progen/modeling_progen.py |
| | """ |
| | model_layers = len(model.transformer.h) |
| | frozen_layers = 0 |
| | unfrozen_layers = 0 |
| | for i, h in enumerate(model.transformer.h): |
| | if i >= model_layers - N_layers: |
| | if any(param.requires_grad for param in h.parameters()): |
| | unfrozen_layers += 1 |
| | else: |
| | print(f"Layer {i} has all parameters frozen, but it should be unfrozen.") |
| | else: |
| | if any(param.requires_grad for param in h.parameters()): |
| | print(f"Layer {i} is not frozen, but it should be frozen.") |
| | else: |
| | frozen_layers += 1 |
| |
|
| | assert frozen_layers == model_layers - N_layers and unfrozen_layers == N_layers, \ |
| | f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}" |
| |
|
| | print(f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}") |
| |
|
| | freeze_model(model) |
| | unfreeze_n_layers(model, N_layers) |
| | check_frozen_model(model, N_layers) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| | def apply_rdm_freezing(model: nn.Module, N_layers: int, model_type: str): |
| | """ |
| | Freeze all layers except last N for esm-like architectures |
| | |
| | Args: |
| | model (nn.Module): model to freeze |
| | N_layers (int): num encoder layers to unfreeze |
| | model_type (str): one of {"esm", "evoflow", "dplm"} |
| | """ |
| |
|
| | |
| | if model_type == "dplm": |
| | encoder_layers = model.net.esm.encoder.layer |
| | elif model_type in ("esm", "evoflow"): |
| | encoder_layers = model.esm.encoder.layer |
| | else: |
| | raise ValueError(f"Unknown model_type: {model_type}") |
| |
|
| | def unfreeze_n_layers(layers, N_layers: int): |
| | model_layers = len(layers) |
| | for i, layer in enumerate(layers): |
| | if i >= model_layers - N_layers: |
| | for module in layer.attention.self.key.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | for module in layer.attention.self.query.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | for module in layer.attention.self.value.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| |
|
| | def check_model(layers, N_layers: int): |
| | model_layers = len(layers) |
| | frozen_layers = 0 |
| | unfrozen_layers = 0 |
| |
|
| | for i, layer in enumerate(layers): |
| | if i >= model_layers - N_layers: |
| | layer_frozen = True |
| | for module in layer.attention.self.key.modules(): |
| | if any(param.requires_grad for param in module.parameters()): |
| | layer_frozen = False |
| | for module in layer.attention.self.query.modules(): |
| | if any(param.requires_grad for param in module.parameters()): |
| | layer_frozen = False |
| | for module in layer.attention.self.value.modules(): |
| | if any(param.requires_grad for param in module.parameters()): |
| | layer_frozen = False |
| | |
| | if layer_frozen: |
| | print(f"layer {i} has all parameters frozen, but it should be unfrozen.") |
| | else: |
| | unfrozen_layers += 1 |
| | else: |
| | if any(param.requires_grad for param in layer.parameters()): |
| | print(f"layer {i} is not frozen, but it should") |
| | else: |
| | frozen_layers += 1 |
| |
|
| | assert (frozen_layers == model_layers - N_layers) and (unfrozen_layers == N_layers), \ |
| | f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}" |
| |
|
| |
|
| | freeze_model(model) |
| | unfreeze_n_layers(encoder_layers, N_layers) |
| | check_model(encoder_layers, N_layers) |
| |
|