Safetensors
TEDDY / teddy /tokenizer /tokenization_args.py
soumyatghosh's picture
Upload folder using huggingface_hub
4527b5f verified
"""
Module: tokenization_args.py
This module defines the `TokenizationArgs` dataclass, which encapsulates all the configurable parameters
required for the tokenization process in the TEDDY project. These parameters control how gene expression
data and biological annotations are tokenized for training.
Main Features:
- Provides a structured way to define and manage tokenization arguments.
- Supports configuration for gene selection, sequence truncation, and annotation inclusion.
- Includes options for handling PerturbSeq-specific flags and preprocessing steps.
- Allows for flexible mapping of biological annotations (e.g., disease, tissue, cell type, sex).
- Enables reproducibility through random seed control for gene selection.
Dependencies:
- `dataclasses`: For defining the `TokenizationArgs` dataclass.
Usage:
1. Import the `TokenizationArgs` class:
```python
from teddy.tokenizer.tokenization_args import TokenizationArgs"
```
2. Define tokenization arguments for a specific tokenization task:
```python
tokenization_args = TokenizationArgs(
tokenizer_name_or_path="path/to/tokenizer",
...
)
```
3. Pass the `tokenization_args` object to the tokenization function:
```python
tokenized_data = tokenize(data, tokenization_args)
```
"""
from dataclasses import dataclass, field
@dataclass
class TokenizationArgs:
tokenizer_name_or_path: str = field(metadata={"help": "Path to tokenizer used."})
gene_id_column: str = field(default="index", metadata={"help": "Field to use while accessing gene_ids for values."})
random_genes: bool = field(
default=False, metadata={"help": "whether we want random genes (True) selection or top expressed ones (False)"}
)
include_zero_genes: bool = field(default=False, metadata={"help": "Path to tokenizer used."})
add_cls: bool = field(default=False, metadata={"help": "Whether to add cls token to the start of the sequence."})
cls_token_id: int = field(default=None, metadata={"help": "Token id for cls token."})
perturbseq: bool = field(
default=False,
metadata={"help": "[PerturbSeq specific flag] Whether to add perturbation token during tokenization."},
)
tokenize_perturbseq_for_train: bool = field(
default=True,
metadata={
"help": "[PerturbSeq specific flag] Whether to tokenize labels to prepare data for training or to simply prepare tokennized perturbation flags for inference."
},
)
add_tokens: tuple = field(
default=(),
metadata={
"help": "Enter a tuple of string values for tokens. Will be pre-pended to the gene id sequence. Can be used instead of add_cls"
},
)
add_disease_annotation: bool = field(default=False)
label_column: str = field(
default=None, metadata={"help": "Which column to use as a label for a classification task."}
)
max_shard_samples: int = field(default=500, metadata={"help": "Number of samples included in sharding."})
max_seq_len: int = field(default=3001, metadata={"help": "Max seq length used for data processing"})
pad_length: int = field(default=3001, metadata={"help": "Pad sequence to x length so that all arrays in all batches are same length"})
truncation_method: str = field(
default="max",
metadata={
"help": "Indicate here how to restrict the number of genes to obtain max_seq_len from the full set of expresison values. Options: max, random"
},
)
bins: int = field(default=None, metadata={"help": "Number of bins used when required for data processing"})
rescale_labels: bool = field(default=False, metadata={"help": "If true, labels are binned or continiously ranked"})
continuous_rank: bool = field(
default=False, metadata={"help": "If true, gene values are overwritten with linspace[-1, 1] by rank."}
)
bio_annotations: bool = field(
default=False, metadata={"help": "If true, include disease, tissue type, cell type, sex"}
)
bio_annotation_masking_prob: float = field(
default=0.15, metadata={"help": "Mask annotation tokens with this probability"}
)
disease_mapping: str = field(
default=None, metadata={"help": "Path to json mapping from disease names to standard disease categories"}
)
tissue_mapping: str = field(
default=None, metadata={"help": "Path to json mapping from tissue names to standard tissue categories"}
)
cell_mapping: str = field(
default=None, metadata={"help": "Path to json mapping from cell type names to standard cell types"}
)
sex_mapping: str = field(
default=None, metadata={"help": "Path to json mapping from sex names to standard sex categories"}
)
load_dir: str = field(default="", metadata={"help": "Directory where h5ad data is loaded from."})
save_dir: str = field(
default="",
metadata={
"help": "Directory where tokenization function will save data. tokenize() saves tokenized in data_path.replace(load_dir, save_dir)"
},
)
gene_seed: int = field(default=42, metadata={"help": "Random seed that controls randomness of gene selection"})