Safetensors
File size: 8,413 Bytes
4527b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# test_preprocess.py
import hashlib
import json
import os
from pathlib import Path

import anndata as ad
import numpy as np
import pandas as pd
import pytest
from scipy.sparse import csr_matrix

from teddy.data_processing.preprocessing.preprocess import (
    compute_and_save_medians,
    filter_cells_by_gene_counts,
    filter_cells_by_mitochondrial_fraction,
    filter_highly_variable_genes,
    initialize_processed_layer,
    log_transform_layer,
    normalize_data_inplace,
    preprocess,
    set_raw_if_necessary,
    update_metadata,
)


@pytest.fixture
def synthetic_anndata():
    """
    Returns a small synthetic AnnData object for testing.
    """
    X = np.array([[0, 2, 3], [4, 0, 6], [7, 8, 0], [0, 0, 1]], dtype=float)
    var = pd.DataFrame(index=["geneA", "geneB", "geneC"])
    obs = pd.DataFrame(index=["cell1", "cell2", "cell3", "cell4"])

    adata = ad.AnnData(X=csr_matrix(X), var=var, obs=obs)
    return adata


def test_set_raw_if_necessary(synthetic_anndata):
    data = synthetic_anndata.copy()
    # The first 64 rows in synthetic data_processing are the entire data_processing in this test
    # Check if they're integer:
    # They are not all integer if we rely on the raw array—some zeros, but let's see:
    # Actually, they *are* integers, so raw should be set.
    result = set_raw_if_necessary(data)
    assert result is not None
    assert result.raw is not None
    assert (result.raw.X != 0).toarray().any()


def test_initialize_processed_layer(synthetic_anndata):
    data = synthetic_anndata.copy()
    # Manually set data_processing.raw to be some integer data_processing so that the function can copy it
    data.raw = data.copy()

    # processed layer does not exist initially
    assert "processed" not in data.layers
    data = initialize_processed_layer(data)
    assert "processed" in data.layers
    # The processed layer should match data_processing.raw.X
    assert (data.layers["processed"] != 0).toarray().any()


def test_filter_cells_by_gene_counts(synthetic_anndata):
    data = synthetic_anndata.copy()
    data.raw = data  # So that processed can be set
    data.layers["processed"] = data.X.copy()

    # min_gene_counts = 2 (Example)
    result = filter_cells_by_gene_counts(data, min_count=2)
    # We expect that cell1 has total = 5, cell2 total=10, cell3=15, cell4=1 => so cell4 gets filtered out
    assert result.n_obs == 3
    assert "cell4" not in result.obs_names


def test_filter_cells_by_mitochondrial_fraction(synthetic_anndata):
    data = synthetic_anndata.copy()
    data.raw = data  # needed so we can add a processed layer
    data.var["feature_name"] = ["MT-GENE1", "geneB", "geneC"]  # Suppose first gene is a MT gene
    data.layers["processed"] = data.X.copy()

    # If max_mito_prop is 0.25, we see if the fraction is bigger than 0.25
    result = filter_cells_by_mitochondrial_fraction(data, max_mito_prop=0.25)
    # Evaluate how many cells remain. The first gene is mitochondrial => row sums for that gene:
    # cell1: 0 / row sum=5 => 0
    # cell2: 4 / row sum=10 => 0.4  (excluded)
    # cell3: 7 / row sum=15 => ~0.466.. (excluded)
    # cell4: 0 / row sum=1 => 0
    # So we expect cell1 and cell4 remain
    assert result.n_obs == 2
    assert "cell2" not in result.obs_names
    assert "cell3" not in result.obs_names
    assert "cell1" in result.obs_names
    assert "cell4" in result.obs_names


def test_normalize_data_inplace(synthetic_anndata):
    data = synthetic_anndata.copy()
    data.layers["processed"] = data.X.copy()
    # Convert to CSR
    if not isinstance(data.layers["processed"], csr_matrix):
        data.layers["processed"] = csr_matrix(data.layers["processed"])

    normalize_data_inplace(data.layers["processed"], 1e4)
    # After normalization, row sums ~ 1e4
    row_sums = np.array(data.layers["processed"].sum(axis=1)).flatten()
    for s in row_sums:
        # row sums should be close to 10000
        assert pytest.approx(s, 1e-4) == 10000.0


