File size: 5,185 Bytes
4527b5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import hashlib
from pathlib import Path
import anndata as ad
import numpy as np
import pytest
import torch
from teddy.data_processing.tokenization.tokenization import (
_bin_values,
_build_batch_tensors,
_check_genes_in_tokenizer,
_prepare_tokenizer_args,
_rank_continuous,
tokenize,
)
##############################################################################
# Unit Tests: Helpers
##############################################################################
def test_bin_values_no_sorting():
"""
Test _bin_values with no_sorting=True on a simple array.
"""
vals_list = [np.array([0, 1, 2, 3], dtype=float)]
class TokenArgs:
include_zero_genes = True
bins = 2
token_args = TokenArgs()
binned = _bin_values(vals_list, token_args, no_sorting=True)
assert len(binned) == 1
# The result is a single torch tensor with 0 => 0, others => bucket IDs (>=1).
# For example, the min->max range is [0..3], so edges ~ [0, 1.5, 3].
# Then bucketize: [0 => bin0, 1 => bin1, 2 => bin2, 3 => bin2], but we clamp min=1 for non-zero => final might be [0,1,2,2].
# Adjust as your code's logic expects. Just do a sanity check:
result = binned[0]
expected = np.array([1, 1, 2, 2], dtype=float)
assert (result.numpy() == expected).all(), f"Got {result.numpy()}, expected {expected}"
def test_bin_values_with_sorting():
"""
Test _bin_values with no_sorting=False (the 'positional chunk' approach).
"""
vals_list = [np.array([5, 4, 3, 0], dtype=float)]
class TokenArgs:
include_zero_genes = True
bins = 2
token_args = TokenArgs()
binned = _bin_values(vals_list, token_args, no_sorting=False)
assert len(binned) == 1
# The 'positional chunk' logic is tested. We won't replicate the entire math here;
# just confirm the shape and that there's a zero-bin appended for the zero entry.
result = binned[0]
assert len(result) == 4, f"Expected 4 bins in the result, got {len(result)}"
def test_rank_continuous_normal():
"""
Should produce a descending linear scale from ~-1 to +1 across the entire array.
"""
arr = np.array([3, 2, 1, 0], dtype=float)
class TokenArgs:
pass
token_args = TokenArgs()
ranked = _rank_continuous(arr, token_args)
# Should be a descending scale from near -1 => +1 (since we flip).
# len=4 => torch.linspace(-1,1,steps=4) => [-1, -0.3333, +0.3333, +1], flipped => [+1, +0.3333, -0.3333, -1].
# Just check shape, min, max, monotonic. Exact values might differ slightly if your code does micro offsets.
assert ranked.shape[0] == 4
assert ranked[0] > ranked[1] # Descending
assert torch.isclose(ranked.min(), torch.tensor(-1.0), atol=1e-5)
def test_prepare_tokenizer_args_dict():
"""
Test that a dict tokenization_args is converted to TokenizationArgs object properly,
and random seeds are set if gene_seed is not None.
"""
args_dict = {
"load_dir": "/mock/load",
"save_dir": "/mock/save",
"gene_seed": 42,
"tokenizer_name_or_path": "some/tokenizer",
}
token_args_obj, load_dir, save_dir = _prepare_tokenizer_args(args_dict)
assert load_dir == "/mock/load"
assert save_dir == "/mock/save"
assert token_args_obj.gene_seed == 42
def test_check_genes_in_tokenizer():
"""
Test _check_genes_in_tokenizer with a minimal mock GeneTokenizer vocab.
"""
# Mock an AnnData with 4 genes, 2 of which are in vocab
import anndata as ad
from teddy.tokenizer.gene_tokenizer import GeneTokenizer
data = ad.AnnData(X=np.zeros((10, 4))) # 10 cells, 4 genes
data.var["gene_name"] = ["G1", "G2", "G3", "G4"]
# Create a mock tokenizer with vocab = {G2, G3}
class MockGeneTokenizer:
def __init__(self, vocab):
self.vocab = vocab
mock_tokenizer = MockGeneTokenizer({"G2": 1, "G3": 2})
gene_in_vocab, coding_genes, ratio = _check_genes_in_tokenizer(
data,
gene_id_column="gene_name",
tokenizer=mock_tokenizer
)
# We expect G2, G3 => 2 matches out of 4 => ratio=0.5
assert len(gene_in_vocab) == 2
assert ratio == 0.5
def test_build_batch_tensors():
"""
Confirm _build_batch_tensors performs topk or random_genes if specified, and returns expected shapes.
We'll do a simple topk test.
"""
X_batch = torch.tensor([[5, 0, 1], [0, 2, 2]], dtype=torch.float) # 2 cells x 3 genes
token_array = torch.tensor([10, 11, 12]) # "gene IDs" for the 3 columns
class TokenArgs:
max_seq_len = 3
add_cls = False
random_genes = False
gene_ids, vals, labels_list, decoder_vals = _build_batch_tensors(X_batch, token_array, TokenArgs())
# topk => for each row => pick best 3 (all, since 3 genes). Sorted => descending
# Row0 => [5,1,0], gene IDs => [10,12,11]
# Row1 => [2,2,0], gene IDs => [11,12,10] (tie but stable - depends on how PyTorch breaks ties)
assert len(gene_ids) == 2
assert labels_list is None
# Check shapes
assert len(gene_ids[0]) == 3
assert len(vals[0]) == 3
assert decoder_vals is None
|