| | import json |
| | import logging |
| | import os |
| | import random |
| | import datetime |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.utils.data import ConcatDataset |
| | from data import SeqRecDataset, ItemFeatDataset, ItemSearchDataset, FusionSeqRecDataset, SeqRecTestDataset, PreferenceObtainDataset |
| |
|
| |
|
| | def parse_global_args(parser): |
| |
|
| |
|
| | parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| |
|
| | parser.add_argument("--base_model", type=str, |
| | default="./llama-7b/", |
| | help="basic model path") |
| | parser.add_argument("--output_dir", type=str, |
| | default="./ckpt/", |
| | help="The output directory") |
| |
|
| |
|
| | return parser |
| |
|
| | def parse_dataset_args(parser): |
| | parser.add_argument("--data_path", type=str, default="", help="data directory") |
| | parser.add_argument("--tasks", type=str, default="seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain", |
| | help="Downstream tasks, separate by comma") |
| | parser.add_argument("--dataset", type=str, default="Games", help="Dataset name") |
| | parser.add_argument("--inter_path", type=str, default="", help="data directory") |
| | parser.add_argument("--feature_path", type=str, default="", help="data directory") |
| | parser.add_argument("--index_file", type=str, default=".index.json", help="the item indices file") |
| | parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers") |
| | parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor") |
| | |
| |
|
| | |
| | parser.add_argument("--max_his_len", type=int, default=20, |
| | help="the max number of items in history sequence, -1 means no limit") |
| | parser.add_argument("--add_prefix", action="store_true", default=False, |
| | help="whether add sequential prefix in history") |
| | parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history") |
| | parser.add_argument("--only_train_response", action="store_true", default=False, |
| | help="whether only train on responses") |
| |
|
| | parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1", |
| | help="the number of sampling prompts for each task") |
| | parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,100000,0,0", |
| | help="the number of sampling prompts for each task") |
| |
|
| | parser.add_argument("--valid_prompt_id", type=int, default=0, |
| | help="The prompt used for validation") |
| | parser.add_argument("--sample_valid", action="store_true", default=True, |
| | help="use sampled prompt for validation") |
| | parser.add_argument("--valid_prompt_sample_num", type=int, default=2, |
| | help="the number of sampling validation sequential recommendation prompts") |
| |
|
| | return parser |
| |
|
| | def parse_train_args(parser): |
| |
|
| | parser.add_argument("--optim", type=str, default="adamw_torch", help='The name of the optimizer') |
| | parser.add_argument("--epochs", type=int, default=4) |
| | parser.add_argument("--learning_rate", type=float, default=2e-5) |
| | parser.add_argument("--per_device_batch_size", type=int, default=8) |
| | parser.add_argument("--gradient_accumulation_steps", type=int, default=2) |
| | parser.add_argument("--logging_step", type=int, default=10) |
| | parser.add_argument("--model_max_length", type=int, default=2048) |
| | parser.add_argument("--weight_decay", type=float, default=0.01) |
| |
|
| | parser.add_argument("--lora_r", type=int, default=8) |
| | parser.add_argument("--lora_alpha", type=int, default=32) |
| | parser.add_argument("--lora_dropout", type=float, default=0.05) |
| | parser.add_argument("--lora_target_modules", type=str, |
| | default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj", help="separate by comma") |
| | parser.add_argument("--lora_modules_to_save", type=str, |
| | default="embed_tokens,lm_head", help="separate by comma") |
| |
|
| | parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="either training checkpoint or final adapter") |
| |
|
| | parser.add_argument("--warmup_ratio", type=float, default=0.01) |
| | parser.add_argument("--lr_scheduler_type", type=str, default="cosine") |
| | parser.add_argument("--save_and_eval_strategy", type=str, default="epoch") |
| | parser.add_argument("--save_and_eval_steps", type=int, default=1000) |
| | parser.add_argument("--fp16", action="store_true", default=False) |
| | parser.add_argument("--bf16", action="store_true", default=False) |
| | parser.add_argument("--deepspeed", type=str, default="./config/ds_z3_bf16.json") |
| |
|
| | return parser |
| |
|
| | def parse_test_args(parser): |
| |
|
| | parser.add_argument("--ckpt_path", type=str, |
| | default="", |
| | help="The checkpoint path") |
| | parser.add_argument("--lora", action="store_true", default=False) |
| | parser.add_argument("--filter_items", action="store_true", default=False, |
| | help="whether filter illegal items") |
| |
|
| | parser.add_argument("--results_file", type=str, |
| | default="./results/test-ddp.json", |
| | help="result output path") |
| |
|
| | parser.add_argument("--test_batch_size", type=int, default=1) |
| | parser.add_argument("--num_beams", type=int, default=20) |
| | parser.add_argument("--sample_num", type=int, default=-1, |
| | help="test sample number, -1 represents using all test data") |
| | parser.add_argument("--gpu_id", type=int, default=0, |
| | help="GPU ID when testing with single GPU") |
| | parser.add_argument("--test_prompt_ids", type=str, default="0", |
| | help="test prompt ids, separate by comma. 'all' represents using all") |
| | parser.add_argument("--metrics", type=str, default="hit@1,hit@5,hit@10,ndcg@5,ndcg@10", |
| | help="test metrics, separate by comma") |
| | parser.add_argument("--test_task", type=str, default="SeqRec") |
| |
|
| |
|
| | return parser |
| |
|
| |
|
| | def get_local_time(): |
| | cur = datetime.datetime.now() |
| | cur = cur.strftime("%b-%d-%Y_%H-%M-%S") |
| |
|
| | return cur |
| |
|
| |
|
| | def set_seed(seed): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.benchmark = False |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.enabled = False |
| |
|
| | def ensure_dir(dir_path): |
| |
|
| | os.makedirs(dir_path, exist_ok=True) |
| |
|
| |
|
| | def load_datasets(args): |
| |
|
| | tasks = args.tasks.split(",") |
| |
|
| | train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")] |
| | assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number" |
| | train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")] |
| | assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number" |
| |
|
| | train_datasets = [] |
| | for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num): |
| | if task.lower() == "seqrec": |
| | dataset = SeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) |
| |
|
| | elif task.lower() == "item2index" or task.lower() == "index2item": |
| | dataset = ItemFeatDataset(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) |
| |
|
| | elif task.lower() == "fusionseqrec": |
| | dataset = FusionSeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) |
| |
|
| | elif task.lower() == "itemsearch": |
| | dataset = ItemSearchDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) |
| |
|
| | elif task.lower() == "preferenceobtain": |
| | dataset = PreferenceObtainDataset(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) |
| |
|
| | else: |
| | raise NotImplementedError |
| | train_datasets.append(dataset) |
| |
|
| | train_data = ConcatDataset(train_datasets) |
| |
|
| | valid_data = SeqRecDataset(args,"valid",args.valid_prompt_sample_num) |
| |
|
| | return train_data, valid_data |
| |
|
| | def load_test_dataset(args): |
| |
|
| | if args.test_task.lower() == "seqrec": |
| | test_data = SeqRecDataset(args, mode="test", sample_num=args.sample_num) |
| | |
| | elif args.test_task.lower() == "itemsearch": |
| | test_data = ItemSearchDataset(args, mode="test", sample_num=args.sample_num) |
| | elif args.test_task.lower() == "fusionseqrec": |
| | test_data = FusionSeqRecDataset(args, mode="test", sample_num=args.sample_num) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return test_data |
| |
|
| | def load_json(file): |
| | with open(file, 'r') as f: |
| | data = json.load(f) |
| | return data |
| |
|