vickyvigneshmass commited on
Commit
a5faf3c
·
verified ·
1 Parent(s): 102bd9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -58
app.py CHANGED
@@ -1,60 +1,7 @@
1
- from fastapi import FastAPI, Query, HTTPException
2
- from transformers import CLIPModel, CLIPProcessor
3
- import torch
4
- import os
5
 
6
- app = FastAPI(title="CLIP-based Document Retrieval API")
7
 
8
- # Load model and processor (requires Pillow)
9
- try:
10
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
11
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
- model.eval()
13
- except ImportError as e:
14
- raise RuntimeError(
15
- "Missing dependencies. Make sure to install 'pillow'.\nInstall using:\n\npip install pillow"
16
- )
17
- except Exception as e:
18
- raise RuntimeError(f"Failed to load model or processor: {str(e)}")
19
-
20
- # Load and encode sentences
21
- document_path = "test.txt"
22
- if not os.path.exists(document_path):
23
- raise FileNotFoundError(f"❌ Document not found: {document_path}")
24
-
25
- with open(document_path, "r", encoding="utf-8") as f:
26
- sentences = [line.strip() for line in f if line.strip()]
27
-
28
- with torch.no_grad():
29
- sentence_inputs = processor(text=sentences, return_tensors="pt", padding=True, truncation=True)
30
- sentence_embeddings = model.get_text_features(**sentence_inputs)
31
-
32
- @app.get("/", tags=["Welcome"])
33
- async def root():
34
- return {"message": "✅ CLIP-based Document Retrieval API is Running"}
35
-
36
- @app.get("/search", tags=["Search"])
37
- async def search(
38
- query: str = Query(..., description="Search text query"),
39
- top_k: int = Query(5, gt=0, le=20, description="Number of top results to return (max 20)")
40
- ):
41
- if not query.strip():
42
- raise HTTPException(status_code=400, detail="Query must not be empty")
43
-
44
- with torch.no_grad():
45
- query_inputs = processor(text=[query], return_tensors="pt", padding=True, truncation=True)
46
- query_embedding = model.get_text_features(**query_inputs)
47
-
48
- # Cosine similarity
49
- similarities = torch.nn.functional.cosine_similarity(query_embedding, sentence_embeddings)[0]
50
- top_indices = torch.topk(similarities, k=top_k).indices
51
-
52
- results = [{
53
- "sentence": sentences[i],
54
- "score": round(similarities[i].item(), 4)
55
- } for i in top_indices]
56
-
57
- return {
58
- "query": query,
59
- "results": results
60
- }
 
1
+ from fastapi import FastAPI
 
 
 
2
 
3
+ app = FastAPI()
4
 
5
+ @app.get("/")
6
+ def greet_json():
7
+ return {"welcome": "Created!"}