Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import chromadb | |
| from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # πΉ Load Mistral-7B for LLM Responses | |
| import os | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # πΉ Load API Token from Hugging Face Secrets | |
| HF_TOKEN = os.getenv("api_key") # β Securely load API key | |
| # πΉ Ensure API Token is Loaded | |
| if HF_TOKEN is None: | |
| raise ValueError("β Hugging Face API token not found. Add `HF_TOKEN` in Hugging Face Secrets.") | |
| # πΉ Load Mistral-7B-Instruct with Authentication | |
| llm_name = "mistralai/Mistral-7B-Instruct-v0.1" | |
| llm_tokenizer = AutoTokenizer.from_pretrained(llm_name, use_auth_token=HF_TOKEN) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| llm_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| use_auth_token=HF_TOKEN | |
| ) | |
| # πΉ Optimize Mistral for Faster Inference | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True | |
| llm_model = torch.compile(llm_model) | |
| # πΉ Initialize ChromaDB | |
| import os | |
| import zipfile | |
| # πΉ Unzip ChromaDB database if not extracted | |
| if not os.path.exists("./chroma_db"): | |
| with zipfile.ZipFile("chroma_db.zip", 'r') as zip_ref: | |
| zip_ref.extractall("./") | |
| print("β ChromaDB database loaded!") | |
| import chromadb | |
| # πΉ Load ChromaDB from local storage | |
| chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
| collection = chroma_client.get_or_create_collection(name="hepB_knowledge") | |
| print("β ChromaDB initialized!") | |
| # πΉ Function to Generate LLM Responses | |
| import torch | |
| # πΉ Detect Device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"β Using device: {device}") | |
| def generate_humanized_response(query, retrieved_text): | |
| """Passes retrieved chunks through Mistral-7B to improve readability.""" | |
| # πΉ Truncate retrieved text to avoid long input errors | |
| retrieved_text = retrieved_text[:500] | |
| prompt = f"""You are a medical assistant. Answer the following question based on retrieved text: | |
| Retrieved Text: | |
| {retrieved_text} | |
| User Query: {query} | |
| Provide a well-structured, human-like response: | |
| """ | |
| inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) # β Uses GPU if available, otherwise CPU | |
| output = llm_model.generate(**inputs, max_new_tokens=150, do_sample=True) | |
| response = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
| return response | |
| # πΉ Load BioMedBERT for Embeddings | |
| embed_model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract" | |
| embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name) | |
| embed_model = AutoModel.from_pretrained(embed_model_name) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| embed_model.to(device) | |
| # πΉ Function to Generate Text Embeddings | |
| def get_embedding(text): | |
| """Generates BioMedBERT embeddings using the CLS token (max 512 tokens).""" | |
| inputs = embed_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512 | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = embed_model(**inputs) | |
| cls_embedding = outputs.last_hidden_state[:, 0, :].cpu() # Move back to CPU | |
| return cls_embedding.squeeze().numpy().tolist() | |
| # πΉ Function to Retrieve Similar Chunks | |
| def retrieve_similar_chunks(query, top_k=5, similarity_threshold=0.5): | |
| """Finds top-k similar chunks from ChromaDB using cosine similarity.""" | |
| print("πΉ Generating embedding for query...") | |
| query_embedding = get_embedding(query) | |
| print("πΉ Querying ChromaDB...") | |
| results = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k | |
| ) | |
| # β Check if results are empty before accessing scores | |
| if not results["documents"] or not results["distances"]: | |
| print("β No relevant chunks found in ChromaDB.") | |
| return ["Sorry, I couldn't find relevant information."] | |
| print(f"πΉ Retrieved {len(results['documents'])} chunks from ChromaDB.") | |
| # π Print similarity scores | |
| for i, score in enumerate(results["distances"]): | |
| print(f"Chunk {i+1} Score: {score}") | |
| # π Filter out low-score chunks | |
| filtered_results = [] | |
| for doc, scores in zip(results["documents"], results["distances"]): | |
| if scores and scores[0] >= similarity_threshold: # β Avoid IndexError | |
| filtered_results.append(doc) | |
| print("β Retrieval completed.") | |
| return filtered_results if filtered_results else ["Sorry, I couldn't find relevant information."] | |
| # πΉ Chatbot Function | |
| def chatbot(query): | |
| """Returns a structured and human-like answer using Mistral-7B.""" | |
| retrieved_chunks = retrieve_similar_chunks(query) | |
| if not retrieved_chunks or retrieved_chunks == ["No relevant information found."]: | |
| return "Sorry, I couldn't find relevant information." | |
| retrieved_texts = [chunk if isinstance(chunk, str) else " ".join(chunk) for chunk in retrieved_chunks] | |
| retrieved_text = "\n\n".join(retrieved_texts)[:500] | |
| response = generate_humanized_response(query, retrieved_text) | |
| return response | |
| # πΉ Gradio Chat Interface | |
| ui = gr.Interface( | |
| fn=chatbot, | |
| inputs=gr.Textbox(lines=2, placeholder="Ask about Hepatitis B..."), | |
| outputs=gr.Textbox(), | |
| title="π‘ Hepatitis B Chatbot", | |
| description="βοΈ Ask questions based on WHO Hepatitis B guidelines (2024). Uses ChromaDB & Mistral-7B for responses.", | |
| ) | |
| # π₯ Run the Chatbot | |
| if __name__ == "__main__": | |
| ui.launch() | |