Spaces:
Build error
Build error
Update rag_system.py
Browse files- rag_system.py +134 -37
rag_system.py
CHANGED
|
@@ -13,10 +13,10 @@ from langchain_community.vectorstores import FAISS
|
|
| 13 |
try:
|
| 14 |
import fitz # PyMuPDF
|
| 15 |
PYMUPDF_AVAILABLE = True
|
| 16 |
-
print("
|
| 17 |
except ImportError:
|
| 18 |
PYMUPDF_AVAILABLE = False
|
| 19 |
-
print("
|
| 20 |
|
| 21 |
# PDF processing utilities
|
| 22 |
import pytesseract
|
|
@@ -396,45 +396,142 @@ def split_documents(documents, chunk_size=800, chunk_overlap=100):
|
|
| 396 |
# Main Execution
|
| 397 |
# --------------------------------
|
| 398 |
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
source = doc.metadata.get('source', 'unknown')
|
| 410 |
page = doc.metadata.get('page', 'unknown')
|
| 411 |
doc_type = doc.metadata.get('type', 'unknown')
|
|
|
|
|
|
|
| 412 |
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
-
for
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
-
|
| 430 |
-
vectorstore.save_local("vector_db")
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
try:
|
| 14 |
import fitz # PyMuPDF
|
| 15 |
PYMUPDF_AVAILABLE = True
|
| 16 |
+
print("PyMuPDF library available")
|
| 17 |
except ImportError:
|
| 18 |
PYMUPDF_AVAILABLE = False
|
| 19 |
+
print("PyMuPDF library is not installed. Install with: pip install PyMuPDF")
|
| 20 |
|
| 21 |
# PDF processing utilities
|
| 22 |
import pytesseract
|
|
|
|
| 396 |
# Main Execution
|
| 397 |
# --------------------------------
|
| 398 |
|
| 399 |
+
def build_rag_chain(llm, vectorstore, language="en", k=7):
|
| 400 |
+
"""Build RAG Chain"""
|
| 401 |
+
question_prompt, refine_prompt = create_refine_prompts_with_pages(language)
|
| 402 |
+
|
| 403 |
+
qa_chain = RetrievalQA.from_chain_type(
|
| 404 |
+
llm=llm,
|
| 405 |
+
chain_type="refine",
|
| 406 |
+
retriever=vectorstore.as_retriever(search_kwargs={"k": k}),
|
| 407 |
+
chain_type_kwargs={
|
| 408 |
+
"question_prompt": question_prompt,
|
| 409 |
+
"refine_prompt": refine_prompt
|
| 410 |
+
},
|
| 411 |
+
return_source_documents=True
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
return qa_chain
|
| 415 |
+
|
| 416 |
+
def ask_question_with_pages(qa_chain, question):
|
| 417 |
+
"""Process questions"""
|
| 418 |
+
result = qa_chain({"query": question})
|
| 419 |
+
|
| 420 |
+
# Extract only the text after A: from the result
|
| 421 |
+
answer = result['result']
|
| 422 |
+
final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
|
| 423 |
+
|
| 424 |
+
print(f"\nQuestion: {question}")
|
| 425 |
+
print(f"\nFinal Answer: {final_answer}")
|
| 426 |
+
|
| 427 |
+
# Metadata debugging info (disabled)
|
| 428 |
+
# debug_metadata_info(result["source_documents"])
|
| 429 |
+
|
| 430 |
+
# Organize reference documents by page
|
| 431 |
+
print("\nReference Document Summary:")
|
| 432 |
+
source_info = {}
|
| 433 |
+
|
| 434 |
+
for doc in result["source_documents"]:
|
| 435 |
source = doc.metadata.get('source', 'unknown')
|
| 436 |
page = doc.metadata.get('page', 'unknown')
|
| 437 |
doc_type = doc.metadata.get('type', 'unknown')
|
| 438 |
+
section = doc.metadata.get('section', None)
|
| 439 |
+
total_pages = doc.metadata.get('total_pages', None)
|
| 440 |
|
| 441 |
+
filename = doc.metadata.get('filename', 'unknown')
|
| 442 |
+
if filename == 'unknown':
|
| 443 |
+
filename = os.path.basename(source) if source != 'unknown' else 'unknown'
|
| 444 |
+
|
| 445 |
+
if filename not in source_info:
|
| 446 |
+
source_info[filename] = {
|
| 447 |
+
'pages': set(),
|
| 448 |
+
'sections': set(),
|
| 449 |
+
'types': set(),
|
| 450 |
+
'total_pages': total_pages
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
if page != 'unknown':
|
| 454 |
+
if isinstance(page, str) and page.startswith('section'):
|
| 455 |
+
source_info[filename]['sections'].add(page)
|
| 456 |
+
else:
|
| 457 |
+
source_info[filename]['pages'].add(page)
|
| 458 |
+
|
| 459 |
+
if section is not None:
|
| 460 |
+
source_info[filename]['sections'].add(f"section {section}")
|
| 461 |
+
|
| 462 |
+
source_info[filename]['types'].add(doc_type)
|
| 463 |
+
|
| 464 |
+
# Result output
|
| 465 |
+
total_chunks = len(result["source_documents"])
|
| 466 |
+
print(f"Total chunks used: {total_chunks}")
|
| 467 |
|
| 468 |
+
for filename, info in source_info.items():
|
| 469 |
+
print(f"\n- {filename}")
|
| 470 |
+
|
| 471 |
+
# Total page count information
|
| 472 |
+
if info['total_pages']:
|
| 473 |
+
print(f" Total page count: {info['total_pages']}")
|
| 474 |
+
|
| 475 |
+
# Page information output
|
| 476 |
+
if info['pages']:
|
| 477 |
+
pages_list = list(info['pages'])
|
| 478 |
+
print(f" Pages: {', '.join(map(str, pages_list))}")
|
| 479 |
+
|
| 480 |
+
# Section information output
|
| 481 |
+
if info['sections']:
|
| 482 |
+
sections_list = sorted(list(info['sections']))
|
| 483 |
+
print(f" Sections: {', '.join(sections_list)}")
|
| 484 |
+
|
| 485 |
+
# If no pages or sections are present
|
| 486 |
+
if not info['pages'] and not info['sections']:
|
| 487 |
+
print(f" Pages: No information")
|
| 488 |
+
|
| 489 |
+
# Output document type
|
| 490 |
+
types_str = ', '.join(sorted(info['types']))
|
| 491 |
+
print(f" Type: {types_str}")
|
| 492 |
|
| 493 |
+
return result
|
|
|
|
| 494 |
|
| 495 |
+
# Existing ask_question function is replaced with ask_question_with_pages
|
| 496 |
+
def ask_question(qa_chain, question):
|
| 497 |
+
"""Wrapper function for compatibility"""
|
| 498 |
+
return ask_question_with_pages(qa_chain, question)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
if __name__ == "__main__":
|
| 502 |
+
parser = argparse.ArgumentParser(description="RAG refine system (supports page numbers)")
|
| 503 |
+
parser.add_argument("--vector_store", type=str, default="vector_db", help="Vector store path")
|
| 504 |
+
parser.add_argument("--model", type=str, default="LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct", help="LLM model ID")
|
| 505 |
+
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use")
|
| 506 |
+
parser.add_argument("--k", type=int, default=7, help="Number of documents to retrieve")
|
| 507 |
+
parser.add_argument("--language", type=str, default="en", choices=["ko", "en"], help="Language to use")
|
| 508 |
+
parser.add_argument("--query", type=str, help="Question (runs interactive mode if not provided)")
|
| 509 |
+
|
| 510 |
+
args = parser.parse_args()
|
| 511 |
+
|
| 512 |
+
embeddings = get_embeddings(device=args.device)
|
| 513 |
+
vectorstore = load_vector_store(embeddings, load_path=args.vector_store)
|
| 514 |
+
llm = load_llama_model()
|
| 515 |
+
|
| 516 |
+
from rag_system import build_rag_chain, ask_question_with_pages #Hinzugefügt, um den neuen ask_question_with_pages code in der Konsole nutzbar zu machen.
|
| 517 |
+
|
| 518 |
+
qa_chain = build_rag_chain(llm, vectorstore, language=args.language, k=args.k)
|
| 519 |
+
|
| 520 |
+
print("RAG system with page number support ready!")
|
| 521 |
+
|
| 522 |
+
if args.query:
|
| 523 |
+
ask_question_with_pages(qa_chain, args.query)
|
| 524 |
+
else:
|
| 525 |
+
print("Starting interactive mode (enter 'exit', 'quit' to finish)")
|
| 526 |
+
while True:
|
| 527 |
+
try:
|
| 528 |
+
query = input("Question: ").strip()
|
| 529 |
+
if query.lower() in ["exit", "quit"]:
|
| 530 |
+
break
|
| 531 |
+
if query: # Prevent empty input
|
| 532 |
+
ask_question_with_pages(qa_chain, query)
|
| 533 |
+
except KeyboardInterrupt:
|
| 534 |
+
print("\n\nExiting program.")
|
| 535 |
+
break
|
| 536 |
+
except Exception as e:
|
| 537 |
+
print(f"Error occurred: {e}\nPlease try again.")
|