import torch from pathlib import Path from huggingface_hub import hf_hub_download from PIL import Image from torchvision import transforms from medmnist import INFO import gradio as gr import os import base64 from io import BytesIO from huggingface_hub import HfApi from datetime import datetime import io from model import resnet18, resnet50 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") AUTH_TOKEN = os.getenv("APP_TOKEN")#to acces the app DATASET_REPO = os.getenv("Dataset_repo") #"G44mlops/API_received" HF_TOKEN = os.getenv("HF_TOKEN") #to acces dataset repo MODEL = os.getenv("Model_repo")#"G44mlops/ResNet-medmnist" #taken from Mikolaj code with closed PR def load_model_from_hf( repo_id: str, filename: str, model_type: str, num_classes: int, in_channels: int, device: str, ) -> torch.nn.Module: """Load trained model from Hugging Face Hub. Args: repo_id: Hugging Face repository ID filename: Model checkpoint filename model_type: Type of model ('resnet18' or 'resnet50') num_classes: Number of output classes in_channels: Number of input channels device: Device to load model on Returns: Loaded model in eval mode """ print(f"Downloading model from Hugging Face: {repo_id}/{filename}") checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename) # Create model if model_type == "resnet18": model = resnet18(num_classes=num_classes, in_channels=in_channels) else: model = resnet50(num_classes=num_classes, in_channels=in_channels) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval() return model #taken from Mikolaj code with closed PR # Image preprocessing pipeline (basic so far, can be improved) def get_preprocessing_pipeline() -> transforms.Compose: """Get preprocessing pipeline for images.""" #getting information on number of image channels (RGB or Grayscale) for trained model info = INFO["organamnist"] # Using organamnist as reference output_channels = info["n_channels"] # RGB or Grayscale #chosing 'standard' mean and std values for normalization if dataset statistics are not available mean = (0.5,) * output_channels std = (0.5,) * output_channels #preparing transformation pipeline trans = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) #returning the transformation pipeline return trans def get_class_labels(data_flag: str = "organamnist") -> list[str]: """Get class labels for MedMNIST dataset.""" #retrieving dataset info info = INFO[data_flag] labels = info["label"] #returning class labels return labels def save_image_to_hf_folder(image_path, prediction_label): """Upload image to HF dataset folder.""" api = HfApi() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create a text file with metadata metadata = f"prediction: {prediction_label}\ntimestamp: {timestamp}" metadata_path = f"{Path(image_path).stem}_metadata.txt" # Upload image api.upload_file( path_or_fileobj=image_path, path_in_repo=f"uploads/{timestamp}_{Path(image_path).name}", repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN ) # Upload metadata as separate file api.upload_file( path_or_fileobj=io.BytesIO(metadata.encode()), path_in_repo=f"uploads/{timestamp}_{Path(image_path).stem}_metadata.txt", repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN ) def classify_images(images) -> str: """Classify images and return formatted HTML with embedded images.""" # Handle case with no images if images is None: return "
No images uploaded
" # Ensure images is a list if(case when only one image is uploaded is problematic without it) if isinstance(images, str): images = [images] #creating HTML structure for results html = "{filename}
{label}