# TEDDY-G Tutorial: Generating Embeddings

This notebook provides a tutorial of how to generate embeddings with a small sample of `.h5ad` data from `CellXGene` using `TEDDY-G`. We use the `70M` variant of `TEDDY-G` and demonstrate this tutorial at scale suitable for CPU use for accessibility. The `TEDDY-G` repository is fully enabled to work with accelerated hardware and we recommend the use of such with the `160M` and `400M` `TEDDY-G` variants.

## Table of Contents
1. [Preprocess the sample data](#preprocess-the-sample-data)
2. [Tokenize the sample data](#tokenize-the-sample-data)
3. [Load the model from the pretrained checkpoint](#load-the-model-from-the-pretrained-checkpoint)
4. [Prepare model input](#prepare-model-input)
 - [Create custom data collator](#create-custom-data-collator)
 - [Prepare dataloader](#prepare-dataloader)
5. [Implement forward loop to generate embeddings](#implement-forward-loop-to-generate-embeddings)
6. [Process embeddings](#process-embeddings)
7. [Plot the UMAP](#plot-the-umap)

Work from the root of the repo:

In [None]:
cd ..

### General imports

In [None]:
import gc
import os
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import load_dataset
import pandas as pd
import umap
import matplotlib.pyplot as plt

### Preprocess the sample data

In [None]:
from teddy.data_processing.preprocessing.preprocess import preprocess

preprocessing_config = {
 "min_gene_counts": None,
 "remove_assays": [],
 "max_mitochondrial_prop": None,
 "remove_cell_types": [],
 "hvg_method": None,
 "normalized_total": 10000,
 "median_dict": "teddy/data_processing/utils/medians/data/teddy_gene_medians.json",
 "log1p": False,
 "compute_medians": False,
 "median_column": "index",
 "reference_id_only": False,
 "load_dir": "data",
 "save_dir": "data/processed",
}

preprocess(
 data_path="data/sample_data.h5ad",
 metadata_path="data/sample_data_metadata.json",
 hyperparameters=preprocessing_config
)

### Tokenize the sample data

In [None]:
from teddy.data_processing.tokenization.tokenization import tokenize

tokenizer_config = {
 "tokenizer_name_or_path": "teddy/models/teddy_g/70M",
 "gene_id_column": "index",
 "bio_annotations": True,
 "disease_mapping": "teddy/data_processing/utils/bio_annotations/data/mappings/all_filtered_disease_mapping.json",
 "tissue_mapping": "teddy/data_processing/utils/bio_annotations/data/mappings/all_filtered_tissue_mapping.json",
 "cell_mapping": "teddy/data_processing/utils/bio_annotations/data/mappings/all_filtered_cell_mapping.json",
 "sex_mapping": "teddy/data_processing/utils/bio_annotations/data/mappings/all_filtered_sex_mapping.json",
 "max_shard_samples": 500,
 "max_seq_len": 2048,
 "pad_length": 2048,
 "add_cls": False,
 "bins": 0,
 "continuous_rank": True,
 "truncation_method": "max",
 "add_disease_annotation": False,
 "include_zero_genes": False,
 "load_dir": "data/processed",
 "save_dir": "data/tokenized"
}

tokenize(
 data_path="data/processed/sample_data.h5ad",
 metadata_path="data/processed/sample_data_metadata.json",
 tokenization_args=tokenizer_config
)


### Load the model from the pretrained checkpoint
For the remainder of the tutorial we are assuming that your HuggingFace model checkpoints are stored within the respective folder, in this case, within the folder `teddy/models/teddy_g/70M`. 

In [None]:
from teddy.models.model_directory import get_architecture, model_dict

model_name_or_path = 'teddy/models/teddy_g/70M'

# look up model configs in TEDDY model family dictionary
architecture = get_architecture(model_name_or_path)
config_cls = model_dict[architecture]["config_cls"]
model_cls = model_dict[architecture]["model_cls"]

# load configs 
config = config_cls.from_pretrained(model_name_or_path)
model = model_cls.from_pretrained(model_name_or_path, config=config)

# configure model to return all embeddings 
model.return_all_embs = True

## Prepare model input

### Create custom data collator

In [None]:
def collate_fn(batch, tokenizer, max_seq_len=2048):
 """
 Minimal collate function to handle variable-length `gene_ids`.
 Pads each batch to the longest sequence in that batch.
 Can be further customized to include ontologies by referring to the
 model code.
 """
 batch_size = len(batch)
 # 1) Find the longest sequence length for this batch
 max_len = max_seq_len
 # 2) Create a padded tensor for gene_ids up to max seq length
 input_ids = torch.full(
 (batch_size, max_len),
 tokenizer.pad_token_id,
 dtype=torch.long
 )
 for i, sample in enumerate(batch):
 seq = sample["gene_ids"]
 input_ids[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
 # 3) Build attention mask
 attention_mask = (input_ids != tokenizer.pad_token_id).long()
 # Return dict as TEDDY G model expects: `gene_ids` + `attention_mask`.
 return {
 "gene_ids": input_ids,
 "attention_mask": attention_mask,
 }

### Prepare dataloader

In [None]:
from teddy.tokenizer.gene_tokenizer import GeneTokenizer
tokenizer = GeneTokenizer.from_pretrained(model_name_or_path)

ds = load_dataset("arrow", data_files={"train":os.path.join('data/tokenized/sample_data', "*.arrow")})["train"]
# choose how many cells you want to embed with max_eval_samples
max_eval_samples=15
if max_eval_samples not training 
model.eval()

all_embeddings = []

device = torch.device("cpu")

with torch.no_grad():
 for step, batch_tensors in enumerate(tqdm(loader, desc="Embedding Batches")):
 # Move to device
 gene_ids = batch_tensors["gene_ids"].to(device)
 attn_mask = batch_tensors["attention_mask"].to(device)
 # Forward pass (adapt to model's signature here if including ontologies)
 outputs = model(
 gene_ids=gene_ids,
 attention_mask=attn_mask,
 return_outputs=True
 )
 # Final embeddings are in `outputs["all_embs"]` of shape [B, seq_len, dim]
 emb = outputs["all_embs"].cpu()
 all_embeddings.append(emb)
 
# Concatenate to shape [num_samples, seq_len, dim]
final_embeddings = torch.cat(all_embeddings, dim=0)


## Process embeddings

In [None]:
# Mean pooling of cell embeddings
n_cells, seq_len, hidden_dim = final_embeddings.shape
pooled_embeddings = final_embeddings.mean(dim=1) # shape -> [n_cells, hidden_dim]
pooled_embeddings = pooled_embeddings.cpu().numpy()

# Convert to DataFrame
df_emb = pd.DataFrame(pooled_embeddings)

## Plot the UMAP

In [None]:
# UMAP dimensionality reducer
n_neighbors = 5
random_state = 0
reducer = umap.UMAP(n_neighbors=n_neighbors, random_state=random_state, metric="cosine")
umap_coords = reducer.fit_transform(df_emb) # shape -> [n_cells, 2]

# Plot the UMAP
plt.scatter(umap_coords[:, 0], umap_coords[:, 1], s=5, alpha=0.7)
plt.xlabel("UMAP-1")
plt.ylabel("UMAP-2")
plt.title("UMAP of Mean-Pooled Cell Embeddings")
plt.show()
