Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| def parse_filelist(filelist_path, split_char="|"): | |
| with open(filelist_path, encoding='utf-8') as f: | |
| filepaths_and_text = [line.strip().split(split_char) for line in f] | |
| return filepaths_and_text | |
| def load_model(model, saved_state_dict): | |
| state_dict = model.state_dict() | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| try: | |
| new_state_dict[k] = saved_state_dict[k] | |
| except: | |
| print("%s is not in the checkpoint" % k) | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict) | |
| return model | |
| def latest_checkpoint_path(dir_path, regex="grad_svc_*.pt"): | |
| f_list = glob.glob(os.path.join(dir_path, regex)) | |
| f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) | |
| x = f_list[-1] | |
| return x | |
| def load_checkpoint(logdir, model, num=None): | |
| if num is None: | |
| model_path = latest_checkpoint_path(logdir, regex="grad_svc_*.pt") | |
| else: | |
| model_path = os.path.join(logdir, f"grad_svc_{num}.pt") | |
| print(f'Loading checkpoint {model_path}...') | |
| model_dict = torch.load(model_path, map_location=lambda loc, storage: loc) | |
| model.load_state_dict(model_dict, strict=False) | |
| return model | |
| def save_figure_to_numpy(fig): | |
| data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return data | |
| def plot_tensor(tensor): | |
| plt.style.use('default') | |
| fig, ax = plt.subplots(figsize=(12, 3)) | |
| im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| data = save_figure_to_numpy(fig) | |
| plt.close() | |
| return data | |
| def save_plot(tensor, savepath): | |
| plt.style.use('default') | |
| fig, ax = plt.subplots(figsize=(12, 3)) | |
| im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| plt.savefig(savepath) | |
| plt.close() | |
| return | |
| def print_error(info): | |
| print(f"\033[31m {info} \033[0m") | |