| | import os |
| | import torch |
| | import stat |
| | import re |
| | import time |
| | import argparse |
| | import numpy as np |
| |
|
| | from functools import partial |
| | from typing import List, Tuple |
| |
|
| | import torch.distributed as dist |
| | from sat.helpers import print_rank0 |
| | from sat import mpu, get_args, get_tokenizer |
| | from utils import AdvancedBaseStrategy, BeamSearchStrategy |
| | from model_utils import MSAGPT, FineTuneMSAGPT |
| | from utils import chat_api |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | py_parser = argparse.ArgumentParser(add_help=False) |
| | py_parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.") |
| | py_parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.") |
| | py_parser.add_argument("--max-gen-length", type=int, default=512, help="The minimum length each blank should generate.") |
| | py_parser.add_argument("--is-valid", action="store_true", help="Print all output generated by beam search strategy.") |
| | py_parser.add_argument("--print-all-beams", action="store_true", help="Print all output generated by beam search strategy.") |
| | py_parser.add_argument("--multiline_stream", action="store_true", help="streaming multiline output.") |
| | py_parser.add_argument("--no-gap", action="store_true", help="do not generate gaps.") |
| | py_parser.add_argument("--from_pretrained", type=str, default="./checkpoints/MSAGPT", help='pretrained ckpt') |
| | py_parser.add_argument("--chinese", action='store_true', help='Chinese interface') |
| | py_parser.add_argument("--stream_chat", action='store_true', help='streaming output') |
| |
|
| |
|
| | py_parser = MSAGPT.add_model_specific_args(py_parser) |
| | known, args_list = py_parser.parse_known_args() |
| | args = get_args(args_list) |
| | args = argparse.Namespace(**vars(args), **vars(known)) |
| | model, args = MSAGPT.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {}) |
| | model.eval() |
| | rank = int(os.environ.get('RANK', 0)) |
| | world_size = int(os.environ.get('WORLD_SIZE', 1)) |
| | if torch.cuda.is_available(): |
| | model = model.to('cuda') |
| | from utils import proteinglm_tokenizer |
| | tokenizer = proteinglm_tokenizer() |
| |
|
| | end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")] |
| | |
| | invalid_slices = [0,26,28,29,30,31,32] |
| | if args.no_gap: |
| | invalid_slices.append(tokenizer.TokenToId('-')) |
| | if args.sampling_strategy == "BaseStrategy": |
| | assert not args.print_all_beams, "BaseStrategy don't support print all beams." |
| | strategy = AdvancedBaseStrategy( |
| | batch_size=1, invalid_slices = invalid_slices, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, min_gen_length=args.min_gen_length, no_repeat_ngram_size=args.no_repeat_ngram_size, end_tokens=end_tokens |
| | ) |
| | elif args.sampling_strategy == "BeamSearchStrategy": |
| | strategy = BeamSearchStrategy( |
| | 1, |
| | args.num_beams, |
| | length_penalty=args.length_penalty, |
| | consider_end=True, |
| | end_tokens=end_tokens, |
| | invalid_slices=invalid_slices, |
| | no_repeat_ngram_size=args.no_repeat_ngram_size, |
| | min_gen_length=args.min_gen_length, |
| | deterministic=True |
| | ) |
| | else: |
| | raise ValueError(f"unknown strategy {args.sampling_strategy}") |
| |
|
| |
|
| |
|
| | if args.input_source == 'chat': |
| | if args.chinese: |
| | if rank == 0: |
| | print('欢迎使用 MSAGPT-CLI ,输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以"<M>"相连),例如:"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG",其中"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG"为主序列,"VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG"为MSA prompt。 stop 终止程序'.center(20, "*")) |
| | else: |
| | if rank == 0: |
| | print('Welcome to MSAGPT-CLI. Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by "<M>"), for example: "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG", where "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG" is the main sequence, and "VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG" are MSA prompts. Type "stop" to end the program.'.center(20,"*")) |
| | with torch.no_grad(): |
| | while True: |
| | if args.chinese: |
| | if rank == 0: |
| | protein_input = input("请输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以'<M>'相连):") |
| | else: |
| | protein_input = None |
| | else: |
| | if rank == 0: |
| | protein_input = input("Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by '<M>': ") |
| | else: |
| | protein_input = None |
| | if world_size > 1: |
| | torch.distributed.broadcast_object(protein_input, 0) |
| | protein_input = protein_input.strip() |
| | assert protein_input is not None |
| |
|
| | if protein_input == 'stop': |
| | break |
| | |
| | try: |
| | response = chat_api( |
| | args=args, |
| | query=protein_input, |
| | model=model, |
| | tokenizer=tokenizer, |
| | strategy=strategy |
| | ) |
| | except Exception as e: |
| | print(e) |
| | break |
| | if rank == 0 and not args.stream_chat: |
| | if args.chinese: |
| | print(f"{'生成的MSA'.center(20, '*')}") |
| | else: |
| | print(f"{'Virtual MSA'.center(20, '*')}") |
| | if args.print_all_beams: |
| | for idx, gen in enumerate(response): |
| | out_str = f"Beam: {idx}".center(11,'@') |
| | print(out_str) |
| | for _ in gen: |
| | print(_) |
| | print() |
| | else: |
| | response = response[0] |
| | for _ in response: |
| | print(_) |
| | print() |
| | else: |
| | chat_api( |
| | args=args, |
| | model=model, |
| | tokenizer=tokenizer, |
| | strategy=strategy |
| | ) |