|
|
import hashlib |
|
|
from pathlib import Path |
|
|
|
|
|
import anndata as ad |
|
|
import numpy as np |
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from teddy.data_processing.tokenization.tokenization import ( |
|
|
_bin_values, |
|
|
_build_batch_tensors, |
|
|
_check_genes_in_tokenizer, |
|
|
_prepare_tokenizer_args, |
|
|
_rank_continuous, |
|
|
tokenize, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_bin_values_no_sorting(): |
|
|
""" |
|
|
Test _bin_values with no_sorting=True on a simple array. |
|
|
""" |
|
|
vals_list = [np.array([0, 1, 2, 3], dtype=float)] |
|
|
|
|
|
class TokenArgs: |
|
|
include_zero_genes = True |
|
|
bins = 2 |
|
|
|
|
|
token_args = TokenArgs() |
|
|
|
|
|
binned = _bin_values(vals_list, token_args, no_sorting=True) |
|
|
assert len(binned) == 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = binned[0] |
|
|
expected = np.array([1, 1, 2, 2], dtype=float) |
|
|
assert (result.numpy() == expected).all(), f"Got {result.numpy()}, expected {expected}" |
|
|
|
|
|
|
|
|
def test_bin_values_with_sorting(): |
|
|
""" |
|
|
Test _bin_values with no_sorting=False (the 'positional chunk' approach). |
|
|
""" |
|
|
vals_list = [np.array([5, 4, 3, 0], dtype=float)] |
|
|
|
|
|
class TokenArgs: |
|
|
include_zero_genes = True |
|
|
bins = 2 |
|
|
|
|
|
token_args = TokenArgs() |
|
|
binned = _bin_values(vals_list, token_args, no_sorting=False) |
|
|
assert len(binned) == 1 |
|
|
|
|
|
|
|
|
result = binned[0] |
|
|
assert len(result) == 4, f"Expected 4 bins in the result, got {len(result)}" |
|
|
|
|
|
|
|
|
def test_rank_continuous_normal(): |
|
|
""" |
|
|
Should produce a descending linear scale from ~-1 to +1 across the entire array. |
|
|
""" |
|
|
arr = np.array([3, 2, 1, 0], dtype=float) |
|
|
|
|
|
class TokenArgs: |
|
|
pass |
|
|
|
|
|
token_args = TokenArgs() |
|
|
ranked = _rank_continuous(arr, token_args) |
|
|
|
|
|
|
|
|
|
|
|
assert ranked.shape[0] == 4 |
|
|
assert ranked[0] > ranked[1] |
|
|
assert torch.isclose(ranked.min(), torch.tensor(-1.0), atol=1e-5) |
|
|
|
|
|
|
|
|
|
|
|
def test_prepare_tokenizer_args_dict(): |
|
|
""" |
|
|
Test that a dict tokenization_args is converted to TokenizationArgs object properly, |
|
|
and random seeds are set if gene_seed is not None. |
|
|
""" |
|
|
|
|
|
args_dict = { |
|
|
"load_dir": "/mock/load", |
|
|
"save_dir": "/mock/save", |
|
|
"gene_seed": 42, |
|
|
"tokenizer_name_or_path": "some/tokenizer", |
|
|
} |
|
|
token_args_obj, load_dir, save_dir = _prepare_tokenizer_args(args_dict) |
|
|
assert load_dir == "/mock/load" |
|
|
assert save_dir == "/mock/save" |
|
|
assert token_args_obj.gene_seed == 42 |
|
|
|
|
|
|
|
|
def test_check_genes_in_tokenizer(): |
|
|
""" |
|
|
Test _check_genes_in_tokenizer with a minimal mock GeneTokenizer vocab. |
|
|
""" |
|
|
|
|
|
import anndata as ad |
|
|
|
|
|
from teddy.tokenizer.gene_tokenizer import GeneTokenizer |
|
|
|
|
|
data = ad.AnnData(X=np.zeros((10, 4))) |
|
|
data.var["gene_name"] = ["G1", "G2", "G3", "G4"] |
|
|
|
|
|
class MockGeneTokenizer: |
|
|
def __init__(self, vocab): |
|
|
self.vocab = vocab |
|
|
|
|
|
mock_tokenizer = MockGeneTokenizer({"G2": 1, "G3": 2}) |
|
|
|
|
|
gene_in_vocab, coding_genes, ratio = _check_genes_in_tokenizer( |
|
|
data, |
|
|
gene_id_column="gene_name", |
|
|
tokenizer=mock_tokenizer |
|
|
) |
|
|
|
|
|
|
|
|
assert len(gene_in_vocab) == 2 |
|
|
assert ratio == 0.5 |
|
|
|
|
|
|
|
|
def test_build_batch_tensors(): |
|
|
""" |
|
|
Confirm _build_batch_tensors performs topk or random_genes if specified, and returns expected shapes. |
|
|
We'll do a simple topk test. |
|
|
""" |
|
|
X_batch = torch.tensor([[5, 0, 1], [0, 2, 2]], dtype=torch.float) |
|
|
token_array = torch.tensor([10, 11, 12]) |
|
|
|
|
|
class TokenArgs: |
|
|
max_seq_len = 3 |
|
|
add_cls = False |
|
|
random_genes = False |
|
|
|
|
|
gene_ids, vals, labels_list, decoder_vals = _build_batch_tensors(X_batch, token_array, TokenArgs()) |
|
|
|
|
|
|
|
|
|
|
|
assert len(gene_ids) == 2 |
|
|
assert labels_list is None |
|
|
|
|
|
assert len(gene_ids[0]) == 3 |
|
|
assert len(vals[0]) == 3 |
|
|
assert decoder_vals is None |
|
|
|
|
|
|