Spaces:
Paused
Paused
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| from .utils import FORMAT_INFO, to_device | |
| from .tokenizer import SOS_ID, EOS_ID, PAD_ID, MASK_ID | |
| from .inference import GreedySearch, BeamSearch | |
| from .transformer import TransformerDecoder, Embeddings | |
| class Encoder(nn.Module): | |
| def __init__(self, args, pretrained=False): | |
| super().__init__() | |
| model_name = args.encoder | |
| self.model_name = model_name | |
| if model_name.startswith('resnet'): | |
| self.model_type = 'resnet' | |
| self.cnn = timm.create_model(model_name, pretrained=pretrained) | |
| self.n_features = self.cnn.num_features # encoder_dim | |
| self.cnn.global_pool = nn.Identity() | |
| self.cnn.fc = nn.Identity() | |
| elif model_name.startswith('swin'): | |
| self.model_type = 'swin' | |
| self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False, | |
| use_checkpoint=args.use_checkpoint) | |
| self.n_features = self.transformer.num_features | |
| self.transformer.head = nn.Identity() | |
| elif 'efficientnet' in model_name: | |
| self.model_type = 'efficientnet' | |
| self.cnn = timm.create_model(model_name, pretrained=pretrained) | |
| self.n_features = self.cnn.num_features | |
| self.cnn.global_pool = nn.Identity() | |
| self.cnn.classifier = nn.Identity() | |
| else: | |
| raise NotImplemented | |
| def swin_forward(self, transformer, x): | |
| x = transformer.patch_embed(x) | |
| if transformer.absolute_pos_embed is not None: | |
| x = x + transformer.absolute_pos_embed | |
| x = transformer.pos_drop(x) | |
| def layer_forward(layer, x, hiddens): | |
| for blk in layer.blocks: | |
| if not torch.jit.is_scripting() and layer.use_checkpoint: | |
| x = torch.utils.checkpoint.checkpoint(blk, x) | |
| else: | |
| x = blk(x) | |
| H, W = layer.input_resolution | |
| B, L, C = x.shape | |
| hiddens.append(x.view(B, H, W, C)) | |
| if layer.downsample is not None: | |
| x = layer.downsample(x) | |
| return x, hiddens | |
| hiddens = [] | |
| for layer in transformer.layers: | |
| x, hiddens = layer_forward(layer, x, hiddens) | |
| x = transformer.norm(x) # B L C | |
| hiddens[-1] = x.view_as(hiddens[-1]) | |
| return x, hiddens | |
| def forward(self, x, refs=None): | |
| if self.model_type in ['resnet', 'efficientnet']: | |
| features = self.cnn(x) | |
| features = features.permute(0, 2, 3, 1) | |
| hiddens = [] | |
| elif self.model_type == 'swin': | |
| if 'patch' in self.model_name: | |
| features, hiddens = self.swin_forward(self.transformer, x) | |
| else: | |
| features, hiddens = self.transformer(x) | |
| else: | |
| raise NotImplemented | |
| return features, hiddens | |
| class TransformerDecoderBase(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.args = args | |
| self.enc_trans_layer = nn.Sequential( | |
| nn.Linear(args.encoder_dim, args.dec_hidden_size) | |
| # nn.LayerNorm(args.dec_hidden_size, eps=1e-6) | |
| ) | |
| self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None | |
| self.decoder = TransformerDecoder( | |
| num_layers=args.dec_num_layers, | |
| d_model=args.dec_hidden_size, | |
| heads=args.dec_attn_heads, | |
| d_ff=args.dec_hidden_size * 4, | |
| copy_attn=False, | |
| self_attn_type="scaled-dot", | |
| dropout=args.hidden_dropout, | |
| attention_dropout=args.attn_dropout, | |
| max_relative_positions=args.max_relative_positions, | |
| aan_useffn=False, | |
| full_context_alignment=False, | |
| alignment_layer=0, | |
| alignment_heads=0, | |
| pos_ffn_activation_fn='gelu' | |
| ) | |
| def enc_transform(self, encoder_out): | |
| batch_size = encoder_out.size(0) | |
| encoder_dim = encoder_out.size(-1) | |
| encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) | |
| max_len = encoder_out.size(1) | |
| device = encoder_out.device | |
| if self.enc_pos_emb: | |
| pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0) | |
| encoder_out = encoder_out + pos_emb | |
| encoder_out = self.enc_trans_layer(encoder_out) | |
| return encoder_out | |
| class TransformerDecoderAR(TransformerDecoderBase): | |
| """Autoregressive Transformer Decoder""" | |
| def __init__(self, args, tokenizer): | |
| super().__init__(args) | |
| self.tokenizer = tokenizer | |
| self.vocab_size = len(self.tokenizer) | |
| self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True) | |
| self.embeddings = Embeddings( | |
| word_vec_size=args.dec_hidden_size, | |
| word_vocab_size=self.vocab_size, | |
| word_padding_idx=PAD_ID, | |
| position_encoding=True, | |
| dropout=args.hidden_dropout) | |
| def dec_embedding(self, tgt, step=None): | |
| pad_idx = self.embeddings.word_padding_idx | |
| tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt] | |
| emb = self.embeddings(tgt, step=step) | |
| assert emb.dim() == 3 # batch x len x embedding_dim | |
| return emb, tgt_pad_mask | |
| def forward(self, encoder_out, labels, label_lengths): | |
| """Training mode""" | |
| batch_size, max_len, _ = encoder_out.size() | |
| memory_bank = self.enc_transform(encoder_out) | |
| tgt = labels.unsqueeze(-1) # (b, t, 1) | |
| tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) | |
| dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask) | |
| logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v) | |
| return logits[:, :-1], labels[:, 1:], dec_out | |
| def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256, | |
| labels=None): | |
| """Inference mode. Autoregressively decode the sequence. Only greedy search is supported now. Beam search is | |
| out-dated. The labels is used for partial prediction, i.e. part of the sequence is given. In standard decoding, | |
| labels=None.""" | |
| batch_size, max_len, _ = encoder_out.size() | |
| memory_bank = self.enc_transform(encoder_out) | |
| orig_labels = labels | |
| if beam_size == 1: | |
| decode_strategy = GreedySearch( | |
| sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length, | |
| pad=PAD_ID, bos=SOS_ID, eos=EOS_ID, | |
| return_attention=False, return_hidden=True) | |
| else: | |
| decode_strategy = BeamSearch( | |
| beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length, | |
| pad=PAD_ID, bos=SOS_ID, eos=EOS_ID, | |
| return_attention=False) | |
| # adapted from onmt.translate.translator | |
| results = { | |
| "predictions": None, | |
| "scores": None, | |
| "attention": None | |
| } | |
| # (2) prep decode_strategy. Possibly repeat src objects. | |
| _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank) | |
| # (3) Begin decoding step by step: | |
| for step in range(decode_strategy.max_length): | |
| tgt = decode_strategy.current_predictions.view(-1, 1, 1) | |
| if labels is not None: | |
| label = labels[:, step].view(-1, 1, 1) | |
| mask = label.eq(MASK_ID).long() | |
| tgt = tgt * mask + label * (1 - mask) | |
| tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) | |
| dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, | |
| tgt_pad_mask=tgt_pad_mask, step=step) | |
| attn = dec_attn.get("std", None) | |
| dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v] | |
| dec_logits = dec_logits.squeeze(1) | |
| log_probs = F.log_softmax(dec_logits, dim=-1) | |
| if self.tokenizer.output_constraint: | |
| output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()] | |
| output_mask = torch.tensor(output_mask, device=log_probs.device) | |
| log_probs.masked_fill_(output_mask, -10000) | |
| label = labels[:, step + 1] if labels is not None and step + 1 < labels.size(1) else None | |
| decode_strategy.advance(log_probs, attn, dec_out, label) | |
| any_finished = decode_strategy.is_finished.any() | |
| if any_finished: | |
| decode_strategy.update_finished() | |
| if decode_strategy.done: | |
| break | |
| select_indices = decode_strategy.select_indices | |
| if any_finished: | |
| # Reorder states. | |
| memory_bank = memory_bank.index_select(0, select_indices) | |
| if labels is not None: | |
| labels = labels.index_select(0, select_indices) | |
| self.map_state(lambda state, dim: state.index_select(dim, select_indices)) | |
| results["scores"] = decode_strategy.scores # fixed to be average of token scores | |
| results["token_scores"] = decode_strategy.token_scores | |
| results["predictions"] = decode_strategy.predictions | |
| results["attention"] = decode_strategy.attention | |
| results["hidden"] = decode_strategy.hidden | |
| if orig_labels is not None: | |
| for i in range(batch_size): | |
| pred = results["predictions"][i][0] | |
| label = orig_labels[i][1:len(pred) + 1] | |
| mask = label.eq(MASK_ID).long() | |
| pred = pred[:len(label)] | |
| results["predictions"][i][0] = pred * mask + label * (1 - mask) | |
| return results["predictions"], results['scores'], results["token_scores"], results["hidden"] | |
| # adapted from onmt.decoders.transformer | |
| def map_state(self, fn): | |
| def _recursive_map(struct, batch_dim=0): | |
| for k, v in struct.items(): | |
| if v is not None: | |
| if isinstance(v, dict): | |
| _recursive_map(v) | |
| else: | |
| struct[k] = fn(v, batch_dim) | |
| if self.decoder.state["cache"] is not None: | |
| _recursive_map(self.decoder.state["cache"]) | |
| class GraphPredictor(nn.Module): | |
| def __init__(self, decoder_dim, coords=False): | |
| super(GraphPredictor, self).__init__() | |
| self.coords = coords | |
| self.mlp = nn.Sequential( | |
| nn.Linear(decoder_dim * 2, decoder_dim), nn.GELU(), | |
| nn.Linear(decoder_dim, 7) | |
| ) | |
| if coords: | |
| self.coords_mlp = nn.Sequential( | |
| nn.Linear(decoder_dim, decoder_dim), nn.GELU(), | |
| nn.Linear(decoder_dim, 2) | |
| ) | |
| def forward(self, hidden, indices=None): | |
| b, l, dim = hidden.size() | |
| if indices is None: | |
| index = [i for i in range(3, l, 3)] | |
| hidden = hidden[:, index] | |
| else: | |
| batch_id = torch.arange(b).unsqueeze(1).expand_as(indices).reshape(-1) | |
| indices = indices.view(-1) | |
| hidden = hidden[batch_id, indices].view(b, -1, dim) | |
| b, l, dim = hidden.size() | |
| results = {} | |
| hh = torch.cat([hidden.unsqueeze(2).expand(b, l, l, dim), hidden.unsqueeze(1).expand(b, l, l, dim)], dim=3) | |
| results['edges'] = self.mlp(hh).permute(0, 3, 1, 2) | |
| if self.coords: | |
| results['coords'] = self.coords_mlp(hidden) | |
| return results | |
| def get_edge_prediction(edge_prob): | |
| if not edge_prob: | |
| return [], [] | |
| n = len(edge_prob) | |
| if n == 0: | |
| return [], [] | |
| for i in range(n): | |
| for j in range(i + 1, n): | |
| for k in range(5): | |
| edge_prob[i][j][k] = (edge_prob[i][j][k] + edge_prob[j][i][k]) / 2 | |
| edge_prob[j][i][k] = edge_prob[i][j][k] | |
| edge_prob[i][j][5] = (edge_prob[i][j][5] + edge_prob[j][i][6]) / 2 | |
| edge_prob[i][j][6] = (edge_prob[i][j][6] + edge_prob[j][i][5]) / 2 | |
| edge_prob[j][i][5] = edge_prob[i][j][6] | |
| edge_prob[j][i][6] = edge_prob[i][j][5] | |
| prediction = np.argmax(edge_prob, axis=2).tolist() | |
| score = np.max(edge_prob, axis=2).tolist() | |
| return prediction, score | |
| class Decoder(nn.Module): | |
| """This class is a wrapper for different decoder architectures, and support multiple decoders.""" | |
| def __init__(self, args, tokenizer): | |
| super(Decoder, self).__init__() | |
| self.args = args | |
| self.formats = args.formats | |
| self.tokenizer = tokenizer | |
| decoder = {} | |
| for format_ in args.formats: | |
| if format_ == 'edges': | |
| decoder['edges'] = GraphPredictor(args.dec_hidden_size, coords=args.continuous_coords) | |
| else: | |
| decoder[format_] = TransformerDecoderAR(args, tokenizer[format_]) | |
| self.decoder = nn.ModuleDict(decoder) | |
| self.compute_confidence = args.compute_confidence | |
| def forward(self, encoder_out, hiddens, refs): | |
| """Training mode. Compute the logits with teacher forcing.""" | |
| results = {} | |
| refs = to_device(refs, encoder_out.device) | |
| for format_ in self.formats: | |
| if format_ == 'edges': | |
| if 'atomtok_coords' in results: | |
| dec_out = results['atomtok_coords'][2] | |
| predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0]) | |
| elif 'chartok_coords' in results: | |
| dec_out = results['chartok_coords'][2] | |
| predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0]) | |
| else: | |
| raise NotImplemented | |
| targets = {'edges': refs['edges']} | |
| if 'coords' in predictions: | |
| targets['coords'] = refs['coords'] | |
| results['edges'] = (predictions, targets) | |
| else: | |
| labels, label_lengths = refs[format_] | |
| results[format_] = self.decoder[format_](encoder_out, labels, label_lengths) | |
| return results | |
| def decode(self, encoder_out, hiddens=None, refs=None, beam_size=1, n_best=1): | |
| """Inference mode. Call each decoder's decode method (if required), convert the output format (e.g. token to | |
| sequence). Beam search is not supported yet.""" | |
| results = {} | |
| predictions = [] | |
| for format_ in self.formats: | |
| if format_ in ['atomtok', 'atomtok_coords', 'chartok_coords']: | |
| max_len = FORMAT_INFO[format_]['max_len'] | |
| results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len) | |
| outputs, scores, token_scores, *_ = results[format_] | |
| beam_preds = [[self.tokenizer[format_].sequence_to_smiles(x.tolist()) for x in pred] | |
| for pred in outputs] | |
| predictions = [{format_: pred[0]} for pred in beam_preds] | |
| if self.compute_confidence: | |
| for i in range(len(predictions)): | |
| # -1: y score, -2: x score, -3: symbol score | |
| indices = np.array(predictions[i][format_]['indices']) - 3 | |
| if format_ == 'chartok_coords': | |
| atom_scores = [] | |
| for symbol, index in zip(predictions[i][format_]['symbols'], indices): | |
| atom_score = (np.prod(token_scores[i][0][index - len(symbol) + 1:index + 1]) | |
| ** (1 / len(symbol))).item() | |
| atom_scores.append(atom_score) | |
| else: | |
| atom_scores = np.array(token_scores[i][0])[indices].tolist() | |
| predictions[i][format_]['atom_scores'] = atom_scores | |
| predictions[i][format_]['average_token_score'] = scores[i][0] | |
| if format_ == 'edges': | |
| if 'atomtok_coords' in results: | |
| atom_format = 'atomtok_coords' | |
| elif 'chartok_coords' in results: | |
| atom_format = 'chartok_coords' | |
| else: | |
| raise NotImplemented | |
| dec_out = results[atom_format][3] # batch x n_best x len x dim | |
| for i in range(len(dec_out)): | |
| hidden = dec_out[i][0].unsqueeze(0) # 1 * len * dim | |
| indices = torch.LongTensor(predictions[i][atom_format]['indices']).unsqueeze(0) # 1 * k | |
| pred = self.decoder['edges'](hidden, indices) # k * k | |
| prob = F.softmax(pred['edges'].squeeze(0).permute(1, 2, 0), dim=2).tolist() # k * k * 7 | |
| edge_pred, edge_score = get_edge_prediction(prob) | |
| predictions[i]['edges'] = edge_pred | |
| if self.compute_confidence: | |
| predictions[i]['edge_scores'] = edge_score | |
| predictions[i]['edge_score_product'] = np.sqrt(np.prod(edge_score)).item() | |
| predictions[i]['overall_score'] = predictions[i][atom_format]['average_token_score'] * \ | |
| predictions[i]['edge_score_product'] | |
| predictions[i][atom_format].pop('average_token_score') | |
| predictions[i].pop('edge_score_product') | |
| return predictions | |