| from transformers import RagRetriever, RagTokenizer | |
| # 自动从Hub加载索引 | |
| retriever = RagRetriever.from_pretrained( | |
| "facebook/rag-token-base", | |
| index_name="custom", | |
| index_path="GOGO198/GOGO_dataset" | |
| ) | |
| tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") | |
| def answer_question(question): | |
| inputs = tokenizer(question, return_tensors="pt") | |
| outputs = retriever(**inputs) | |
| return outputs['answer'] |