| | import ClassUtils |
| | import LoadUtils |
| |
|
| | import torch |
| | import torchvision |
| | import torchvision.models as models |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import random |
| |
|
| | import warnings |
| |
|
| | |
| | warnings.filterwarnings( |
| | action='ignore', |
| | category=DeprecationWarning, |
| | module=r'.*' |
| | ) |
| |
|
| | vgg16_state_path = "VGG16_Full_State_Dict.pth" |
| | |
| | mobileNet_path = "MobileNetV3_state_dict_big_train.pth" |
| | data_path = "zebra_annotations/classification_data" |
| |
|
| | classify = None |
| | transform = None |
| |
|
| | |
| | def load_vgg_classifier(state_dict_path): |
| | |
| | model = models.vgg16() |
| |
|
| | |
| | model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2) |
| | state_dict = torch.load(state_dict_path, weights_only=True) |
| | model.load_state_dict(state_dict) |
| |
|
| | model.eval() |
| |
|
| | return model |
| |
|
| | |
| | |
| | |
| | def partial_vgg_load(classifier_state_dict_path): |
| | model = models.vgg16(weights=models.VGG16_Weights.DEFAULT) |
| |
|
| | model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2) |
| | model.classifier.load_state_dict(classifier_state_dict_path) |
| |
|
| | model.eval() |
| |
|
| | return model |
| |
|
| | |
| | def load_resnet_classifier(state_dict_path): |
| | |
| | resnet = models.resnet18(pretrained=True) |
| | resnet.fc = torch.nn.Linear(resnet.fc.in_features, 1) |
| |
|
| | state_dict = torch.load(state_dict_path, weights_only=True) |
| | resnet.load_state_dict(state_dict) |
| | |
| | resnet.eval() |
| | return resnet |
| |
|
| | |
| | |
| | def load_mobileNet_classifier(state_dict_path): |
| | |
| | model = models.mobilenet_v3_small() |
| | model.classifier[3] = torch.nn.Linear(model.classifier[3].in_features, 2) |
| |
|
| | state_dict = torch.load(state_dict_path, weights_only=True) |
| | model.load_state_dict(state_dict) |
| |
|
| | model.eval() |
| | return model |
| |
|
| | |
| | |
| |
|
| | classify = load_mobileNet_classifier(mobileNet_path) |
| | transform = ClassUtils.mob3_transform |
| |
|
| |
|
| |
|
| | |
| | |
| | def infer(image, infer_model=classify, infer_transform=transform): |
| |
|
| | |
| | |
| | |
| | |
| | if infer_model is None or infer_transform is None: |
| | raise TypeError("Error: The inference classes have not been initialised properly.") |
| | if not torch.is_tensor(image): |
| | image = infer_transform(image) |
| | |
| | |
| | if len(image.shape) <= 3: |
| | image = image.unsqueeze(0) |
| |
|
| | logit_pred = infer_model(image) |
| |
|
| | probs = 1 / (1 + np.exp(-logit_pred.detach().numpy())) |
| | |
| | return probs |
| |
|
| |
|
| | |
| | |
| | def PIL_infer(image, threshold=0.35): |
| | tensor_im = torchvision.transforms.functional.pil_to_tensor(image).float()/ 255 |
| | prediction = infer(tensor_im) |
| | classification = prediction[0][0] > threshold |
| | return classification |
| |
|
| | |
| | def infer_and_display(image, threshold, actual_label, onlyWrong=False): |
| | probability = infer(image) |
| | prediction = probability > threshold |
| | is_correct = (actual_label[0] == 1) == prediction |
| |
|
| | if onlyWrong and is_correct: |
| | return prediction |
| | |
| | plt.imshow(torch.permute(image, (1, 2, 0)).detach().numpy()) |
| | plt.title(f"Prediction: {prediction[0][0]} with confidence {probability[0][0]}%, Actual: {actual_label[0] == 1}") |
| | plt.axis("off") |
| | plt.show() |
| |
|
| | return probability |
| |
|
| |
|
| | |
| | def example_init(examples=20, display=True): |
| | dataset = ClassUtils.CrosswalkDataset(data_path) |
| | |
| | random_points = [random.randint(0, len(dataset)-1) for i in range(examples)] |
| | correct, incorrect, falsepos, falseneg = 0, 0, 0, 0 |
| | for point in random_points: |
| | image, label = dataset[point] |
| |
|
| | class_guess = [0, 1] |
| | if infer(image)[0][0] > 0.5: |
| | class_guess = [1, 0] |
| | if class_guess == label.tolist(): |
| | correct += 1 |
| | else: |
| | if class_guess[0]: |
| | falsepos += 1 |
| | else: |
| | falseneg += 1 |
| | incorrect += 1 |
| | |
| | if display: |
| | print(f"Prediction of {infer_and_display(image, 0.4, label)}% of a crosswalk (Crosswalk: {label[0]==1})") |
| | print(f"correct: {correct}, incorrect: {incorrect}, of which false positives were {falsepos} and false negatives were {falseneg}") |
| |
|
| | if __name__ == "__main__": |
| | example_init(examples=200,display=False) |
| |
|
| | else: |
| | print(f"Module: [{__name__}] has been loaded") |
| |
|
| |
|
| |
|