Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| from collections import defaultdict | |
| import dijkprofile_annotator.preprocessing as preprocessing | |
| import dijkprofile_annotator.config as config | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| import torch | |
| import torch.nn.functional as F | |
| from sklearn.isotonic import IsotonicRegression | |
| from sklearn.preprocessing import MinMaxScaler, StandardScaler | |
| def extract_img(size, in_tensor): | |
| """ | |
| Args: | |
| size(int) : size of cut | |
| in_tensor(tensor) : tensor to be cut | |
| """ | |
| dim1 = in_tensor.size()[2] | |
| in_tensor = in_tensor[:, :, int((dim1-size)/2):int((size + (dim1-size)/2))] | |
| return in_tensor | |
| def ffill(arr): | |
| """Forward fill utility function. | |
| Args: | |
| arr (np.array): numpy array to fill | |
| Returns: | |
| np.array: filled array. | |
| """ | |
| mask = np.isnan(arr) | |
| idx = np.where(~mask, np.arange(mask.shape[1]), 0) | |
| np.maximum.accumulate(idx, axis=1, out=idx) | |
| out = arr[np.arange(idx.shape[0])[:,None], idx] | |
| return out | |
| def train_scaler(profile_dict, scaler_type='minmax'): | |
| """Train a scaler given a profile dict | |
| Args: | |
| profile_dict (dict): dict containing the profile heights and labels | |
| Returns: | |
| sklearn MinMaxScaler or StandardScaler: fitted scaler in sklearn format | |
| """ | |
| if scaler_type == 'minmax': | |
| scaler = MinMaxScaler(feature_range=(-1, 1)) # for neural networks -1,1 is better than 0,1 | |
| elif scaler_type == 'standard': | |
| scaler = StandardScaler() | |
| else: | |
| raise NotImplementedError(f"no scaler: {scaler}") | |
| randkey = random.choice(list(profile_dict.keys())) | |
| accumulator = np.zeros((len(profile_dict), profile_dict[randkey]['profile'].shape[0])) | |
| for i, key in enumerate(profile_dict.keys()): | |
| accumulator[i, :] = profile_dict[key]['profile'] | |
| scaler.fit(accumulator.reshape(-1, 1)) | |
| return scaler | |
| def get_class_dict(class_list): | |
| """Get correct class dicts and weights from config. | |
| Args: | |
| class_list (string): string representing the class mappings to use | |
| Raises: | |
| NotImplementedError: raise if an not implemented class mapping is passed | |
| Returns: | |
| (dict,dict,list): dict with class mappings, inverse of that dict, weights for each class. | |
| """ | |
| class_list = class_list.lower() | |
| if class_list == 'regional': | |
| class_dict = config.CLASS_DICT_REGIONAL | |
| inverse_class_dict = config.INVERSE_CLASS_DICT_REGIONAL | |
| class_weights = config.WEIGHT_DICT_REGIONAL | |
| elif class_list == 'simple': | |
| class_dict = config.CLASS_DICT_SIMPLE | |
| class_weights = config.WEIGHT_DICT_SIMPLE | |
| inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE | |
| elif class_list == 'berm': | |
| class_dict = config.CLASS_DICT_SIMPLE_BERM | |
| class_weights = config.WEIGHT_DICT_SIMPLE_BERM | |
| inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE_BERM | |
| elif class_list == 'sloot': | |
| class_dict = config.CLASS_DICT_SIMPLE_SLOOT | |
| class_weights = config.WEIGHT_DICT_SIMPLE_SLOOT | |
| inverse_class_dict = config.INVERSE_CLASS_DICT_SIMPLE_SLOOT | |
| elif class_list == 'full': | |
| class_dict = config.CLASS_DICT_FULL | |
| class_weights = config.WEIGHT_DICT_FULL | |
| inverse_class_dict = config.INVERSE_CLASS_DICT_FULL | |
| else: | |
| raise NotImplementedError(f"No configs found for class list of type: {class_list}") | |
| return class_dict, inverse_class_dict, class_weights | |
| def force_sequential_predictions(predictions, method='isotonic'): | |
| """Force the classes in the sample to always go up from left to right. This is | |
| makes sense because a higher class could never be left of a lower class in the | |
| representation chosen here. Two methods are available, Isotonic Regression and | |
| a group first method. I would use the Isotonic regression. | |
| Args: | |
| predictions (torch.Tensor): Tensor output of the model in shape (batch_size, channel_size, sample_size) | |
| method (str, optional): method to use for enforcing the sequentiality. Defaults to 'isotonic'. | |
| Raises: | |
| NotImplementedError: if the given method is not implemented | |
| Returns: | |
| torch.Tensor: Tensor in the same shape as the input but then with only increasing classes from left to right. | |
| """ | |
| predictions = predictions.detach().cpu() | |
| n_classes = predictions.shape[1] # 1 is the channel dimension | |
| if method == 'first': | |
| # loop over batch | |
| for j in range(predictions.shape[0]): | |
| pred = torch.argmax(predictions[j], dim=0) | |
| # construct dict of groups of start-end indices for class | |
| groups = defaultdict(list) | |
| current_class = pred[0] | |
| group_start_idx = 0 | |
| for i in range(1, len(pred)): | |
| if pred[i] != current_class: | |
| groups[current_class.item()].append((group_start_idx, i)) | |
| group_start_idx = i | |
| current_class = pred[i] | |
| # if the class occurs again later in the profile | |
| # discard this occurance of it | |
| new_pred = torch.zeros(len(pred)) | |
| last_index = 0 | |
| for class_n, group_tuples in sorted(groups.items()): | |
| for group_tuple in group_tuples: | |
| if group_tuple[0] >= last_index: | |
| new_pred[group_tuple[0]:group_tuple[1]] = class_n | |
| last_index = group_tuple[1] | |
| break | |
| # simple forward fill | |
| for i in range(1, len(new_pred)): | |
| if new_pred[i] == 0: | |
| new_pred[i] = new_pred[i-1] | |
| # encode back to one-hot tensor | |
| predictions[j] = F.one_hot(new_pred.to(torch.int64), num_classes=n_classes).permute(1,0) | |
| elif method == 'isotonic': | |
| for i in range(predictions.shape[0]): | |
| pred = torch.argmax(predictions[i], dim=0) | |
| x = np.arange(0,len(pred)) | |
| iso_reg = IsotonicRegression().fit(x, pred) | |
| new_pred = iso_reg.predict(x) | |
| new_pred = np.round(new_pred) | |
| # encode back to one-hot tensor | |
| new_pred = F.one_hot(torch.Tensor(new_pred).to(torch.int64), num_classes=n_classes).permute(1,0) | |
| predictions[i] = new_pred | |
| else: | |
| raise NotImplementedError(f"Unknown method: {method}") | |
| return predictions | |
| def visualize_prediction(heights, prediction, labels, location_name, class_list): | |
| """visualize a profile plus labels and prediction | |
| Args: | |
| heights (tensor): tensor containing the heights data of the profile | |
| prediction (tensor): tensor containing the predicted data of the profile | |
| labels (tensor): tensor containing the labels for each height point in heights | |
| location_name (str): name of the profile, just for visualization | |
| class_list (str): class mapping to use, determines which labels are visualized | |
| """ | |
| class_dict, inverse_class_dict, _ = get_class_dict(class_list) | |
| fig, ax = plt.subplots(figsize=(20,11)) | |
| plt.title(location_name) | |
| plt.plot(heights, label='profile') | |
| # change one-hot batched format to list of classes | |
| if prediction.dim() == 3: | |
| prediction = torch.argmax(torch.squeeze(prediction, dim=0), dim=0) | |
| if prediction.dim() == 2: | |
| # assuming channel first representation | |
| prediction = torch.argmax(prediction, dim=0) | |
| prediction = prediction.detach().cpu().numpy() | |
| # ax.set_ylim(top=np.max(heights), bottom=np.min(heights)) | |
| label_height = np.min(heights) | |
| n_labels = len(np.unique(labels)) | |
| label_height_distance = (np.max(heights) - np.min(heights))/(n_labels*2) | |
| cmap = sns.color_palette("Set2", len(set(class_dict.values()))) | |
| # plot actual labels | |
| prev_class_n = 999 | |
| for index, class_n in enumerate(labels): | |
| if class_n == 0: | |
| continue | |
| if class_n != prev_class_n: | |
| plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes | |
| plt.text(index, label_height, inverse_class_dict[class_n], rotation=0) | |
| label_height += label_height_distance | |
| prev_class_n = class_n | |
| # plot predicted points | |
| used_classes = [] | |
| prev_class_n = 999 | |
| for index, class_n in enumerate(prediction): | |
| if class_n == 0 or class_n in used_classes: | |
| continue | |
| if class_n != prev_class_n: | |
| plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(1,1))) # small dots | |
| plt.text(index, label_height, "predicted " + inverse_class_dict[class_n], rotation=0) | |
| label_height += label_height_distance | |
| used_classes.append(prev_class_n) | |
| prev_class_n = class_n | |
| plt.show() | |
| def visualize_sample(heights, labels, location_name, class_list): | |
| """visualize a profile and labels. | |
| Args: | |
| heights (tensor): tensor containing the heights data of the profile | |
| labels (tensor): tensor containing the labels for each height point in heights | |
| location_name (str): name of the profile, just for visualization | |
| class_list (str): class mapping to use, determines which labels are visualized | |
| """ | |
| class_dict, inverse_class_dict, _ = get_class_dict(class_list) | |
| fig, ax = plt.subplots(figsize=(20,11)) | |
| plt.title(location_name) | |
| plt.plot(heights, label='profile') | |
| # ax.set_ylim(top=np.max(heights), bottom=np.min(heights)) | |
| label_height = np.min(heights) | |
| n_labels = len(np.unique(labels)) | |
| label_height_distance = (np.max(heights) - np.min(heights))/(n_labels*2) | |
| cmap = sns.color_palette("Set2", len(set(class_dict.values()))) | |
| # plot actual labels | |
| prev_class_n = 999 | |
| for index, class_n in enumerate(labels): | |
| if class_n == 0: | |
| continue | |
| if class_n != prev_class_n: | |
| plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes | |
| plt.text(index, label_height, inverse_class_dict[class_n], rotation=0) | |
| label_height += label_height_distance | |
| prev_class_n = class_n | |
| plt.show() | |
| def visualize_files(linesfp, pointsfp, max_profile_size=512, class_list='simple', location_index=0, return_dict=False): | |
| """visualize profile lines and points filepaths. | |
| Args: | |
| linesfp (str): path to surfacelines file. | |
| pointsfp (str): path to points file. | |
| max_profile_size (int, optional): cutoff size of the profile, can leave on default here. Defaults to 512. | |
| class_list (str, optional): class mapping to use. Defaults to 'simple'. | |
| location_index (int, optional): index of profile to visualize.. Defaults to 0. | |
| return_dict (bool, optional): return the profile dict for faster visualization. Defaults to False. | |
| Returns: | |
| [dict, optional]: profile dict containing the profiles of the given files | |
| """ | |
| profile_label_dict = preprocessing.filepath_pair_to_labeled_sample(linesfp, | |
| pointsfp, | |
| max_profile_size=max_profile_size, | |
| class_list=class_list) | |
| location_name = list(profile_label_dict.keys())[location_index] | |
| heights = profile_label_dict[location_name]['profile'] | |
| labels = profile_label_dict[location_name]['label'] | |
| class_dict, inverse_class_dict, _ = get_class_dict(class_list) | |
| fig, ax = plt.subplots(figsize=(20,11)) | |
| plt.title(location_name) | |
| plt.plot(heights, label='profile') | |
| label_height = np.min(heights) | |
| n_labels = len(np.unique(labels)) | |
| label_height_distance = (np.max(heights) - np.min(heights))/(n_labels) | |
| cmap = sns.color_palette("Set2", len(set(class_dict.values()))) | |
| # plot actual labels | |
| prev_class_n = 999 | |
| for index, class_n in enumerate(labels): | |
| if class_n == 0: | |
| continue | |
| if class_n != prev_class_n: | |
| plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes | |
| plt.text(index, label_height, inverse_class_dict[class_n], rotation=0) | |
| label_height += label_height_distance | |
| prev_class_n = class_n | |
| plt.show() | |
| if return_dict: | |
| return profile_label_dict | |
| def visualize_dict(profile_label_dict, class_list='simple', location_index=0): | |
| """visualise profile with labels from profile_dict, profile specified by index. | |
| Args: | |
| profile_label_dict (dict): dict containing profiles and labels | |
| class_list (str, optional): class_mapping to use for visualization. Defaults to 'simple'. | |
| location_index (int, optional): specifies the index of the profile to visualize. Defaults to 0. | |
| """ | |
| location_name = list(profile_label_dict.keys())[location_index] | |
| heights = profile_label_dict[location_name]['profile'] | |
| labels = profile_label_dict[location_name]['label'] | |
| class_dict, inverse_class_dict, _ = get_class_dict(class_list) | |
| fig, ax = plt.subplots(figsize=(20,11)) | |
| plt.title(location_name) | |
| plt.plot(heights, label='profile') | |
| label_height = np.min(heights) | |
| n_labels = len(np.unique(labels)) | |
| label_height_distance = (np.max(heights) - np.min(heights))/(n_labels) | |
| cmap = sns.color_palette("Set2", len(set(class_dict.values()))) | |
| # plot actual labels | |
| prev_class_n = 999 | |
| for index, class_n in enumerate(labels): | |
| if class_n == 0: | |
| continue | |
| if class_n != prev_class_n: | |
| plt.axvline(index, 0,5, color=cmap[class_n], linestyle=(0,(5,10))) # loose dashes | |
| plt.text(index, label_height, inverse_class_dict[class_n], rotation=0) | |
| label_height += label_height_distance | |
| prev_class_n = class_n | |
| plt.show() |