Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration | |
| from datasets import load_dataset | |
| import faiss | |
| import numpy as np | |
| import streamlit as st | |
| # Load the datasets from Hugging Face | |
| datasets_dict = {} | |
| # Function to load datasets safely | |
| def load_datasets(): | |
| global datasets_dict | |
| try: | |
| datasets_dict["BillSum"] = load_dataset("billsum") | |
| except Exception as e: | |
| st.error(f"Error loading BillSum dataset: {e}") | |
| try: | |
| datasets_dict["EurLex"] = load_dataset("eurlex", trust_remote_code=True) # Set trust_remote_code=True | |
| except Exception as e: | |
| st.error(f"Error loading EurLex dataset: {e}") | |
| # Load datasets at the start | |
| load_datasets() | |
| # Load the T5 model and tokenizer for summarization | |
| t5_tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
| t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
| # Initialize variables for the selected dataset | |
| selected_dataset = "BillSum" | |
| documents = [] | |
| titles = [] | |
| # Prepare the dataset for retrieval based on user selection | |
| def prepare_dataset(dataset_name): | |
| global documents, titles | |
| dataset = datasets_dict[dataset_name] | |
| documents = dataset['train']['text'][:100] # Use a subset for demo purposes | |
| titles = dataset['train']['title'][:100] # Get corresponding titles | |
| # Function for case retrieval and summarization | |
| def retrieve_cases(query): | |
| # Implement a simple keyword-based search for demo purposes | |
| return [(doc, title) for doc, title in zip(documents, titles) if query.lower() in doc.lower()] | |
| def summarize_cases(cases): | |
| summaries = [] | |
| for case in cases: | |
| input_ids = t5_tokenizer.encode(case[0], return_tensors="pt", max_length=512, truncation=True) | |
| outputs = t5_model.generate(input_ids, max_length=60, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True) | |
| summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| summaries.append(summary) | |
| return summaries | |
| # Streamlit App Code | |
| st.title("Legal Case Summarizer") | |
| st.write("Select a dataset and enter keywords to retrieve and summarize relevant cases.") | |
| # Dropdown for selecting dataset | |
| dataset_options = list(datasets_dict.keys()) | |
| selected_dataset = st.selectbox("Choose a dataset:", dataset_options) | |
| # Prepare the selected dataset | |
| prepare_dataset(selected_dataset) | |
| query = st.text_input("Enter search keywords:", "healthcare") | |
| if st.button("Retrieve and Summarize Cases"): | |
| with st.spinner("Retrieving and summarizing cases..."): | |
| cases = retrieve_cases(query) | |
| if cases: | |
| summaries = summarize_cases(cases) | |
| for i, (case, title) in enumerate(cases): | |
| summary = summaries[i] | |
| st.write(f"### Case {i + 1}") | |
| st.write(f"**Title:** {title}") | |
| st.write(f"**Case Text:** {case[0]}") | |
| st.write(f"**Summary:** {summary}") | |
| else: | |
| st.write("No cases found for the given query.") | |
| st.write("Using T5 for summarization and retrieval.") | |