File size: 8,413 Bytes
4527b5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# test_preprocess.py
import hashlib
import json
import os
from pathlib import Path
import anndata as ad
import numpy as np
import pandas as pd
import pytest
from scipy.sparse import csr_matrix
from teddy.data_processing.preprocessing.preprocess import (
compute_and_save_medians,
filter_cells_by_gene_counts,
filter_cells_by_mitochondrial_fraction,
filter_highly_variable_genes,
initialize_processed_layer,
log_transform_layer,
normalize_data_inplace,
preprocess,
set_raw_if_necessary,
update_metadata,
)
@pytest.fixture
def synthetic_anndata():
"""
Returns a small synthetic AnnData object for testing.
"""
X = np.array([[0, 2, 3], [4, 0, 6], [7, 8, 0], [0, 0, 1]], dtype=float)
var = pd.DataFrame(index=["geneA", "geneB", "geneC"])
obs = pd.DataFrame(index=["cell1", "cell2", "cell3", "cell4"])
adata = ad.AnnData(X=csr_matrix(X), var=var, obs=obs)
return adata
def test_set_raw_if_necessary(synthetic_anndata):
data = synthetic_anndata.copy()
# The first 64 rows in synthetic data_processing are the entire data_processing in this test
# Check if they're integer:
# They are not all integer if we rely on the raw array—some zeros, but let's see:
# Actually, they *are* integers, so raw should be set.
result = set_raw_if_necessary(data)
assert result is not None
assert result.raw is not None
assert (result.raw.X != 0).toarray().any()
def test_initialize_processed_layer(synthetic_anndata):
data = synthetic_anndata.copy()
# Manually set data_processing.raw to be some integer data_processing so that the function can copy it
data.raw = data.copy()
# processed layer does not exist initially
assert "processed" not in data.layers
data = initialize_processed_layer(data)
assert "processed" in data.layers
# The processed layer should match data_processing.raw.X
assert (data.layers["processed"] != 0).toarray().any()
def test_filter_cells_by_gene_counts(synthetic_anndata):
data = synthetic_anndata.copy()
data.raw = data # So that processed can be set
data.layers["processed"] = data.X.copy()
# min_gene_counts = 2 (Example)
result = filter_cells_by_gene_counts(data, min_count=2)
# We expect that cell1 has total = 5, cell2 total=10, cell3=15, cell4=1 => so cell4 gets filtered out
assert result.n_obs == 3
assert "cell4" not in result.obs_names
def test_filter_cells_by_mitochondrial_fraction(synthetic_anndata):
data = synthetic_anndata.copy()
data.raw = data # needed so we can add a processed layer
data.var["feature_name"] = ["MT-GENE1", "geneB", "geneC"] # Suppose first gene is a MT gene
data.layers["processed"] = data.X.copy()
# If max_mito_prop is 0.25, we see if the fraction is bigger than 0.25
result = filter_cells_by_mitochondrial_fraction(data, max_mito_prop=0.25)
# Evaluate how many cells remain. The first gene is mitochondrial => row sums for that gene:
# cell1: 0 / row sum=5 => 0
# cell2: 4 / row sum=10 => 0.4 (excluded)
# cell3: 7 / row sum=15 => ~0.466.. (excluded)
# cell4: 0 / row sum=1 => 0
# So we expect cell1 and cell4 remain
assert result.n_obs == 2
assert "cell2" not in result.obs_names
assert "cell3" not in result.obs_names
assert "cell1" in result.obs_names
assert "cell4" in result.obs_names
def test_normalize_data_inplace(synthetic_anndata):
data = synthetic_anndata.copy()
data.layers["processed"] = data.X.copy()
# Convert to CSR
if not isinstance(data.layers["processed"], csr_matrix):
data.layers["processed"] = csr_matrix(data.layers["processed"])
normalize_data_inplace(data.layers["processed"], 1e4)
# After normalization, row sums ~ 1e4
row_sums = np.array(data.layers["processed"].sum(axis=1)).flatten()
for s in row_sums:
# row sums should be close to 10000
assert pytest.approx(s, 1e-4) == 10000.0
def test_log_transform_layer(synthetic_anndata):
data = synthetic_anndata.copy()
data.layers["processed"] = data.X.copy()
log_transform_layer(data, "processed")
# Just ensure that the data_processing was log1p-transformed
# e.g., check (some random cell, gene) is log(1 + original_value)
# cell2, gene1 was 4 => log(5) approx 1.609...
# Let's check:
val_after = data.layers["processed"][1, 0]
assert pytest.approx(val_after, 1e-5) == np.log1p(4)
def test_compute_and_save_medians(tmp_path, synthetic_anndata):
"""
Check that compute_and_save_medians writes out a JSON file with medians of non-zero entries.
"""
data = synthetic_anndata.copy()
data.layers["processed"] = data.X.copy()
# Suppose we want to save them
mock_data_path = str(tmp_path / "test.h5ad")
data.write_h5ad(mock_data_path)
hyperparams = {
"load_dir": str(tmp_path),
"save_dir": str(tmp_path),
"median_column": "index",
}
compute_and_save_medians(data, mock_data_path, hyperparams)
# check that the .json file was written
median_path = mock_data_path.replace(".h5ad", "_medians.json")
assert os.path.exists(median_path)
with open(median_path, "r") as f:
mdict = json.load(f)
# We expect some medians in that dictionary
assert len(mdict) == data.n_vars
def test_update_metadata(synthetic_anndata):
data = synthetic_anndata.copy()
metadata = {}
hyperparams = {"some": "arg"}
data.obs["some_col"] = [1, 2, 3, 4] # just to show we have 4 cells
new_meta = update_metadata(metadata, data, hyperparams)
assert new_meta.get("cell_count", None) == 4
assert "processings_args" in new_meta or "processing_args" in new_meta
def test_preprocess_pipeline_end_to_end(tmp_path):
"""
Test the entire preprocess pipeline with minimal hyperparameters on synthetic data_processing.
"""
# 1) Write out a small synthetic .h5ad and a metadata.json
# 1) Create a small synthetic dataset (10 cells × 5 genes here, for example)
X = np.random.rand(10, 5) # or integer counts if your pipeline expects that
# 2) Construct an AnnData in memory
adata = ad.AnnData(
X=X,
obs=pd.DataFrame(index=[f"cell_{i}" for i in range(X.shape[0])]),
var=pd.DataFrame(index=[f"gene_{j}" for j in range(X.shape[1])]),
)
# 3) Set raw to be a snapshot of the same data (like many pipelines do)
adata.raw = adata # This copies over current X, var, obs into .raw
# 4) Confirm: By default, .raw.X is a numpy array, not a sparse matrix:
print("raw.X is type:", type(adata.raw.X)) # typically <class 'numpy.ndarray'>
# 5) If your code wants .raw.X to be a CSR matrix, do the 'official' approach:
raw_adata = adata.raw.to_adata() # extract .raw as its own AnnData
raw_adata.X = csr_matrix(raw_adata.X) # now we can set .X to sparse
adata.raw = raw_adata # reassign
print("Now raw.X is type:", type(adata.raw.X)) # <class 'scipy.sparse.csr.csr_matrix'>
data_path = tmp_path / "before.h5ad"
adata.write_h5ad(data_path)
metadata = {"sample": "test_sample"}
metadata_path = tmp_path / "metadata.json"
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=4)
# 2) Minimal hyperparameters
hyperparameters = {
"load_dir": str(tmp_path),
"save_dir": str(tmp_path),
"reference_id_only": False,
"remove_assays": [],
"min_gene_counts": 0,
"max_mitochondrial_prop": None,
"hvg_method": None,
"normalized_total": None,
"median_dict": None,
"median_column": "index",
"log1p": False,
"compute_medians": False,
}
# 3) Run preprocess
returned = preprocess(str(data_path), str(metadata_path), hyperparameters)
assert returned is not None
# 4) Check that output file was created
after_path = str(data_path).replace(".h5ad", "_medians.json") # noqa: F841
# only if compute_medians was True
# but we had compute_medians=False, so let's instead check final .h5ad
final_h5ad = str(data_path).replace("before.h5ad", "before.h5ad") # or the same name if not changed
# Actually, the pipeline calls: data_path.replace(load_dir, save_dir), which is the same
# So let's just see if it wrote "before.h5ad" properly
assert os.path.exists(final_h5ad)
|