Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow.keras as keras | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from huggingface_hub import from_pretrained_keras | |
| # download the already pushed model | |
| trained_models = [from_pretrained_keras("buio/attention_mil_classification")] | |
| POSITIVE_CLASS = 1 | |
| BAG_COUNT = 1000 | |
| VAL_BAG_COUNT = 300 | |
| BAG_SIZE = 3 | |
| PLOT_SIZE = 1 | |
| ENSEMBLE_AVG_COUNT = 1 | |
| def create_bags(input_data, input_labels, positive_class, bag_count, instance_count): | |
| # Set up bags. | |
| bags = [] | |
| bag_labels = [] | |
| # Normalize input data. | |
| input_data = np.divide(input_data, 255.0) | |
| # Count positive samples. | |
| count = 0 | |
| for _ in range(bag_count): | |
| # Pick a fixed size random subset of samples. | |
| index = np.random.choice(input_data.shape[0], instance_count, replace=False) | |
| instances_data = input_data[index] | |
| instances_labels = input_labels[index] | |
| # By default, all bags are labeled as 0. | |
| bag_label = 0 | |
| # Check if there is at least a positive class in the bag. | |
| if positive_class in instances_labels: | |
| # Positive bag will be labeled as 1. | |
| bag_label = 1 | |
| count += 1 | |
| bags.append(instances_data) | |
| bag_labels.append(np.array([bag_label])) | |
| print(f"Positive bags: {count}") | |
| print(f"Negative bags: {bag_count - count}") | |
| return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels)) | |
| # Load the MNIST dataset. | |
| (x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data() | |
| # Create validation data. | |
| val_data, val_labels = create_bags( | |
| x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE | |
| ) | |
| def predict(data, labels, trained_models): | |
| # Collect info per model. | |
| models_predictions = [] | |
| models_attention_weights = [] | |
| models_losses = [] | |
| models_accuracies = [] | |
| for model in trained_models: | |
| # Predict output classes on data. | |
| predictions = model.predict(data) | |
| models_predictions.append(predictions) | |
| # Create intermediate model to get MIL attention layer weights. | |
| intermediate_model = keras.Model(model.input, model.get_layer("alpha").output) | |
| # Predict MIL attention layer weights. | |
| intermediate_predictions = intermediate_model.predict(data) | |
| attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0)) | |
| models_attention_weights.append(attention_weights) | |
| model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"]) | |
| loss, accuracy = model.evaluate(data, labels, verbose=0) | |
| models_losses.append(loss) | |
| models_accuracies.append(accuracy) | |
| print( | |
| f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}" | |
| f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp." | |
| ) | |
| return ( | |
| np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT, | |
| np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT, | |
| ) | |
| def plot(data, labels, bag_class, predictions=None, attention_weights=None): | |
| """"Utility for plotting bags and attention weights. | |
| Args: | |
| data: Input data that contains the bags of instances. | |
| labels: The associated bag labels of the input data. | |
| bag_class: String name of the desired bag class. | |
| The options are: "positive" or "negative". | |
| predictions: Class labels model predictions. | |
| If you don't specify anything, ground truth labels will be used. | |
| attention_weights: Attention weights for each instance within the input data. | |
| If you don't specify anything, the values won't be displayed. | |
| """ | |
| labels = np.array(labels).reshape(-1) | |
| if bag_class == "positive": | |
| if predictions is not None: | |
| labels = np.where(predictions.argmax(1) == 1)[0] | |
| else: | |
| labels = np.where(labels == 1)[0] | |
| random_labels = np.random.choice(labels, PLOT_SIZE) | |
| bags = np.array(data)[:, random_labels] | |
| elif bag_class == "negative": | |
| if predictions is not None: | |
| labels = np.where(predictions.argmax(1) == 0)[0] | |
| else: | |
| labels = np.where(labels == 0)[0] | |
| random_labels = np.random.choice(labels, PLOT_SIZE) | |
| bags = np.array(data)[:, random_labels] | |
| else: | |
| print(f"There is no class {bag_class}") | |
| return | |
| print(f"The bag class label is {bag_class}") | |
| for i in range(PLOT_SIZE): | |
| figure = plt.figure(figsize=(8, 8)) #each image | |
| print(f"Bag number: {labels[i]}") | |
| for j in range(BAG_SIZE): | |
| image = bags[j][i] | |
| figure.add_subplot(1, BAG_SIZE, j + 1) | |
| plt.grid(False) | |
| plt.axis('off') | |
| if attention_weights is not None: | |
| plt.title(np.around(attention_weights[random_labels[i]][j], 2)) | |
| plt.imshow(image) | |
| plt.show() | |
| return figure | |
| # Evaluate and predict classes and attention scores on validation data. | |
| def predict_and_plot(class_): | |
| print('WTF') | |
| class_predictions, attention_params = predict(val_data, val_labels, trained_models) | |
| PLOT_SIZE = 1 | |
| return plot(val_data, val_labels, class_, | |
| predictions=class_predictions, | |
| attention_weights=attention_params) | |
| predict_and_plot('positive') | |
| inputs = gr.Radio(choices=['positive','negative']) | |
| outputs = gr.Plot(label='predicted bag') | |
| #title = "Heart Disease Classification 🩺❤️" | |
| #description = "Binary classification of structured data including numerical and categorical features." | |
| #article = "Author: <a href=\"https://huggingface.co/buio\">Marco Buiani</a>. Based on the <a href=\"https://keras.io/examples/structured_data/structured_data_classification_from_scratch/\">keras example</a> by <a href=\"https://twitter.com/fchollet\">François Chollet</a> Model Link: https://huggingface.co/buio/structured-data-classification" | |
| demo = gr.Interface(fn=predict_and_plot, inputs=inputs, outputs=outputs, allow_flagging='never') | |
| demo.launch(debug=True) |