Spaces:
Sleeping
Sleeping
| # app_hybrid_llm.py | |
| import os | |
| import re | |
| import numpy as np | |
| import faiss | |
| import gradio as gr | |
| import openai | |
| from openai import OpenAI | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from sentence_transformers import SentenceTransformer | |
| DARTMOUTH_CHAT_API_KEY = os.getenv('DARTMOUTH_CHAT_API_KEY') | |
| if DARTMOUTH_CHAT_API_KEY is None: | |
| raise ValueError("DARTMOUTH_CHAT_API_KEY not set.") | |
| MODEL = "openai.gpt-4o-2024-08-06" | |
| client = OpenAI( | |
| base_url="https://chat.dartmouth.edu/api", # Replace with your endpoint URL | |
| api_key=DARTMOUTH_CHAT_API_KEY, # Replace with your API key, if required | |
| ) | |
| # --- Load and Prepare Data --- | |
| with open("gen_agents.txt", "r", encoding="utf-8") as f: | |
| full_text = f.read() | |
| text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=512, chunk_overlap=20) | |
| docs = text_splitter.create_documents([full_text]) | |
| passages = [doc.page_content for doc in docs] | |
| embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
| passage_embeddings = embedder.encode(passages, convert_to_tensor=False, show_progress_bar=True) | |
| passage_embeddings = np.array(passage_embeddings).astype("float32") | |
| d = passage_embeddings.shape[1] | |
| index = faiss.IndexFlatL2(d) | |
| index.add(passage_embeddings) | |
| # --- Provided Functions --- | |
| def retrieve_passages(query, embedder, index, passages, top_k=3): | |
| query_embedding = embedder.encode([query], convert_to_tensor=False) | |
| query_embedding = np.array(query_embedding).astype('float32') | |
| distances, indices = index.search(query_embedding, top_k) | |
| retrieved = [passages[i] for i in indices[0]] | |
| return retrieved | |
| def process_llm_output_with_references(text, passages): | |
| """ | |
| Replace tokens like <<PASSAGE_1>> in the LLM output with HTML block quotes. | |
| """ | |
| def replacement(match): | |
| num = int(match.group(1)) | |
| if 0 <= num < len(passages): | |
| passage_text = passages[num - 1] | |
| return (f"<blockquote style='background: #ffffff; color: #000000; padding: 10px; " | |
| f"border-left: 5px solid #ccc; margin: 10px 0; font-size: 14px;'>{passage_text}</blockquote>") | |
| return match.group(0) | |
| return re.sub(r"<<PASSAGE_(\d+)>>", replacement, text) | |
| def generate_answer_with_references(query, retrieved_text): | |
| """ | |
| Generate an answer using GPT-4 with reference tokens. | |
| """ | |
| context_str = "\n".join([f"<<PASSAGE_{i}>>: \"{passage}\"" for i, passage in enumerate(retrieved_text)]) | |
| messages = [ | |
| {"role": "system", "content": "You are a knowledgeable technical assistant."}, | |
| {"role": "user", "content": ( | |
| f"Using the following textbook passages as reference:\n{context_str}\n\n" | |
| "In your answer, include passage block quotes as references. " | |
| "Refer to the passages using tokens such as <<PASSAGE_0>>, <<PASSAGE_1>>, etc. " | |
| "They should appear after complete thoughts on a new line.\n\n" | |
| f"Answer the question: {query}" | |
| )} | |
| ] | |
| response = client.chat.completions.create( | |
| model=MODEL, | |
| messages=messages, | |
| ) | |
| answer = response.choices[0].message.content.strip() | |
| return answer | |
| # --- Gradio App Function --- | |
| def get_hybrid_output(query): | |
| retrieved = retrieve_passages(query, embedder, index, passages, top_k=3) | |
| hybrid_raw = generate_answer_with_references(query, retrieved) | |
| hybrid_processed = process_llm_output_with_references(hybrid_raw, retrieved) | |
| return f"<div style='white-space: pre-wrap;'>{hybrid_processed}</div>" | |
| def clear_output(): | |
| return "" | |
| # --- Custom CSS --- | |
| custom_css = """ | |
| body { | |
| background-color: #343541 !important; | |
| color: #ECECEC !important; | |
| margin: 0; | |
| padding: 0; | |
| font-family: 'Inter', sans-serif; | |
| } | |
| #container { | |
| max-width: 900px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| } | |
| label { | |
| color: #ECECEC; | |
| font-weight: 600; | |
| } | |
| textarea, input { | |
| background-color: #40414F; | |
| color: #ECECEC; | |
| border: 1px solid #565869; | |
| } | |
| button { | |
| background-color: #565869; | |
| color: #ECECEC; | |
| border: none; | |
| font-weight: 600; | |
| transition: background-color 0.2s ease; | |
| } | |
| button:hover { | |
| background-color: #6e7283; | |
| } | |
| .output-box { | |
| border: 1px solid #565869; | |
| border-radius: 4px; | |
| padding: 10px; | |
| margin-top: 8px; | |
| background-color: #40414F; | |
| } | |
| """ | |
| # --- Build Gradio Interface --- | |
| with gr.Blocks(css=custom_css) as demo: | |
| with gr.Column(elem_id="container"): | |
| gr.Markdown("## Anonymous Chatbot\n### Loaded Article: Generative Agents - Interactive Simulacra of Human Behavior (Park et al. 2023)\n [https://arxiv.org/pdf/2304.03442](https://arxiv.org/pdf/2304.03442)") | |
| gr.Markdown("Enter any questions about the article above in the prompt!") | |
| query_input = gr.Textbox(label="Query", placeholder="Enter your query here...", lines=1) | |
| with gr.Column(): | |
| submit_button = gr.Button("Submit") | |
| clear_button = gr.Button("Clear") | |
| output_box = gr.HTML(label="Output", elem_classes="output-box") | |
| submit_button.click(fn=get_hybrid_output, inputs=query_input, outputs=output_box) | |
| clear_button.click(fn=clear_output, inputs=[], outputs=output_box) | |
| demo.launch() | |