Safetensors
File size: 5,238 Bytes
4527b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Module: tokenization_args.py

This module defines the `TokenizationArgs` dataclass, which encapsulates all the configurable parameters
required for the tokenization process in the TEDDY project. These parameters control how gene expression
data and biological annotations are tokenized for training.

Main Features:
- Provides a structured way to define and manage tokenization arguments.
- Supports configuration for gene selection, sequence truncation, and annotation inclusion.
- Includes options for handling PerturbSeq-specific flags and preprocessing steps.
- Allows for flexible mapping of biological annotations (e.g., disease, tissue, cell type, sex).
- Enables reproducibility through random seed control for gene selection.

Dependencies:
- `dataclasses`: For defining the `TokenizationArgs` dataclass.

Usage:
1. Import the `TokenizationArgs` class:
    ```python
    from teddy.tokenizer.tokenization_args import TokenizationArgs"
    ```
2. Define tokenization arguments for a specific tokenization task:
    ```python
    tokenization_args = TokenizationArgs(
    tokenizer_name_or_path="path/to/tokenizer",
    ...
    )
    ```
3. Pass the `tokenization_args` object to the tokenization function:
    ```python
    tokenized_data = tokenize(data, tokenization_args)
    ```
"""

from dataclasses import dataclass, field


@dataclass
class TokenizationArgs:
    tokenizer_name_or_path: str = field(metadata={"help": "Path to tokenizer used."})
    gene_id_column: str = field(default="index", metadata={"help": "Field to use while accessing gene_ids for values."})
    random_genes: bool = field(
        default=False, metadata={"help": "whether we want random genes (True) selection or top expressed ones (False)"}
    )
    include_zero_genes: bool = field(default=False, metadata={"help": "Path to tokenizer used."})
    add_cls: bool = field(default=False, metadata={"help": "Whether to add cls token to the start of the sequence."})
    cls_token_id: int = field(default=None, metadata={"help": "Token id for cls token."})
    perturbseq: bool = field(
        default=False,
        metadata={"help": "[PerturbSeq specific flag] Whether to add perturbation token during tokenization."},
    )
    tokenize_perturbseq_for_train: bool = field(
        default=True,
        metadata={
            "help": "[PerturbSeq specific flag] Whether to tokenize labels to prepare data for training or to simply prepare tokennized perturbation flags for inference."
        },
    )
    add_tokens: tuple = field(
        default=(),
        metadata={
            "help": "Enter a tuple of string values for tokens. Will be pre-pended to the gene id sequence. Can be used instead of add_cls"
        },
    )

    add_disease_annotation: bool = field(default=False)

    label_column: str = field(
        default=None, metadata={"help": "Which column to use as a label for a classification task."}
    )
    max_shard_samples: int = field(default=500, metadata={"help": "Number of samples included in sharding."})
    max_seq_len: int = field(default=3001, metadata={"help": "Max seq length used for data processing"})
    pad_length: int = field(default=3001, metadata={"help": "Pad sequence to x length so that all arrays in all batches are same length"})
    truncation_method: str = field(
        default="max",
        metadata={
            "help": "Indicate here how to restrict the number of genes to obtain max_seq_len from the full set of expresison values. Options: max, random"
        },
    )
    bins: int = field(default=None, metadata={"help": "Number of bins used when required for data processing"})

    rescale_labels: bool = field(default=False, metadata={"help": "If true, labels are binned or continiously ranked"})

    continuous_rank: bool = field(
        default=False, metadata={"help": "If true, gene values are overwritten with linspace[-1, 1] by rank."}
    )

    bio_annotations: bool = field(
        default=False, metadata={"help": "If true, include disease, tissue type, cell type, sex"}
    )

    bio_annotation_masking_prob: float = field(
        default=0.15, metadata={"help": "Mask annotation tokens with this probability"}
    )

    disease_mapping: str = field(
        default=None, metadata={"help": "Path to json mapping from disease names to standard disease categories"}
    )

    tissue_mapping: str = field(
        default=None, metadata={"help": "Path to json mapping from tissue names to standard tissue categories"}
    )

    cell_mapping: str = field(
        default=None, metadata={"help": "Path to json mapping from cell type names to standard cell types"}
    )

    sex_mapping: str = field(
        default=None, metadata={"help": "Path to json mapping from sex names to standard sex categories"}
    )

    load_dir: str = field(default="", metadata={"help": "Directory where h5ad data is loaded from."})

    save_dir: str = field(
        default="",
        metadata={
            "help": "Directory where tokenization function will save data. tokenize() saves tokenized in data_path.replace(load_dir, save_dir)"
        },
    )

    gene_seed: int = field(default=42, metadata={"help": "Random seed that controls randomness of gene selection"})