| | import os |
| | import sys |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "6" |
| |
|
| | import torch |
| | import wandb |
| |
|
| | from entangledcell_module_unseen import EntangledNetTrainCellUnseen |
| | from entangledcell_module_three import EntangledNetTrainCellThree |
| |
|
| | |
| | from dataloaders.three_branch_data import ThreeBranchTahoeDataModule |
| | from dataloaders.clonidine_v2_data import ClonidineV2DataModule |
| |
|
| | from geo_metrics.metric_factory import DataManifoldMetric |
| | from pytorch_lightning.loggers import WandbLogger |
| | from pytorch_lightning import Trainer |
| |
|
| | from torchcfm.optimal_transport import OTPlanSampler |
| | from parser import parse_args |
| | from train_utils import load_config, merge_config |
| | from bias import BiasForceTransformer, BiasForceTransformerNoVel |
| |
|
| | def main(): |
| | |
| | args = parse_args() |
| | if args.config_path: |
| | config = load_config(args.config_path) |
| | args = merge_config(args, config) |
| | |
| | args.training = True |
| | args.save_dir = args.save_dir |
| | |
| | |
| | positions_dir = f"{args.save_dir}/positions" |
| | if not os.path.exists(positions_dir): |
| | os.makedirs(positions_dir) |
| | |
| | wandb.init(project="entangled-cell", |
| | config=args, |
| | name=args.run_name) |
| | |
| | torch.manual_seed(args.seed) |
| | |
| | ot_sampler = ( |
| | OTPlanSampler(method=args.optimal_transport_method) |
| | if args.optimal_transport_method != "None" |
| | else None |
| | ) |
| | |
| | |
| | if args.data_name == "trametinib": |
| | datamodule = ThreeBranchTahoeDataModule(args=args) |
| | else: |
| | datamodule = ClonidineV2DataModule(args=args) |
| | |
| | |
| | data_manifold_metric = DataManifoldMetric( |
| | args=args, |
| | skipped_time_points=[], |
| | datamodule=datamodule, |
| | ) |
| | |
| | if args.vel_conditioned: |
| | bias_net = BiasForceTransformer(args) |
| | else: |
| | print("Using no velocity conditioned model") |
| | bias_net = BiasForceTransformerNoVel(args) |
| | |
| | timepoint_data = datamodule.get_timepoint_data() |
| | |
| | if args.data_name == "trametinib": |
| | entangled_train = EntangledNetTrainCellThree(args=args, |
| | bias_net=bias_net, |
| | data_manifold_metric=data_manifold_metric, |
| | timepoint_data=timepoint_data, |
| | ot_sampler=ot_sampler, |
| | vel_conditioned=args.vel_conditioned) |
| | else: |
| | entangled_train = EntangledNetTrainCellUnseen(args=args, |
| | bias_net=bias_net, |
| | data_manifold_metric=data_manifold_metric, |
| | timepoint_data=timepoint_data, |
| | ot_sampler=ot_sampler, |
| | vel_conditioned=args.vel_conditioned) |
| | |
| | wandb_logger = WandbLogger() |
| | |
| | trainer = Trainer( |
| | max_epochs=args.num_rollouts, |
| | logger=wandb_logger, |
| | num_sanity_val_steps=0, |
| | default_root_dir=args.root_dir, |
| | gradient_clip_val=None, |
| | devices=[0], |
| | ) |
| | |
| | trainer.fit( |
| | entangled_train, datamodule=datamodule |
| | ) |
| | trainer.test(entangled_train, datamodule=datamodule) |
| |
|
| | if __name__ == "__main__": |
| | main() |