Safetensors
File size: 5,185 Bytes
4527b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import hashlib
from pathlib import Path

import anndata as ad
import numpy as np
import pytest
import torch

from teddy.data_processing.tokenization.tokenization import (
    _bin_values,
    _build_batch_tensors,
    _check_genes_in_tokenizer,
    _prepare_tokenizer_args,
    _rank_continuous,
    tokenize,
)

##############################################################################
# Unit Tests: Helpers
##############################################################################
def test_bin_values_no_sorting():
    """
    Test _bin_values with no_sorting=True on a simple array.
    """
    vals_list = [np.array([0, 1, 2, 3], dtype=float)]

    class TokenArgs:
        include_zero_genes = True
        bins = 2

    token_args = TokenArgs()

    binned = _bin_values(vals_list, token_args, no_sorting=True)
    assert len(binned) == 1
    # The result is a single torch tensor with 0 => 0, others => bucket IDs (>=1).
    # For example, the min->max range is [0..3], so edges ~ [0, 1.5, 3].
    # Then bucketize: [0 => bin0, 1 => bin1, 2 => bin2, 3 => bin2], but we clamp min=1 for non-zero => final might be [0,1,2,2].
    # Adjust as your code's logic expects. Just do a sanity check:
    result = binned[0]
    expected = np.array([1, 1, 2, 2], dtype=float)
    assert (result.numpy() == expected).all(), f"Got {result.numpy()}, expected {expected}"


def test_bin_values_with_sorting():
    """
    Test _bin_values with no_sorting=False (the 'positional chunk' approach).
    """
    vals_list = [np.array([5, 4, 3, 0], dtype=float)]

    class TokenArgs:
        include_zero_genes = True
        bins = 2

    token_args = TokenArgs()
    binned = _bin_values(vals_list, token_args, no_sorting=False)
    assert len(binned) == 1
    # The 'positional chunk' logic is tested. We won't replicate the entire math here;
    # just confirm the shape and that there's a zero-bin appended for the zero entry.
    result = binned[0]
    assert len(result) == 4, f"Expected 4 bins in the result, got {len(result)}"


def test_rank_continuous_normal():
    """
    Should produce a descending linear scale from ~-1 to +1 across the entire array.
    """
    arr = np.array([3, 2, 1, 0], dtype=float)

    class TokenArgs:
        pass

    token_args = TokenArgs()
    ranked = _rank_continuous(arr, token_args)
    # Should be a descending scale from near -1 => +1 (since we flip).
    # len=4 => torch.linspace(-1,1,steps=4) => [-1, -0.3333, +0.3333, +1], flipped => [+1, +0.3333, -0.3333, -1].
    # Just check shape, min, max, monotonic. Exact values might differ slightly if your code does micro offsets.
    assert ranked.shape[0] == 4
    assert ranked[0] > ranked[1]  # Descending
    assert torch.isclose(ranked.min(), torch.tensor(-1.0), atol=1e-5)



def test_prepare_tokenizer_args_dict():
    """
    Test that a dict tokenization_args is converted to TokenizationArgs object properly,
    and random seeds are set if gene_seed is not None.
    """

    args_dict = {
        "load_dir": "/mock/load",
        "save_dir": "/mock/save",
        "gene_seed": 42,
        "tokenizer_name_or_path": "some/tokenizer",
    }
    token_args_obj, load_dir, save_dir = _prepare_tokenizer_args(args_dict)
    assert load_dir == "/mock/load"
    assert save_dir == "/mock/save"
    assert token_args_obj.gene_seed == 42


def test_check_genes_in_tokenizer():
    """
    Test _check_genes_in_tokenizer with a minimal mock GeneTokenizer vocab.
    """
    # Mock an AnnData with 4 genes, 2 of which are in vocab
    import anndata as ad

    from teddy.tokenizer.gene_tokenizer import GeneTokenizer

    data = ad.AnnData(X=np.zeros((10, 4)))  # 10 cells, 4 genes
    data.var["gene_name"] = ["G1", "G2", "G3", "G4"]
    # Create a mock tokenizer with vocab = {G2, G3}
    class MockGeneTokenizer:
        def __init__(self, vocab):
            self.vocab = vocab

    mock_tokenizer = MockGeneTokenizer({"G2": 1, "G3": 2})

    gene_in_vocab, coding_genes, ratio = _check_genes_in_tokenizer(
        data,
        gene_id_column="gene_name",
        tokenizer=mock_tokenizer
    )

    # We expect G2, G3 => 2 matches out of 4 => ratio=0.5
    assert len(gene_in_vocab) == 2
    assert ratio == 0.5


def test_build_batch_tensors():
    """
    Confirm _build_batch_tensors performs topk or random_genes if specified, and returns expected shapes.
    We'll do a simple topk test.
    """
    X_batch = torch.tensor([[5, 0, 1], [0, 2, 2]], dtype=torch.float)  # 2 cells x 3 genes
    token_array = torch.tensor([10, 11, 12])  # "gene IDs" for the 3 columns

    class TokenArgs:
        max_seq_len = 3
        add_cls = False
        random_genes = False

    gene_ids, vals, labels_list, decoder_vals = _build_batch_tensors(X_batch, token_array, TokenArgs())
    # topk => for each row => pick best 3 (all, since 3 genes). Sorted => descending
    # Row0 => [5,1,0], gene IDs => [10,12,11]
    # Row1 => [2,2,0], gene IDs => [11,12,10] (tie but stable - depends on how PyTorch breaks ties)
    assert len(gene_ids) == 2
    assert labels_list is None
    # Check shapes
    assert len(gene_ids[0]) == 3
    assert len(vals[0]) == 3
    assert decoder_vals is None