Spaces:
Paused
Paused
| import os | |
| import torch | |
| from lib import utility | |
| def test(args): | |
| # conf | |
| config = utility.get_config(args) | |
| config.device_id = args.device_ids[0] | |
| # set environment | |
| utility.set_environment(config) | |
| config.init_instance() | |
| if config.logger is not None: | |
| config.logger.info("Loaded configure file %s: %s" % (args.config_name, config.id)) | |
| config.logger.info("\n" + "\n".join(["%s: %s" % item for item in config.__dict__.items()])) | |
| # model | |
| net = utility.get_net(config) | |
| model_path = os.path.join(config.model_dir, | |
| "train.pkl") if args.pretrained_weight is None else args.pretrained_weight | |
| if args.device_ids == [-1]: | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| else: | |
| checkpoint = torch.load(model_path) | |
| net.load_state_dict(checkpoint["net"]) | |
| if config.logger is not None: | |
| config.logger.info("Loaded network") | |
| # config.logger.info('Net flops: {} G, params: {} MB'.format(flops/1e9, params/1e6)) | |
| # data - test | |
| test_loader = utility.get_dataloader(config, "test") | |
| if config.logger is not None: | |
| config.logger.info("Loaded data from {:}".format(config.test_tsv_file)) | |
| # inference | |
| result, metrics = utility.forward(config, test_loader, net) | |
| if config.logger is not None: | |
| config.logger.info("Finished inference") | |
| # output | |
| for k, metric in enumerate(metrics): | |
| if config.logger is not None and len(metric) != 0: | |
| config.logger.info( | |
| "Tested {} dataset, the Size is {}, Metric: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format( | |
| config.type, len(test_loader.dataset), metric[0], metric[1], metric[2])) | |