import gradio as gr import torch import torchvision.models as models import torchvision.transforms as transforms from PIL import Image from models.resnet34 import resnet34 # --- 1. Define CIFAR-100 Class Names --- # The classes are in order, so the model's output index will correspond to this list. cifar100_labels = [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] # --- 2. Load Your Model --- # Initialize the ResNet-34 architecture for 100 classes model = resnet34(num_classes=100, pretrained=False) model.eval() # IMPORTANT: Load your trained weights here. # Make sure you have a 'model.pth' file in your Hugging Face Space repository. # If your file has a different name, update the path accordingly. try: # Use torch.load with map_location to ensure the model loads correctly on a CPU-only environment. checkpoint = torch.load('resnet34_cifar100_frozen.pth', map_location=torch.device('cpu')) # Extract the model state dict from the checkpoint if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) print("Model weights loaded successfully.") except FileNotFoundError: print("WARNING: resnet34_cifar100_frozen.pth not found. The model is using random weights.") print("Please upload your trained model file to the Hugging Face Space.") # --- 3. Define Image Transformations --- # Your model was likely trained with specific transformations. # Adjust these to match the ones you used during training. preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.50707516,0.48654887,0.44091784], std=[0.26733429,0.25643846,0.27615047]), ]) # --- 4. Create the Prediction Function --- def predict(image: Image.Image): """ Takes a PIL image, preprocesses it, and returns the model's top 5 predictions. """ if image is None: return { "error": "No image provided" } # Preprocess the image img_tensor = preprocess(image) # Add a batch dimension (models expect a batch of images) img_tensor = img_tensor.unsqueeze(0) # Make prediction with torch.no_grad(): outputs = model(img_tensor) # Apply softmax to get probabilities probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Create a dictionary of labels and their probabilities confidences = {cifar100_labels[i]: float(probabilities[i]) for i in range(100)} return confidences # --- 5. Build the Gradio Interface --- title = "CIFAR-100 Image Classifier" description = """ Upload an image or use one of the examples below to see the model's prediction. This app uses a ResNet-34 model trained on the CIFAR-100 dataset. **IMPORTANT**: You must upload your own `model.pth` file for this space to work with your trained model. """ # Create the Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload an Image"), outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"), title=title, description=description, examples=[ ["examples/bicycle.png"], ["examples/house.png"], ["examples/plain.png"], ["examples/tulip.png"] ], allow_flagging="never", live=True ) # Launch the app iface.launch()