pujanpaudel's picture
added new model
0f466f0 verified
raw
history blame
4.38 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import uvicorn
import numpy as np
import io
from PIL import Image
import base64
import torch
import torch.nn.functional as F
from transformers import ViTImageProcessor, SwinForImageClassification,AutoImageProcessor
import lightning as L
import uuid
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Label mappings
label2id = {'fake': 0, 'real': 1}
id2label = {0: 'fake', 1: 'real'}
# Load model
hyper_params = {
"MODEL_CKPT": "microsoft/swin-small-patch4-window7-224",
"num_labels": 2,
"id2label": id2label,
"label2id": label2id,
}
# Load the processor manually
vit_img_processor = AutoImageProcessor.from_pretrained('microsoft/swin-small-patch4-window7-224')
class DeepFakeModel(L.LightningModule):
def __init__(self, hyperparams: dict):
super().__init__()
self.model = SwinForImageClassification.from_pretrained(
hyperparams["MODEL_CKPT"],
num_labels=hyperparams["num_labels"],
id2label=hyperparams["id2label"],
label2id=hyperparams["label2id"],
ignore_mismatched_sizes=True
)
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, pixel_values):
output = self.model(pixel_values=pixel_values)
return output.logits
# Load trained model
model = DeepFakeModel(hyper_params)
state_dict = torch.load("deepfake_new_trained.pth", map_location=torch.device(device))
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print("Model loaded successfully")
# Initialize FastAPI app
app = FastAPI(title="DeepFake Detector API", description="API for detecting deepfake images", version="1.0.0")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Update with frontend server address in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ImageData(BaseModel):
image: str # Base64 encoded image
class AnalysisResult(BaseModel):
id: str
isDeepfake: bool
confidence: float
details: str
def preprocess_image(img):
img = vit_img_processor(img, return_tensors='pt')['pixel_values'].to(device)
return img
def predict_deepfake(image):
try:
img_tensor = preprocess_image(image)
with torch.inference_mode():
logits = model(img_tensor)
probabilities = F.softmax(logits, dim=-1)
confidence, predicted_index = torch.max(probabilities, dim=-1)
predicted_label = id2label[predicted_index.item()]
details = "Deepfake detected." if predicted_label == "fake" else "Image appears to be real."
return {
"id": str(uuid.uuid4()),
"isDeepfake": predicted_label == "fake",
"confidence": round(confidence.item() * 100, 2),
"details": details
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
@app.post("/api/analyze", response_model=AnalysisResult)
async def analyze_image(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
result = predict_deepfake(image)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/api/analyze-base64", response_model=AnalysisResult)
async def analyze_base64_image(data: ImageData):
try:
image_data = data.image.split("base64,")[-1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
result = predict_deepfake(image)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.get("/")
async def root():
return {"message": "DeepFake Detector API is running"}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)