{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# TEDDY-G Tutorial: Generating Embeddings\n", "\n", "This notebook provides a tutorial of how to generate embeddings with a small sample of `.h5ad` data from `CellXGene` using `TEDDY-G`. We use the `70M` variant of `TEDDY-G` and demonstrate this tutorial at scale suitable for CPU use for accessibility. The `TEDDY-G` repository is fully enabled to work with accelerated hardware and we recommend the use of such with the `160M` and `400M` `TEDDY-G` variants." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Table of Contents\n", "1. [Preprocess the sample data](#preprocess-the-sample-data)\n", "2. [Tokenize the sample data](#tokenize-the-sample-data)\n", "3. [Load the model from the pretrained checkpoint](#load-the-model-from-the-pretrained-checkpoint)\n", "4. [Prepare model input](#prepare-model-input)\n", " - [Create custom data collator](#create-custom-data-collator)\n", " - [Prepare dataloader](#prepare-dataloader)\n", "5. [Implement forward loop to generate embeddings](#implement-forward-loop-to-generate-embeddings)\n", "6. [Process embeddings](#process-embeddings)\n", "7. [Plot the UMAP](#plot-the-umap)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Work from the root of the repo:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cd .." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### General imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gc\n", "import os\n", "import torch\n", "from tqdm import tqdm\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "import pandas as pd\n", "import umap\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preprocess the sample data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from teddy.data_processing.preprocessing.preprocess import preprocess\n", "\n", "preprocessing_config = {\n", " \"min_gene_counts\": None,\n", " \"remove_assays\": [],\n", " \"max_mitochondrial_prop\": None,\n", " \"remove_cell_types\": [],\n", " \"hvg_method\": None,\n", " \"normalized_total\": 10000,\n", " \"median_dict\": \"teddy/data_processing/utils/medians/data/teddy_gene_medians.json\",\n", " \"log1p\": False,\n", " \"compute_medians\": False,\n", " \"median_column\": \"index\",\n", " \"reference_id_only\": False,\n", " \"load_dir\": \"data\",\n", " \"save_dir\": \"data/processed\",\n", "}\n", "\n", "preprocess(\n", " data_path=\"data/sample_data.h5ad\",\n", " metadata_path=\"data/sample_data_metadata.json\",\n", " hyperparameters=preprocessing_config\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokenize the sample data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from teddy.data_processing.tokenization.tokenization import tokenize\n", "\n", "tokenizer_config = {\n", " \"tokenizer_name_or_path\": \"teddy/models/teddy_g/70M\",\n", " \"gene_id_column\": \"index\",\n", " \"bio_annotations\": True,\n", " \"disease_mapping\": \"teddy/data_processing/utils/bio_annotations/data/mappings/all_filtered_disease_mapping.json\",\n", " \"tissue_mapping\": \"teddy/data_processing/utils/bio_annotations/data/mappings/all_filtered_tissue_mapping.json\",\n", " \"cell_mapping\": \"teddy/data_processing/utils/bio_annotations/data/mappings/all_filtered_cell_mapping.json\",\n", " \"sex_mapping\": \"teddy/data_processing/utils/bio_annotations/data/mappings/all_filtered_sex_mapping.json\",\n", " \"max_shard_samples\": 500,\n", " \"max_seq_len\": 2048,\n", " \"pad_length\": 2048,\n", " \"add_cls\": False,\n", " \"bins\": 0,\n", " \"continuous_rank\": True,\n", " \"truncation_method\": \"max\",\n", " \"add_disease_annotation\": False,\n", " \"include_zero_genes\": False,\n", " \"load_dir\": \"data/processed\",\n", " \"save_dir\": \"data/tokenized\"\n", "}\n", "\n", "tokenize(\n", " data_path=\"data/processed/sample_data.h5ad\",\n", " metadata_path=\"data/processed/sample_data_metadata.json\",\n", " tokenization_args=tokenizer_config\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the model from the pretrained checkpoint\n", "For the remainder of the tutorial we are assuming that your HuggingFace model checkpoints are stored within the respective folder, in this case, within the folder `teddy/models/teddy_g/70M`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from teddy.models.model_directory import get_architecture, model_dict\n", "\n", "model_name_or_path = 'teddy/models/teddy_g/70M'\n", "\n", "# look up model configs in TEDDY model family dictionary\n", "architecture = get_architecture(model_name_or_path)\n", "config_cls = model_dict[architecture][\"config_cls\"]\n", "model_cls = model_dict[architecture][\"model_cls\"]\n", "\n", "# load configs \n", "config = config_cls.from_pretrained(model_name_or_path)\n", "model = model_cls.from_pretrained(model_name_or_path, config=config)\n", "\n", "# configure model to return all embeddings \n", "model.return_all_embs = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare model input\n", "\n", "### Create custom data collator" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def collate_fn(batch, tokenizer, max_seq_len=2048):\n", " \"\"\"\n", " Minimal collate function to handle variable-length `gene_ids`.\n", " Pads each batch to the longest sequence in that batch.\n", " Can be further customized to include ontologies by referring to the\n", " model code.\n", " \"\"\"\n", " batch_size = len(batch)\n", " # 1) Find the longest sequence length for this batch\n", " max_len = max_seq_len\n", " # 2) Create a padded tensor for gene_ids up to max seq length\n", " input_ids = torch.full(\n", " (batch_size, max_len),\n", " tokenizer.pad_token_id,\n", " dtype=torch.long\n", " )\n", " for i, sample in enumerate(batch):\n", " seq = sample[\"gene_ids\"]\n", " input_ids[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)\n", " # 3) Build attention mask\n", " attention_mask = (input_ids != tokenizer.pad_token_id).long()\n", " # Return dict as TEDDY G model expects: `gene_ids` + `attention_mask`.\n", " return {\n", " \"gene_ids\": input_ids,\n", " \"attention_mask\": attention_mask,\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare dataloader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from teddy.tokenizer.gene_tokenizer import GeneTokenizer\n", "tokenizer = GeneTokenizer.from_pretrained(model_name_or_path)\n", "\n", "ds = load_dataset(\"arrow\", data_files={\"train\":os.path.join('data/tokenized/sample_data', \"*.arrow\")})[\"train\"]\n", "# choose how many cells you want to embed with max_eval_samples\n", "max_eval_samples=15\n", "if max_eval_samples not training \n", "model.eval()\n", "\n", "all_embeddings = []\n", "\n", "device = torch.device(\"cpu\")\n", "\n", "with torch.no_grad():\n", " for step, batch_tensors in enumerate(tqdm(loader, desc=\"Embedding Batches\")):\n", " # Move to device\n", " gene_ids = batch_tensors[\"gene_ids\"].to(device)\n", " attn_mask = batch_tensors[\"attention_mask\"].to(device)\n", " # Forward pass (adapt to model's signature here if including ontologies)\n", " outputs = model(\n", " gene_ids=gene_ids,\n", " attention_mask=attn_mask,\n", " return_outputs=True\n", " )\n", " # Final embeddings are in `outputs[\"all_embs\"]` of shape [B, seq_len, dim]\n", " emb = outputs[\"all_embs\"].cpu()\n", " all_embeddings.append(emb)\n", " \n", "# Concatenate to shape [num_samples, seq_len, dim]\n", "final_embeddings = torch.cat(all_embeddings, dim=0)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Process embeddings" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Mean pooling of cell embeddings\n", "n_cells, seq_len, hidden_dim = final_embeddings.shape\n", "pooled_embeddings = final_embeddings.mean(dim=1) # shape -> [n_cells, hidden_dim]\n", "pooled_embeddings = pooled_embeddings.cpu().numpy()\n", "\n", "# Convert to DataFrame\n", "df_emb = pd.DataFrame(pooled_embeddings)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the UMAP" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# UMAP dimensionality reducer\n", "n_neighbors = 5\n", "random_state = 0\n", "reducer = umap.UMAP(n_neighbors=n_neighbors, random_state=random_state, metric=\"cosine\")\n", "umap_coords = reducer.fit_transform(df_emb) # shape -> [n_cells, 2]\n", "\n", "# Plot the UMAP\n", "plt.scatter(umap_coords[:, 0], umap_coords[:, 1], s=5, alpha=0.7)\n", "plt.xlabel(\"UMAP-1\")\n", "plt.ylabel(\"UMAP-2\")\n", "plt.title(\"UMAP of Mean-Pooled Cell Embeddings\")\n", "plt.show()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 2 }