Spaces:
Build error
Build error
| import os | |
| import argparse | |
| import sys | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from vector_store import get_embeddings, load_vector_store | |
| from llm_loader import load_llama_model | |
| def create_refine_prompts_with_pages(language="ko"): | |
| if language == "ko": | |
| question_prompt = PromptTemplate( | |
| input_variables=["context_str", "question"], | |
| template=""" | |
| ๋ค์์ ๊ฒ์๋ ๋ฌธ์ ์กฐ๊ฐ๋ค์ ๋๋ค: | |
| {context_str} | |
| ์ ๋ฌธ์๋ค์ ์ฐธ๊ณ ํ์ฌ ์ง๋ฌธ์ ๋ต๋ณํด์ฃผ์ธ์. | |
| **์ค์ํ ๊ท์น:** | |
| - ๋ต๋ณ ์ ์ฐธ๊ณ ํ ๋ฌธ์๊ฐ ์๋ค๋ฉด ํด๋น ์ ๋ณด๋ฅผ ์ธ์ฉํ์ธ์ | |
| - ๋ฌธ์์ ๋ช ์๋ ์ ๋ณด๋ง ์ฌ์ฉํ๊ณ , ์ถ์ธกํ์ง ๋ง์ธ์ | |
| - ํ์ด์ง ๋ฒํธ๋ ์ถ์ฒ๋ ์ ๋ฌธ์์์ ํ์ธ๋ ๊ฒ๋ง ์ธ๊ธํ์ธ์ | |
| - ํ์คํ์ง ์์ ์ ๋ณด๋ "๋ฌธ์์์ ํ์ธ๋์ง ์์"์ด๋ผ๊ณ ๋ช ์ํ์ธ์ | |
| ์ง๋ฌธ: {question} | |
| ๋ต๋ณ:""" | |
| ) | |
| refine_prompt = PromptTemplate( | |
| input_variables=["question", "existing_answer", "context_str"], | |
| template=""" | |
| ๊ธฐ์กด ๋ต๋ณ: | |
| {existing_answer} | |
| ์ถ๊ฐ ๋ฌธ์: | |
| {context_str} | |
| ๊ธฐ์กด ๋ต๋ณ์ ์ ์ถ๊ฐ ๋ฌธ์๋ฅผ ๋ฐํ์ผ๋ก ๋ณด์ํ๊ฑฐ๋ ์์ ํด์ฃผ์ธ์. | |
| **๊ท์น:** | |
| - ์๋ก์ด ์ ๋ณด๊ฐ ๊ธฐ์กด ๋ต๋ณ๊ณผ ๋ค๋ฅด๋ค๋ฉด ์์ ํ์ธ์ | |
| - ์ถ๊ฐ ๋ฌธ์์ ๋ช ์๋ ์ ๋ณด๋ง ์ฌ์ฉํ์ธ์ | |
| - ํ๋์ ์๊ฒฐ๋ ๋ต๋ณ์ผ๋ก ์์ฑํ์ธ์ | |
| - ํ์คํ์ง ์์ ์ถ์ฒ๋ ํ์ด์ง๋ ์ธ๊ธํ์ง ๋ง์ธ์ | |
| ์ง๋ฌธ: {question} | |
| ๋ต๋ณ:""" | |
| ) | |
| else: | |
| question_prompt = PromptTemplate( | |
| input_variables=["context_str", "question"], | |
| template=""" | |
| Here are the retrieved document fragments: | |
| {context_str} | |
| Please answer the question based on the above documents. | |
| **Important rules:** | |
| - Only use information explicitly stated in the documents | |
| - If citing sources, only mention what is clearly indicated in the documents above | |
| - Do not guess or infer page numbers not shown in the context | |
| - If unsure, state "not confirmed in the provided documents" | |
| Question: {question} | |
| Answer:""" | |
| ) | |
| refine_prompt = PromptTemplate( | |
| input_variables=["question", "existing_answer", "context_str"], | |
| template=""" | |
| Existing answer: | |
| {existing_answer} | |
| Additional documents: | |
| {context_str} | |
| Refine the existing answer using the additional documents. | |
| **Rules:** | |
| - Only use information explicitly stated in the additional documents | |
| - Create one coherent final answer | |
| - Do not mention uncertain sources or page numbers | |
| Question: {question} | |
| Answer:""" | |
| ) | |
| return question_prompt, refine_prompt | |
| def build_rag_chain(llm, vectorstore, language="ko", k=7): | |
| """RAG ์ฒด์ธ ๊ตฌ์ถ""" | |
| question_prompt, refine_prompt = create_refine_prompts_with_pages(language) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="refine", | |
| retriever=vectorstore.as_retriever(search_kwargs={"k": k}), | |
| chain_type_kwargs={ | |
| "question_prompt": question_prompt, | |
| "refine_prompt": refine_prompt | |
| }, | |
| return_source_documents=True | |
| ) | |
| return qa_chain | |
| def ask_question_with_pages(qa_chain, question): | |
| """์ง๋ฌธ ์ฒ๋ฆฌ""" | |
| result = qa_chain.invoke({"query": question}) | |
| # ๊ฒฐ๊ณผ์์ A: ์ดํ ๋ฌธ์ฅ๋ง ์ถ์ถ | |
| answer = result['result'] | |
| final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip() | |
| print(f"\n๐งพ ์ง๋ฌธ: {question}") | |
| print(f"\n๐ข ์ต์ข ๋ต๋ณ: {final_answer}") | |
| # ๋ฉํ๋ฐ์ดํฐ ๋๋ฒ๊น ์ ๋ณด ์ถ๋ ฅ (๋นํ์ฑํ) | |
| # debug_metadata_info(result["source_documents"]) | |
| # ์ฐธ๊ณ ๋ฌธ์๋ฅผ ํ์ด์ง๋ณ๋ก ์ ๋ฆฌ | |
| print("\n๐ ์ฐธ๊ณ ๋ฌธ์ ์์ฝ:") | |
| source_info = {} | |
| for doc in result["source_documents"]: | |
| source = doc.metadata.get('source', 'N/A') | |
| page = doc.metadata.get('page', 'N/A') | |
| doc_type = doc.metadata.get('type', 'N/A') | |
| section = doc.metadata.get('section', None) | |
| total_pages = doc.metadata.get('total_pages', None) | |
| filename = doc.metadata.get('filename', 'N/A') | |
| if filename == 'N/A': | |
| filename = os.path.basename(source) if source != 'N/A' else 'N/A' | |
| if filename not in source_info: | |
| source_info[filename] = { | |
| 'pages': set(), | |
| 'sections': set(), | |
| 'types': set(), | |
| 'total_pages': total_pages | |
| } | |
| if page != 'N/A': | |
| if isinstance(page, str) and page.startswith('์น์ '): | |
| source_info[filename]['sections'].add(page) | |
| else: | |
| source_info[filename]['pages'].add(page) | |
| if section is not None: | |
| source_info[filename]['sections'].add(f"์น์ {section}") | |
| source_info[filename]['types'].add(doc_type) | |
| # ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
| total_chunks = len(result["source_documents"]) | |
| print(f"์ด ์ฌ์ฉ๋ ์ฒญํฌ ์: {total_chunks}") | |
| for filename, info in source_info.items(): | |
| print(f"\n- {filename}") | |
| # ์ ์ฒด ํ์ด์ง ์ ์ ๋ณด | |
| if info['total_pages']: | |
| print(f" ์ ์ฒด ํ์ด์ง ์: {info['total_pages']}") | |
| # ํ์ด์ง ์ ๋ณด ์ถ๋ ฅ | |
| if info['pages']: | |
| pages_list = list(info['pages']) | |
| print(f" ํ์ด์ง: {', '.join(map(str, pages_list))}") | |
| # ์น์ ์ ๋ณด ์ถ๋ ฅ | |
| if info['sections']: | |
| sections_list = sorted(list(info['sections'])) | |
| print(f" ์น์ : {', '.join(sections_list)}") | |
| # ํ์ด์ง์ ์น์ ์ด ๋ชจ๋ ์๋ ๊ฒฝ์ฐ | |
| if not info['pages'] and not info['sections']: | |
| print(f" ํ์ด์ง: ์ ๋ณด ์์") | |
| # ๋ฌธ์ ์ ํ ์ถ๋ ฅ | |
| types_str = ', '.join(sorted(info['types'])) | |
| print(f" ์ ํ: {types_str}") | |
| return result | |
| # ๊ธฐ์กด ask_question ํจ์๋ ask_question_with_pages๋ก ๊ต์ฒด | |
| def ask_question(qa_chain, question): | |
| """ํธํ์ฑ์ ์ํ ๋ํผ ํจ์""" | |
| return ask_question_with_pages(qa_chain, question) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="RAG refine system (ํ์ด์ง ๋ฒํธ ์ง์)") | |
| parser.add_argument("--vector_store", type=str, default="vector_db", help="๋ฒกํฐ ์คํ ์ด ๊ฒฝ๋ก") | |
| parser.add_argument("--model", type=str, default="LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct", help="LLM ๋ชจ๋ธ ID") | |
| parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="์ฌ์ฉํ ๋๋ฐ์ด์ค") | |
| parser.add_argument("--k", type=int, default=7, help="๊ฒ์ํ ๋ฌธ์ ์") | |
| parser.add_argument("--language", type=str, default="ko", choices=["ko", "en"], help="์ฌ์ฉํ ์ธ์ด") | |
| parser.add_argument("--query", type=str, help="์ง๋ฌธ (์์ผ๋ฉด ๋ํํ ๋ชจ๋ ์คํ)") | |
| args = parser.parse_args() | |
| embeddings = get_embeddings(device=args.device) | |
| vectorstore = load_vector_store(embeddings, load_path=args.vector_store) | |
| llm = load_llama_model() | |
| qa_chain = build_rag_chain(llm, vectorstore, language=args.language, k=args.k) | |
| print("๐ข RAG ํ์ด์ง ๋ฒํธ ์ง์ ์์คํ ์ค๋น ์๋ฃ!") | |
| if args.query: | |
| ask_question_with_pages(qa_chain, args.query) | |
| else: | |
| print("๐ฌ ๋ํํ ๋ชจ๋ ์์ (์ข ๋ฃํ๋ ค๋ฉด 'exit', 'quit', '์ข ๋ฃ' ์ ๋ ฅ)") | |
| while True: | |
| try: | |
| query = input("\n์ง๋ฌธ: ").strip() | |
| if query.lower() in ["exit", "quit", "์ข ๋ฃ"]: | |
| break | |
| if query: # ๋น ์ ๋ ฅ ๋ฐฉ์ง | |
| ask_question_with_pages(qa_chain, query) | |
| except KeyboardInterrupt: | |
| print("\n\nํ๋ก๊ทธ๋จ์ ์ข ๋ฃํฉ๋๋ค.") | |
| break | |
| except Exception as e: | |
| print(f"โ ์ค๋ฅ ๋ฐ์: {e}\n๋ค์ ์๋ํด์ฃผ์ธ์.") | |