vickyvigneshmass commited on
Commit
9d1540f
·
verified ·
1 Parent(s): 64efb08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -1,12 +1,12 @@
1
- from fastapi import FastAPI, UploadFile, File
2
  from transformers import CLIPProcessor, CLIPModel
3
- from PIL import Image
4
  import torch
5
  import io
6
 
7
  app = FastAPI()
8
 
9
- # Load a compatible CLIP model and processor
10
  model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
11
  processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
12
 
@@ -16,17 +16,21 @@ def home():
16
 
17
  @app.post("/search/")
18
  async def search_image(file: UploadFile = File(...), query: str = "a photo"):
19
- # Load the uploaded image
20
- contents = await file.read()
21
- image = Image.open(io.BytesIO(contents)).convert("RGB")
 
 
 
22
 
23
  # Preprocess image and text
24
  inputs = processor(text=[query], images=image, return_tensors="pt", padding=True)
25
 
26
- # Forward pass
27
- outputs = model(**inputs)
28
- logits_per_image = outputs.logits_per_image
29
- probs = logits_per_image.softmax(dim=1)
 
30
 
31
  return {
32
  "query": query,
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
  from transformers import CLIPProcessor, CLIPModel
3
+ from PIL import Image, UnidentifiedImageError
4
  import torch
5
  import io
6
 
7
  app = FastAPI()
8
 
9
+ # Load the CLIP model and processor
10
  model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
11
  processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
12
 
 
16
 
17
  @app.post("/search/")
18
  async def search_image(file: UploadFile = File(...), query: str = "a photo"):
19
+ try:
20
+ # Read and decode image
21
+ contents = await file.read()
22
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
23
+ except UnidentifiedImageError:
24
+ raise HTTPException(status_code=400, detail="Invalid image file format.")
25
 
26
  # Preprocess image and text
27
  inputs = processor(text=[query], images=image, return_tensors="pt", padding=True)
28
 
29
+ # Forward pass through the model
30
+ with torch.no_grad():
31
+ outputs = model(**inputs)
32
+ logits_per_image = outputs.logits_per_image
33
+ probs = logits_per_image.softmax(dim=1)
34
 
35
  return {
36
  "query": query,