Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import re | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_groq import ChatGroq | |
| from langchain.schema import Document | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import RetrievalQA | |
| import chardet | |
| import gradio as gr | |
| import pandas as pd | |
| import json | |
| from nltk.corpus import words | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def clean_api_key(key): | |
| return ''.join(c for c in key if ord(c) < 128) | |
| # Load the GROQ API key | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.") | |
| api_key = clean_api_key(api_key).strip() | |
| def clean_text(text): | |
| return text.encode("ascii", errors="ignore").decode() | |
| def load_documents(file_paths): | |
| docs = [] | |
| for file_path in file_paths: | |
| ext = os.path.splitext(file_path)[-1].lower() | |
| try: | |
| if ext == ".csv": | |
| with open(file_path, 'rb') as f: | |
| result = chardet.detect(f.read()) | |
| encoding = result['encoding'] | |
| data = pd.read_csv(file_path, encoding=encoding) | |
| for _, row in data.iterrows(): | |
| content = clean_text(row.to_string()) | |
| docs.append(Document(page_content=content, metadata={"source": file_path})) | |
| elif ext == ".json": | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| if isinstance(data, list): | |
| for entry in data: | |
| content = clean_text(json.dumps(entry)) | |
| docs.append(Document(page_content=content, metadata={"source": file_path})) | |
| elif isinstance(data, dict): | |
| content = clean_text(json.dumps(data)) | |
| docs.append(Document(page_content=content, metadata={"source": file_path})) | |
| elif ext == ".txt": | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = clean_text(f.read()) | |
| docs.append(Document(page_content=content, metadata={"source": file_path})) | |
| else: | |
| logger.warning(f"Unsupported file format: {file_path}") | |
| except Exception as e: | |
| logger.error(f"Error processing file {file_path}: {e}") | |
| return docs | |
| # Enhanced input validation | |
| # Load NLTK word list | |
| try: | |
| english_words = set(words.words()) | |
| except LookupError: | |
| import nltk | |
| nltk.download('words') | |
| english_words = set(words.words()) | |
| def is_valid_input(text): | |
| """Validate the user's input question.""" | |
| if not text or text.strip() == "": | |
| return False, "Input cannot be empty. Please provide a meaningful question." | |
| if len(text.strip()) < 2: | |
| return False, "Input is too short. Please provide more context or details." | |
| # Check for valid words | |
| words_in_text = re.findall(r'\b\w+\b', text.lower()) | |
| recognized_words = [word for word in words_in_text if word in english_words] | |
| if not recognized_words: | |
| return False, "Input appears unclear. Please use valid words in your question." | |
| return True, "Valid input." | |
| def initialize_llm(model, temperature, max_tokens): | |
| prompt_allocation = int(max_tokens * 0.2) | |
| response_max_tokens = max_tokens - prompt_allocation | |
| if response_max_tokens <= 50: | |
| raise ValueError("max_tokens too small.") | |
| llm = ChatGroq( | |
| model=model, | |
| temperature=temperature, | |
| max_tokens=response_max_tokens, | |
| api_key=api_key | |
| ) | |
| return llm | |
| def create_rag_pipeline(file_paths, model, temperature, max_tokens): | |
| llm = initialize_llm(model, temperature, max_tokens) | |
| docs = load_documents(file_paths) | |
| if not docs: | |
| return None, "No documents were loaded." | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| splits = text_splitter.split_documents(docs) | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| vectorstore = Chroma.from_documents( | |
| documents=splits, | |
| embedding=embedding_model, | |
| persist_directory="/tmp/chroma_db" | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| custom_prompt_template = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=""" | |
| You are an AI assistant specialized in daily wellness. Provide a concise, thorough, and stand-alone answer to the user's question based on the given context. Include relevant examples or schedules where beneficial. **When listing steps or guidelines, format them as a numbered list with appropriate markdown formatting.** The final answer should be coherent, self-contained, and end with a complete sentence. | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Final Answer: | |
| """ | |
| ) | |
| rag_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": custom_prompt_template} | |
| ) | |
| return rag_chain, "Pipeline created successfully." | |
| file_paths = ['AIChatbot.csv'] | |
| model = "llama3-8b-8192" | |
| temperature = 0.7 | |
| max_tokens = 500 | |
| rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens) | |
| def answer_question(model, temperature, max_tokens, question): | |
| is_valid, message = is_valid_input(question) | |
| if not is_valid: | |
| return message | |
| if rag_chain is None: | |
| return "The system is currently unavailable. Please try again later." | |
| try: | |
| answer = rag_chain.run(question) | |
| return answer.strip() | |
| except Exception as e_inner: | |
| logger.error(f"Error: {e_inner}") | |
| return "An error occurred while processing your request." | |
| def gradio_interface(model, temperature, max_tokens, question): | |
| return answer_question(model, temperature, max_tokens, question) | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Textbox(label="Model Name", value=model), | |
| gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.01, value=temperature), | |
| gr.Slider(label="Max Tokens", minimum=200, maximum=2048, step=1, value=max_tokens), | |
| gr.Textbox(label="Question", placeholder="e.g., What is box breathing and how does it help reduce anxiety?") | |
| ], | |
| outputs=gr.Markdown(label="Answer"), | |
| title="Daily Wellness AI", | |
| description="Ask questions about daily wellness and receive a concise, complete answer.", | |
| examples=[ | |
| ["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?"], | |
| ["llama3-8b-8192", 0.6, 600, "Give me a weekly fitness schedule incorporating mindfulness exercises."] | |
| ], | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |