RxnIM / mllm /engine /builder.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
from functools import partial
from typing import Tuple, Dict, Any, Type
from transformers.trainer import DataCollator
from .shikra import ShikraTrainer
from .base_engine import TrainerForMMLLM, Seq2Seq2DataCollatorWithImage
TYPE2TRAINER = {
'shikra': ShikraTrainer,
}
def prepare_trainer_collator(
model_args,
preprocessor: Dict[str, Any],
collator_kwargs: Dict[str, Any]
) -> Tuple[Type[TrainerForMMLLM], Dict[str, DataCollator]]:
type_ = model_args.type
trainer_cls = TYPE2TRAINER[type_]
data_collator_func = partial(
Seq2Seq2DataCollatorWithImage,
preprocessor=preprocessor,
**collator_kwargs,
)
data_collator_dict = {
"train_collator": data_collator_func(inference_mode=False),
"eval_collator": data_collator_func(inference_mode=True),
}
return trainer_cls, data_collator_dict