vickyvigneshmass commited on
Commit
3acddcc
·
verified ·
1 Parent(s): 9d1540f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -25
app.py CHANGED
@@ -1,38 +1,60 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
  from transformers import CLIPProcessor, CLIPModel
3
- from PIL import Image, UnidentifiedImageError
4
  import torch
5
  import io
 
 
 
6
 
 
7
  app = FastAPI()
8
 
9
- # Load the CLIP model and processor
10
- model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
11
- processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
12
 
13
- @app.get("/")
14
- def home():
15
- return {"message": "CLIP FastAPI app is running!"}
16
 
17
- @app.post("/search/")
18
- async def search_image(file: UploadFile = File(...), query: str = "a photo"):
19
- try:
20
- # Read and decode image
21
- contents = await file.read()
22
- image = Image.open(io.BytesIO(contents)).convert("RGB")
23
- except UnidentifiedImageError:
24
- raise HTTPException(status_code=400, detail="Invalid image file format.")
25
-
26
- # Preprocess image and text
27
- inputs = processor(text=[query], images=image, return_tensors="pt", padding=True)
28
-
29
- # Forward pass through the model
30
  with torch.no_grad():
31
- outputs = model(**inputs)
32
- logits_per_image = outputs.logits_per_image
33
- probs = logits_per_image.softmax(dim=1)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  return {
36
  "query": query,
37
- "match_confidence": float(probs[0][0])
 
 
 
38
  }
 
1
+ from fastapi import FastAPI, UploadFile, File, Form
2
  from transformers import CLIPProcessor, CLIPModel
3
+ from PIL import Image
4
  import torch
5
  import io
6
+ import uuid
7
+ import chromadb
8
+ from chromadb.config import Settings
9
 
10
+ # Initialize FastAPI
11
  app = FastAPI()
12
 
13
+ # Load CLIP model and processor
14
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
 
17
+ # Initialize ChromaDB
18
+ chroma_client = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory="./chroma_storage"))
19
+ collection = chroma_client.get_or_create_collection(name="images")
20
 
21
+ # Function to extract image embeddings
22
+ def get_image_embedding(image: Image.Image):
23
+ inputs = processor(images=image, return_tensors="pt")
24
+ with torch.no_grad():
25
+ embeddings = model.get_image_features(**inputs)
26
+ embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
27
+ return embeddings[0].tolist()
28
+
29
+ # Function to extract text embeddings
30
+ def get_text_embedding(text: str):
31
+ inputs = processor(text=[text], return_tensors="pt", padding=True)
 
 
32
  with torch.no_grad():
33
+ embeddings = model.get_text_features(**inputs)
34
+ embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
35
+ return embeddings[0].tolist()
36
 
37
+ @app.get("/")
38
+ def root():
39
+ return {"message": "CLIP + ChromaDB image-text similarity search"}
40
+
41
+ @app.post("/add-image/")
42
+ async def add_image(file: UploadFile = File(...), label: str = Form(...)):
43
+ contents = await file.read()
44
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
45
+ embedding = get_image_embedding(image)
46
+ uid = str(uuid.uuid4())
47
+ collection.add(documents=[label], embeddings=[embedding], ids=[uid], metadatas=[{"label": label}])
48
+ return {"message": f"Image '{label}' added with ID {uid}"}
49
+
50
+ @app.post("/search/")
51
+ async def search_text(query: str = Form(...), top_k: int = 3):
52
+ embedding = get_text_embedding(query)
53
+ results = collection.query(query_embeddings=[embedding], n_results=top_k)
54
  return {
55
  "query": query,
56
+ "results": [
57
+ {"label": doc, "score": score}
58
+ for doc, score in zip(results["documents"][0], results["distances"][0])
59
+ ]
60
  }