def test_log_transform_layer(synthetic_anndata):
    data = synthetic_anndata.copy()
    data.layers["processed"] = data.X.copy()
    log_transform_layer(data, "processed")
    # Just ensure that the data_processing was log1p-transformed
    # e.g., check (some random cell, gene) is log(1 + original_value)
    # cell2, gene1 was 4 => log(5) approx 1.609...
    # Let's check:
    val_after = data.layers["processed"][1, 0]
    assert pytest.approx(val_after, 1e-5) == np.log1p(4)


def test_compute_and_save_medians(tmp_path, synthetic_anndata):
    """
    Check that compute_and_save_medians writes out a JSON file with medians of non-zero entries.
    """
    data = synthetic_anndata.copy()
    data.layers["processed"] = data.X.copy()
    # Suppose we want to save them
    mock_data_path = str(tmp_path / "test.h5ad")
    data.write_h5ad(mock_data_path)

    hyperparams = {
        "load_dir": str(tmp_path),
        "save_dir": str(tmp_path),
        "median_column": "index",
    }
    compute_and_save_medians(data, mock_data_path, hyperparams)
    # check that the .json file was written
    median_path = mock_data_path.replace(".h5ad", "_medians.json")
    assert os.path.exists(median_path)
    with open(median_path, "r") as f:
        mdict = json.load(f)
    # We expect some medians in that dictionary
    assert len(mdict) == data.n_vars


def test_update_metadata(synthetic_anndata):
    data = synthetic_anndata.copy()
    metadata = {}
    hyperparams = {"some": "arg"}
    data.obs["some_col"] = [1, 2, 3, 4]  # just to show we have 4 cells
    new_meta = update_metadata(metadata, data, hyperparams)
    assert new_meta.get("cell_count", None) == 4
    assert "processings_args" in new_meta or "processing_args" in new_meta


def test_preprocess_pipeline_end_to_end(tmp_path):
    """
    Test the entire preprocess pipeline with minimal hyperparameters on synthetic data_processing.
    """
    # 1) Write out a small synthetic .h5ad and a metadata.json
    # 1) Create a small synthetic dataset (10 cells × 5 genes here, for example)
    X = np.random.rand(10, 5)  # or integer counts if your pipeline expects that

    # 2) Construct an AnnData in memory
    adata = ad.AnnData(
        X=X,
        obs=pd.DataFrame(index=[f"cell_{i}" for i in range(X.shape[0])]),
        var=pd.DataFrame(index=[f"gene_{j}" for j in range(X.shape[1])]),
    )

    # 3) Set raw to be a snapshot of the same data (like many pipelines do)
    adata.raw = adata  # This copies over current X, var, obs into .raw

    # 4) Confirm: By default, .raw.X is a numpy array, not a sparse matrix:
    print("raw.X is type:", type(adata.raw.X))  # typically <class 'numpy.ndarray'>

    # 5) If your code wants .raw.X to be a CSR matrix, do the 'official' approach:
    raw_adata = adata.raw.to_adata()       # extract .raw as its own AnnData
    raw_adata.X = csr_matrix(raw_adata.X)  # now we can set .X to sparse
    adata.raw = raw_adata                 # reassign
    print("Now raw.X is type:", type(adata.raw.X))  # <class 'scipy.sparse.csr.csr_matrix'>
    data_path = tmp_path / "before.h5ad"
    adata.write_h5ad(data_path)

    metadata = {"sample": "test_sample"}
    metadata_path = tmp_path / "metadata.json"
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=4)

    # 2) Minimal hyperparameters
    hyperparameters = {
        "load_dir": str(tmp_path),
        "save_dir": str(tmp_path),
        "reference_id_only": False,
        "remove_assays": [],
        "min_gene_counts": 0,
        "max_mitochondrial_prop": None,
        "hvg_method": None,
        "normalized_total": None,
        "median_dict": None,
        "median_column": "index",
        "log1p": False,
        "compute_medians": False,
    }

    # 3) Run preprocess
    returned = preprocess(str(data_path), str(metadata_path), hyperparameters)
    assert returned is not None

    # 4) Check that output file was created
    after_path = str(data_path).replace(".h5ad", "_medians.json")  # noqa: F841
    # only if compute_medians was True
    # but we had compute_medians=False, so let's instead check final .h5ad
    final_h5ad = str(data_path).replace("before.h5ad", "before.h5ad")  # or the same name if not changed
    # Actually, the pipeline calls: data_path.replace(load_dir, save_dir), which is the same
    # So let's just see if it wrote "before.h5ad" properly
    assert os.path.exists(final_h5ad)