Kalpokoch commited on
Commit
9dfc4a0
·
verified ·
1 Parent(s): 39c476b

Update app/app.py

Browse files
Files changed (1) hide show
  1. app/app.py +18 -20
app/app.py CHANGED
@@ -3,33 +3,37 @@ from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
- from app.policy_vector_db import PolicyVectorDB # Import your class
7
 
8
  # --- 1. Initialize the Vector Database and LLM ---
9
 
10
- # Load the vector database.
11
- # This connects to the persistent ChromaDB storage created by policy_vector_db.py
12
  print("Loading Vector Database...")
13
- db = PolicyVectorDB(persist_directory="../policy_vector_db")
14
  print("Vector Database loaded successfully!")
15
 
16
- # Load your fine-tuned model from Hugging Face Hub
17
- model_id = "Kalpokoch/QuntizedTinyLama"
18
  print(f"Loading model: {model_id}...")
19
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_id,
22
- torch_dtype=torch.bfloat16,
23
  device_map="auto"
24
  )
25
 
26
- # Create a text-generation pipeline for the LLM
27
  pipe = pipeline(
28
  "text-generation",
29
  model=model,
30
  tokenizer=tokenizer,
31
  max_new_tokens=256
32
  )
 
33
  print("LLM and pipeline loaded successfully!")
34
 
35
 
@@ -44,13 +48,16 @@ app.add_middleware(
44
  allow_headers=["*"],
45
  )
46
 
 
47
  @app.get("/")
48
  def read_root():
49
- return {"message": "RAG chatbot backend is running with Kalpokoch/QuntizedTinyLama and ChromaDB!"}
 
50
 
51
  class ChatRequest(BaseModel):
52
  question: str
53
 
 
54
  @app.post("/chat")
55
  def chat(request: ChatRequest):
56
  question = request.question.strip()
@@ -58,21 +65,17 @@ def chat(request: ChatRequest):
58
  return {"response": "Please ask a question."}
59
 
60
  # --- 3. RAG Retrieval using PolicyVectorDB ---
61
- # Use the search method from your class to find relevant context
62
  print(f"Searching for context for question: '{question}'")
63
  search_results = db.search(query_text=question, top_k=3)
64
-
65
- # Check if any results were found
66
  if not search_results:
67
  retrieved_context = "No relevant context found."
68
  else:
69
- # Format the retrieved documents into a single context string
70
  retrieved_context = "\n\n".join([result['text'] for result in search_results])
71
-
72
  print(f"Retrieved Context:\n{retrieved_context[:500]}...")
73
 
74
  # --- 4. Prompt Engineering and Generation ---
75
- # Build the prompt with the retrieved context
76
  prompt = (
77
  f"<|system|>\nYou are a helpful assistant for NEEPCO policies. "
78
  f"Use the following context to answer the user's question. If the context doesn't contain the answer, say that.\n"
@@ -81,16 +84,11 @@ def chat(request: ChatRequest):
81
  f"<|assistant|>"
82
  )
83
 
84
- # Generate a response using the pipeline
85
  try:
86
  outputs = pipe(prompt)
87
  reply = outputs[0]['generated_text']
88
-
89
- # Extract only the assistant's newly generated reply
90
  assistant_reply = reply.split("<|assistant|>")[1].strip()
91
-
92
  return {"response": assistant_reply}
93
  except Exception as e:
94
  print(f"Error during model inference: {e}")
95
  return {"response": "Sorry, I encountered an error while generating a response."}
96
-
 
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+ from app.policy_vector_db import PolicyVectorDB # Import your class
7
 
8
  # --- 1. Initialize the Vector Database and LLM ---
9
 
10
+ # Load the vector database from /tmp (safest in Docker/HF Spaces)
 
11
  print("Loading Vector Database...")
12
+ db = PolicyVectorDB(persist_directory="/tmp/policy_vector_db")
13
  print("Vector Database loaded successfully!")
14
 
15
+ # Load your quantized model from Hugging Face Hub
16
+ model_id = "Kalpokoch/QuantizedTinyLlama" # Correct spelling assumed
17
  print(f"Loading model: {model_id}...")
18
  tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+
20
+ # Choose dtype depending on device support
21
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
22
+
23
  model = AutoModelForCausalLM.from_pretrained(
24
  model_id,
25
+ torch_dtype=dtype,
26
  device_map="auto"
27
  )
28
 
29
+ # Create a text-generation pipeline
30
  pipe = pipeline(
31
  "text-generation",
32
  model=model,
33
  tokenizer=tokenizer,
34
  max_new_tokens=256
35
  )
36
+
37
  print("LLM and pipeline loaded successfully!")
38
 
39
 
 
48
  allow_headers=["*"],
49
  )
50
 
51
+
52
  @app.get("/")
53
  def read_root():
54
+ return {"message": "RAG chatbot backend is running with Kalpokoch/QuantizedTinyLlama and ChromaDB!"}
55
+
56
 
57
  class ChatRequest(BaseModel):
58
  question: str
59
 
60
+
61
  @app.post("/chat")
62
  def chat(request: ChatRequest):
63
  question = request.question.strip()
 
65
  return {"response": "Please ask a question."}
66
 
67
  # --- 3. RAG Retrieval using PolicyVectorDB ---
 
68
  print(f"Searching for context for question: '{question}'")
69
  search_results = db.search(query_text=question, top_k=3)
70
+
 
71
  if not search_results:
72
  retrieved_context = "No relevant context found."
73
  else:
 
74
  retrieved_context = "\n\n".join([result['text'] for result in search_results])
75
+
76
  print(f"Retrieved Context:\n{retrieved_context[:500]}...")
77
 
78
  # --- 4. Prompt Engineering and Generation ---
 
79
  prompt = (
80
  f"<|system|>\nYou are a helpful assistant for NEEPCO policies. "
81
  f"Use the following context to answer the user's question. If the context doesn't contain the answer, say that.\n"
 
84
  f"<|assistant|>"
85
  )
86
 
 
87
  try:
88
  outputs = pipe(prompt)
89
  reply = outputs[0]['generated_text']
 
 
90
  assistant_reply = reply.split("<|assistant|>")[1].strip()
 
91
  return {"response": assistant_reply}
92
  except Exception as e:
93
  print(f"Error during model inference: {e}")
94
  return {"response": "Sorry, I encountered an error while generating a response."}