Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -405,9 +405,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 405 |
"""
|
| 406 |
|
| 407 |
# Compute English token embeddings
|
| 408 |
-
print("(debug) in biobraindecoder, english tokens ids : ", english_token_ids.shape)
|
| 409 |
tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
|
| 410 |
-
print("(debug) tokens_embeddings shape : ", tokens_embeddings.shape)
|
| 411 |
|
| 412 |
if projected_bio_embeddings is not None:
|
| 413 |
(
|
|
@@ -419,8 +417,10 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 419 |
|
| 420 |
# Insert the bio embeddings at the SEQ token positions
|
| 421 |
processed_tokens_ids = english_token_ids.clone()
|
| 422 |
-
print("(debug)
|
| 423 |
-
print("(debug)
|
|
|
|
|
|
|
| 424 |
for bio_seq_num in range(num_bio_sequences):
|
| 425 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
| 426 |
processed_tokens_ids,
|
|
@@ -431,7 +431,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 431 |
print("After call : ", tokens_embeddings.shape)
|
| 432 |
|
| 433 |
# Regular GPT pass through
|
| 434 |
-
print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
|
| 435 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
| 436 |
embeddings = self.gpt_model.final_norm(embeddings)
|
| 437 |
|
|
@@ -472,6 +471,11 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 472 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
| 473 |
- tokens with the SEQ token set to -1
|
| 474 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
def _insert(
|
| 477 |
tokens_1d: torch.Tensor,
|
|
@@ -485,6 +489,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 485 |
resampled_embeddings (torch.Tensor):
|
| 486 |
Shape (bio_sequence_length, embed_dim,)
|
| 487 |
"""
|
|
|
|
| 488 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
| 489 |
if indices.numel() > 0:
|
| 490 |
idx = indices[0].item()
|
|
@@ -501,6 +506,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 501 |
:-1, :
|
| 502 |
]
|
| 503 |
tokens_1d[idx] = -1
|
|
|
|
| 504 |
return x, tokens_1d
|
| 505 |
else:
|
| 506 |
return (
|
|
@@ -519,8 +525,11 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 519 |
)
|
| 520 |
tokens_acc.append(tokens_out)
|
| 521 |
embeddings_acc.append(embeddings_out)
|
|
|
|
|
|
|
| 522 |
tokens_acc = torch.stack(tokens_acc)
|
| 523 |
embeddings_acc = torch.stack(embeddings_acc)
|
|
|
|
| 524 |
|
| 525 |
return embeddings_acc, tokens_acc
|
| 526 |
|
|
@@ -701,13 +710,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 701 |
|
| 702 |
if projected_bio_embeddings is None:
|
| 703 |
# Compute bio sequences embeddings
|
| 704 |
-
print("(debug) shape bio tokens ids : ", bio_token_ids.shape)
|
| 705 |
bio_embeddings_list = [
|
| 706 |
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
| 707 |
for bio_seq_num in range(num_bio_sequences)
|
| 708 |
]
|
| 709 |
|
| 710 |
-
print("(debug) shape of embeddings : ", bio_embeddings_list[0].shape)
|
| 711 |
|
| 712 |
# Project these embeddings
|
| 713 |
projected_bio_embeddings = [
|
|
@@ -718,14 +725,9 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 718 |
)
|
| 719 |
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
|
| 720 |
]
|
| 721 |
-
print("(debug) Shape output projection model : ", projected_bio_embeddings[0].shape)
|
| 722 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 723 |
-
print("(debug) Shape projected bio embeddings : ", projected_bio_embeddings.shape)
|
| 724 |
|
| 725 |
# decode
|
| 726 |
-
print("(debug) Going in biobrain decoder : ")
|
| 727 |
-
print("(debug) English token ids : ", english_token_ids.shape)
|
| 728 |
-
print("(debug) Projected bio embeddings : ", projected_bio_embeddings.shape)
|
| 729 |
logits = self.biobrain_decoder(
|
| 730 |
english_token_ids=english_token_ids,
|
| 731 |
projected_bio_embeddings=projected_bio_embeddings,
|
|
@@ -899,7 +901,6 @@ class TorchGptGroupedQueryAttention(nn.Module):
|
|
| 899 |
value_inputs: torch.Tensor,
|
| 900 |
attention_mask: torch.Tensor = None,
|
| 901 |
) -> torch.Tensor:
|
| 902 |
-
print("(debug) Query input shape : ", query_inputs.shape)
|
| 903 |
batch_size, seq_len, _ = query_inputs.shape
|
| 904 |
|
| 905 |
queries = self.query_linear(query_inputs).view( # noqa
|
|
@@ -981,7 +982,6 @@ class TorchGptDecoder(nn.Module):
|
|
| 981 |
if attention_mask is None:
|
| 982 |
attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
|
| 983 |
for layer in self.layers:
|
| 984 |
-
print("Embedding shape in apply_transformer_layers : ", embeddings.shape)
|
| 985 |
embeddings = layer(embeddings, attention_mask)
|
| 986 |
|
| 987 |
return embeddings
|
|
|
|
| 405 |
"""
|
| 406 |
|
| 407 |
# Compute English token embeddings
|
|
|
|
| 408 |
tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
|
|
|
|
| 409 |
|
| 410 |
if projected_bio_embeddings is not None:
|
| 411 |
(
|
|
|
|
| 417 |
|
| 418 |
# Insert the bio embeddings at the SEQ token positions
|
| 419 |
processed_tokens_ids = english_token_ids.clone()
|
| 420 |
+
print("(debug) Before call tokens embeddings shape : ", tokens_embeddings.shape)
|
| 421 |
+
print("(debug) Before call Processed tokens ids shape : ", processed_tokens_ids.shape)
|
| 422 |
+
print("(debug) Before call Projected bio embeddings shape : ", projected_bio_embeddings.shape)
|
| 423 |
+
print("num bio sequences : ", num_bio_sequences)
|
| 424 |
for bio_seq_num in range(num_bio_sequences):
|
| 425 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
| 426 |
processed_tokens_ids,
|
|
|
|
| 431 |
print("After call : ", tokens_embeddings.shape)
|
| 432 |
|
| 433 |
# Regular GPT pass through
|
|
|
|
| 434 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
| 435 |
embeddings = self.gpt_model.final_norm(embeddings)
|
| 436 |
|
|
|
|
| 471 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
| 472 |
- tokens with the SEQ token set to -1
|
| 473 |
"""
|
| 474 |
+
print("Insert_embeddings input shape : ")
|
| 475 |
+
print("Tokens : ", tokens.shape)
|
| 476 |
+
print("Input embeddings : ", input_embeddings.shape)
|
| 477 |
+
print("Resampled embeddings : ", resampled_embeddings.shape)
|
| 478 |
+
print("Bio seq num : ", bio_seq_num)
|
| 479 |
|
| 480 |
def _insert(
|
| 481 |
tokens_1d: torch.Tensor,
|
|
|
|
| 489 |
resampled_embeddings (torch.Tensor):
|
| 490 |
Shape (bio_sequence_length, embed_dim,)
|
| 491 |
"""
|
| 492 |
+
print("_insert input : ", input_embeddings_1d.shape, resampled_embeddings_1d.shape)
|
| 493 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
| 494 |
if indices.numel() > 0:
|
| 495 |
idx = indices[0].item()
|
|
|
|
| 506 |
:-1, :
|
| 507 |
]
|
| 508 |
tokens_1d[idx] = -1
|
| 509 |
+
print("_insert output : ", x.shape)
|
| 510 |
return x, tokens_1d
|
| 511 |
else:
|
| 512 |
return (
|
|
|
|
| 525 |
)
|
| 526 |
tokens_acc.append(tokens_out)
|
| 527 |
embeddings_acc.append(embeddings_out)
|
| 528 |
+
|
| 529 |
+
print("(Embeddings_acc[0] shape : ", embeddings_acc[0].shape)
|
| 530 |
tokens_acc = torch.stack(tokens_acc)
|
| 531 |
embeddings_acc = torch.stack(embeddings_acc)
|
| 532 |
+
print("Embeddings acc shape : ", embeddings_acc.shape)
|
| 533 |
|
| 534 |
return embeddings_acc, tokens_acc
|
| 535 |
|
|
|
|
| 710 |
|
| 711 |
if projected_bio_embeddings is None:
|
| 712 |
# Compute bio sequences embeddings
|
|
|
|
| 713 |
bio_embeddings_list = [
|
| 714 |
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
| 715 |
for bio_seq_num in range(num_bio_sequences)
|
| 716 |
]
|
| 717 |
|
|
|
|
| 718 |
|
| 719 |
# Project these embeddings
|
| 720 |
projected_bio_embeddings = [
|
|
|
|
| 725 |
)
|
| 726 |
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
|
| 727 |
]
|
|
|
|
| 728 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
|
|
|
| 729 |
|
| 730 |
# decode
|
|
|
|
|
|
|
|
|
|
| 731 |
logits = self.biobrain_decoder(
|
| 732 |
english_token_ids=english_token_ids,
|
| 733 |
projected_bio_embeddings=projected_bio_embeddings,
|
|
|
|
| 901 |
value_inputs: torch.Tensor,
|
| 902 |
attention_mask: torch.Tensor = None,
|
| 903 |
) -> torch.Tensor:
|
|
|
|
| 904 |
batch_size, seq_len, _ = query_inputs.shape
|
| 905 |
|
| 906 |
queries = self.query_linear(query_inputs).view( # noqa
|
|
|
|
| 982 |
if attention_mask is None:
|
| 983 |
attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
|
| 984 |
for layer in self.layers:
|
|
|
|
| 985 |
embeddings = layer(embeddings, attention_mask)
|
| 986 |
|
| 987 |
return embeddings
|