Spaces:
Running
Running
| # Import Fast API | |
| from fastapi import FastAPI, Request, UploadFile, File | |
| from fastapi.templating import Jinja2Templates | |
| import base64 | |
| # Import bytes | |
| from io import BytesIO | |
| import os | |
| # Import logging | |
| import logging | |
| # Import utilities | |
| from src.utils.utils import IMAGE_FORMATS | |
| # Import machine learning | |
| from src.predict import predict | |
| from ultralytics import YOLO | |
| from huggingface_hub import hf_hub_download | |
| # Initialazing FastAPI application | |
| app = FastAPI() | |
| # Initialazing templates | |
| templates = Jinja2Templates(directory="templates") | |
| # Initialazing logger | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Loading YOLO model...") | |
| # Download YOLO model from Hugging Face Hub | |
| model_path = hf_hub_download( | |
| repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt" | |
| ) | |
| # Load YOLO model | |
| model = YOLO(model_path) | |
| # Index route | |
| async def root(request: Request): | |
| # Render index.html | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # Upload images decorator | |
| def predict_image(request: Request, file: UploadFile = File(...)): | |
| try: | |
| # Try to read the file | |
| contents = file.file.read() | |
| # Open file and write contents | |
| with open(file.filename, "wb") as f: | |
| f.write(contents) | |
| # Get image filename | |
| image = file.filename | |
| # Check if image format is valid | |
| if not image.endswith(IMAGE_FORMATS): | |
| # If not, raise an error | |
| raise ValueError("Invalid image format") | |
| except Exception as e: | |
| # If there is an error, return the error | |
| return {f"{e}"} | |
| finally: | |
| file.file.close() | |
| # Getting image path | |
| image = file.filename | |
| # Predicting | |
| results = predict(model, image) | |
| # TODO | |
| # extract extension from image and use it to save the image | |
| # Convert image to bytes | |
| img_bytes = BytesIO() | |
| results.save(img_bytes, "JPEG") | |
| img_bytes.seek(0) | |
| img_bytes = base64.b64encode(img_bytes.getvalue()).decode() | |
| try: | |
| os.remove(image) | |
| except Exception as e: | |
| logging.error(f"Error deleting image: {e}") | |
| return templates.TemplateResponse( | |
| "index.html", {"request": request, "img": img_bytes} | |
| ) | |