Safetensors
TEDDY / tests /test_tokenize.py
soumyatghosh's picture
Upload folder using huggingface_hub
4527b5f verified
import hashlib
from pathlib import Path
import anndata as ad
import numpy as np
import pytest
import torch
from teddy.data_processing.tokenization.tokenization import (
_bin_values,
_build_batch_tensors,
_check_genes_in_tokenizer,
_prepare_tokenizer_args,
_rank_continuous,
tokenize,
)
##############################################################################
# Unit Tests: Helpers
##############################################################################
def test_bin_values_no_sorting():
"""
Test _bin_values with no_sorting=True on a simple array.
"""
vals_list = [np.array([0, 1, 2, 3], dtype=float)]
class TokenArgs:
include_zero_genes = True
bins = 2
token_args = TokenArgs()
binned = _bin_values(vals_list, token_args, no_sorting=True)
assert len(binned) == 1
# The result is a single torch tensor with 0 => 0, others => bucket IDs (>=1).
# For example, the min->max range is [0..3], so edges ~ [0, 1.5, 3].
# Then bucketize: [0 => bin0, 1 => bin1, 2 => bin2, 3 => bin2], but we clamp min=1 for non-zero => final might be [0,1,2,2].
# Adjust as your code's logic expects. Just do a sanity check:
result = binned[0]
expected = np.array([1, 1, 2, 2], dtype=float)
assert (result.numpy() == expected).all(), f"Got {result.numpy()}, expected {expected}"
def test_bin_values_with_sorting():
"""
Test _bin_values with no_sorting=False (the 'positional chunk' approach).
"""
vals_list = [np.array([5, 4, 3, 0], dtype=float)]
class TokenArgs:
include_zero_genes = True
bins = 2
token_args = TokenArgs()
binned = _bin_values(vals_list, token_args, no_sorting=False)
assert len(binned) == 1
# The 'positional chunk' logic is tested. We won't replicate the entire math here;
# just confirm the shape and that there's a zero-bin appended for the zero entry.
result = binned[0]
assert len(result) == 4, f"Expected 4 bins in the result, got {len(result)}"
def test_rank_continuous_normal():
"""
Should produce a descending linear scale from ~-1 to +1 across the entire array.
"""
arr = np.array([3, 2, 1, 0], dtype=float)
class TokenArgs:
pass
token_args = TokenArgs()
ranked = _rank_continuous(arr, token_args)
# Should be a descending scale from near -1 => +1 (since we flip).
# len=4 => torch.linspace(-1,1,steps=4) => [-1, -0.3333, +0.3333, +1], flipped => [+1, +0.3333, -0.3333, -1].
# Just check shape, min, max, monotonic. Exact values might differ slightly if your code does micro offsets.
assert ranked.shape[0] == 4
assert ranked[0] > ranked[1] # Descending
assert torch.isclose(ranked.min(), torch.tensor(-1.0), atol=1e-5)
def test_prepare_tokenizer_args_dict():
"""
Test that a dict tokenization_args is converted to TokenizationArgs object properly,
and random seeds are set if gene_seed is not None.
"""
args_dict = {
"load_dir": "/mock/load",
"save_dir": "/mock/save",
"gene_seed": 42,
"tokenizer_name_or_path": "some/tokenizer",
}
token_args_obj, load_dir, save_dir = _prepare_tokenizer_args(args_dict)
assert load_dir == "/mock/load"
assert save_dir == "/mock/save"
assert token_args_obj.gene_seed == 42
def test_check_genes_in_tokenizer():
"""
Test _check_genes_in_tokenizer with a minimal mock GeneTokenizer vocab.
"""
# Mock an AnnData with 4 genes, 2 of which are in vocab
import anndata as ad
from teddy.tokenizer.gene_tokenizer import GeneTokenizer
data = ad.AnnData(X=np.zeros((10, 4))) # 10 cells, 4 genes
data.var["gene_name"] = ["G1", "G2", "G3", "G4"]
# Create a mock tokenizer with vocab = {G2, G3}
class MockGeneTokenizer:
def __init__(self, vocab):
self.vocab = vocab
mock_tokenizer = MockGeneTokenizer({"G2": 1, "G3": 2})
gene_in_vocab, coding_genes, ratio = _check_genes_in_tokenizer(
data,
gene_id_column="gene_name",
tokenizer=mock_tokenizer
)
# We expect G2, G3 => 2 matches out of 4 => ratio=0.5
assert len(gene_in_vocab) == 2
assert ratio == 0.5
def test_build_batch_tensors():
"""
Confirm _build_batch_tensors performs topk or random_genes if specified, and returns expected shapes.
We'll do a simple topk test.
"""
X_batch = torch.tensor([[5, 0, 1], [0, 2, 2]], dtype=torch.float) # 2 cells x 3 genes
token_array = torch.tensor([10, 11, 12]) # "gene IDs" for the 3 columns
class TokenArgs:
max_seq_len = 3
add_cls = False
random_genes = False
gene_ids, vals, labels_list, decoder_vals = _build_batch_tensors(X_batch, token_array, TokenArgs())
# topk => for each row => pick best 3 (all, since 3 genes). Sorted => descending
# Row0 => [5,1,0], gene IDs => [10,12,11]
# Row1 => [2,2,0], gene IDs => [11,12,10] (tie but stable - depends on how PyTorch breaks ties)
assert len(gene_ids) == 2
assert labels_list is None
# Check shapes
assert len(gene_ids[0]) == 3
assert len(vals[0]) == 3
assert decoder_vals is None