|
|
import torch |
|
|
from pathlib import Path |
|
|
from huggingface_hub import hf_hub_download |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from medmnist import INFO |
|
|
import gradio as gr |
|
|
import os |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from huggingface_hub import HfApi |
|
|
from datetime import datetime |
|
|
import io |
|
|
|
|
|
from model import resnet18, resnet50 |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") |
|
|
AUTH_TOKEN = os.getenv("APP_TOKEN") |
|
|
DATASET_REPO = os.getenv("Dataset_repo") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
MODEL = os.getenv("Model_repo") |
|
|
|
|
|
|
|
|
def load_model_from_hf( |
|
|
repo_id: str, |
|
|
filename: str, |
|
|
model_type: str, |
|
|
num_classes: int, |
|
|
in_channels: int, |
|
|
device: str, |
|
|
) -> torch.nn.Module: |
|
|
"""Load trained model from Hugging Face Hub. |
|
|
|
|
|
Args: |
|
|
repo_id: Hugging Face repository ID |
|
|
filename: Model checkpoint filename |
|
|
model_type: Type of model ('resnet18' or 'resnet50') |
|
|
num_classes: Number of output classes |
|
|
in_channels: Number of input channels |
|
|
device: Device to load model on |
|
|
|
|
|
Returns: |
|
|
Loaded model in eval mode |
|
|
""" |
|
|
print(f"Downloading model from Hugging Face: {repo_id}/{filename}") |
|
|
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
|
|
|
|
|
|
if model_type == "resnet18": |
|
|
model = resnet18(num_classes=num_classes, in_channels=in_channels) |
|
|
else: |
|
|
model = resnet50(num_classes=num_classes, in_channels=in_channels) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_preprocessing_pipeline() -> transforms.Compose: |
|
|
"""Get preprocessing pipeline for images.""" |
|
|
|
|
|
info = INFO["organamnist"] |
|
|
output_channels = info["n_channels"] |
|
|
|
|
|
mean = (0.5,) * output_channels |
|
|
std = (0.5,) * output_channels |
|
|
|
|
|
trans = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=mean, std=std), |
|
|
]) |
|
|
|
|
|
return trans |
|
|
def get_class_labels(data_flag: str = "organamnist") -> list[str]: |
|
|
"""Get class labels for MedMNIST dataset.""" |
|
|
|
|
|
info = INFO[data_flag] |
|
|
labels = info["label"] |
|
|
|
|
|
return labels |
|
|
|
|
|
def save_image_to_hf_folder(image_path, prediction_label): |
|
|
"""Upload image to HF dataset folder.""" |
|
|
api = HfApi() |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
|
|
|
metadata = f"prediction: {prediction_label}\ntimestamp: {timestamp}" |
|
|
metadata_path = f"{Path(image_path).stem}_metadata.txt" |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=image_path, |
|
|
path_in_repo=f"uploads/{timestamp}_{Path(image_path).name}", |
|
|
repo_id=DATASET_REPO, |
|
|
repo_type="dataset", |
|
|
token=HF_TOKEN |
|
|
) |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=io.BytesIO(metadata.encode()), |
|
|
path_in_repo=f"uploads/{timestamp}_{Path(image_path).stem}_metadata.txt", |
|
|
repo_id=DATASET_REPO, |
|
|
repo_type="dataset", |
|
|
token=HF_TOKEN |
|
|
) |
|
|
|
|
|
def classify_images(images) -> str: |
|
|
"""Classify images and return formatted HTML with embedded images.""" |
|
|
|
|
|
if images is None: |
|
|
return "<p>No images uploaded</p>" |
|
|
|
|
|
if isinstance(images, str): |
|
|
images = [images] |
|
|
|
|
|
html = "<div style='display: flex; flex-wrap: wrap; gap: 30px; padding: 20px; justify-content: center;'>" |
|
|
|
|
|
for image_path in images: |
|
|
|
|
|
img = Image.open(image_path).convert("L") |
|
|
input_tensor = preprocess(img).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor) |
|
|
probs = torch.nn.functional.softmax(output[0], dim=0) |
|
|
top_class = probs.argmax().item() |
|
|
|
|
|
label = class_labels[str(top_class)] |
|
|
|
|
|
filename = Path(image_path).name |
|
|
|
|
|
buffered = BytesIO() |
|
|
img.save(buffered, format="JPEG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
html += f""" |
|
|
<div style='border: 2px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9; width: 280px;'> |
|
|
<p style='font-size: 14px; color: #666; margin: 0 0 10px 0; text-align: center; font-weight: bold;'>{filename}</p> |
|
|
<img src='data:image/jpeg;base64,{img_str}' style='width: 250px; height: 250px; object-fit: contain; display: block; margin: 0 auto 10px;'> |
|
|
<p style='font-size: 18px; color: #0066cc; margin: 10px 0 0 0; text-align: center; font-weight: bold;'>{label}</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
save_image_to_hf_folder(image_path, label) |
|
|
|
|
|
html += "</div>" |
|
|
|
|
|
return html |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = load_model_from_hf( |
|
|
repo_id=MODEL, |
|
|
filename="resnet18_best.pth", |
|
|
model_type="resnet18", |
|
|
num_classes=11, |
|
|
in_channels=1, |
|
|
device=DEVICE, |
|
|
) |
|
|
preprocess = get_preprocessing_pipeline() |
|
|
class_labels = get_class_labels() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
gr.Markdown("<h1 style='text-align: center;'> MLOps project - MedMNIST dataset Image Classifier</h1>") |
|
|
|
|
|
gr.Markdown("This is a Gradio web application for MLOps course project. Given images are stored in our dataset. " \ |
|
|
"By uploading images you agrree that they will be stored by us and insures that they can be stored by us. " \ |
|
|
"If you somewhat passed the login and are not connected to the project, please do not upload any images. " ) |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
gr.Markdown("<h2 style='text-align: center;'> Upload Images</h2>") |
|
|
|
|
|
images_input = gr.File(file_count="multiple", file_types=["image"], label="Upload Images") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Classify") |
|
|
reset_btn = gr.Button("Reset") |
|
|
|
|
|
gr.Markdown("<h2 style='text-align: center;'> Results</h2>") |
|
|
|
|
|
output = gr.HTML(label="Results") |
|
|
|
|
|
def reset(): |
|
|
return None, "" |
|
|
|
|
|
submit_btn.click(classify_images, inputs=images_input, outputs=output) |
|
|
reset_btn.click(reset, outputs=[images_input, output]) |
|
|
|
|
|
|
|
|
|
|
|
server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") |
|
|
demo.launch( |
|
|
server_name=server_name, |
|
|
auth=[("user", AUTH_TOKEN)] if AUTH_TOKEN else None |
|
|
) |