Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import logging | |
| import os | |
| import json | |
| import shutil | |
| import glob | |
| import queue | |
| import lancedb | |
| from datetime import datetime | |
| from dotenv import load_dotenv, set_key | |
| import yaml | |
| import pandas as pd | |
| from typing import List, Optional | |
| from pydantic import BaseModel | |
| # Set up logging | |
| log_queue = queue.Queue() | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| load_dotenv('indexing/.env') | |
| API_BASE_URL = os.getenv('API_BASE_URL', 'http://localhost:8012') | |
| LLM_API_BASE = os.getenv('LLM_API_BASE', 'http://localhost:11434') | |
| EMBEDDINGS_API_BASE = os.getenv('EMBEDDINGS_API_BASE', 'http://localhost:11434') | |
| ROOT_DIR = os.getenv('ROOT_DIR', 'indexing') | |
| # Data models | |
| class IndexingRequest(BaseModel): | |
| llm_model: str | |
| embed_model: str | |
| llm_api_base: str | |
| embed_api_base: str | |
| root: str | |
| verbose: bool = False | |
| nocache: bool = False | |
| resume: Optional[str] = None | |
| reporter: str = "rich" | |
| emit: List[str] = ["parquet"] | |
| custom_args: Optional[str] = None | |
| class PromptTuneRequest(BaseModel): | |
| root: str = "./{ROOT_DIR}" | |
| domain: Optional[str] = None | |
| method: str = "random" | |
| limit: int = 15 | |
| language: Optional[str] = None | |
| max_tokens: int = 2000 | |
| chunk_size: int = 200 | |
| no_entity_types: bool = False | |
| output: str = "./{ROOT_DIR}/prompts" | |
| class QueueHandler(logging.Handler): | |
| def __init__(self, log_queue): | |
| super().__init__() | |
| self.log_queue = log_queue | |
| def emit(self, record): | |
| self.log_queue.put(self.format(record)) | |
| queue_handler = QueueHandler(log_queue) | |
| logging.getLogger().addHandler(queue_handler) | |
| def update_logs(): | |
| logs = [] | |
| while not log_queue.empty(): | |
| logs.append(log_queue.get()) | |
| return "\n".join(logs) | |
| ##########SETTINGS################ | |
| def load_settings(): | |
| config_path = os.getenv('GRAPHRAG_CONFIG', 'config.yaml') | |
| if os.path.exists(config_path): | |
| with open(config_path, 'r') as config_file: | |
| config = yaml.safe_load(config_file) | |
| else: | |
| config = {} | |
| settings = { | |
| 'llm_model': os.getenv('LLM_MODEL', config.get('llm_model')), | |
| 'embedding_model': os.getenv('EMBEDDINGS_MODEL', config.get('embedding_model')), | |
| 'community_level': int(os.getenv('COMMUNITY_LEVEL', config.get('community_level', 2))), | |
| 'token_limit': int(os.getenv('TOKEN_LIMIT', config.get('token_limit', 4096))), | |
| 'api_key': os.getenv('GRAPHRAG_API_KEY', config.get('api_key')), | |
| 'api_base': os.getenv('LLM_API_BASE', config.get('api_base')), | |
| 'embeddings_api_base': os.getenv('EMBEDDINGS_API_BASE', config.get('embeddings_api_base')), | |
| 'api_type': os.getenv('API_TYPE', config.get('api_type', 'openai')), | |
| } | |
| return settings | |
| #######FILE_MANAGEMENT############## | |
| def list_output_files(root_dir): | |
| output_dir = os.path.join(root_dir, "output") | |
| files = [] | |
| for root, _, filenames in os.walk(output_dir): | |
| for filename in filenames: | |
| files.append(os.path.join(root, filename)) | |
| return files | |
| def update_file_list(): | |
| files = list_input_files() | |
| return gr.update(choices=[f["path"] for f in files]) | |
| def update_file_content(file_path): | |
| if not file_path: | |
| return "" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| content = file.read() | |
| return content | |
| except Exception as e: | |
| logging.error(f"Error reading file: {str(e)}") | |
| return f"Error reading file: {str(e)}" | |
| def list_output_folders(): | |
| output_dir = os.path.join(ROOT_DIR, "output") | |
| folders = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f))] | |
| return sorted(folders, reverse=True) | |
| def update_output_folder_list(): | |
| folders = list_output_folders() | |
| return gr.update(choices=folders, value=folders[0] if folders else None) | |
| def list_folder_contents(folder_name): | |
| folder_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts") | |
| contents = [] | |
| if os.path.exists(folder_path): | |
| for item in os.listdir(folder_path): | |
| item_path = os.path.join(folder_path, item) | |
| if os.path.isdir(item_path): | |
| contents.append(f"[DIR] {item}") | |
| else: | |
| _, ext = os.path.splitext(item) | |
| contents.append(f"[{ext[1:].upper()}] {item}") | |
| return contents | |
| def update_folder_content_list(folder_name): | |
| if isinstance(folder_name, list) and folder_name: | |
| folder_name = folder_name[0] | |
| elif not folder_name: | |
| return gr.update(choices=[]) | |
| contents = list_folder_contents(folder_name) | |
| return gr.update(choices=contents) | |
| def handle_content_selection(folder_name, selected_item): | |
| if isinstance(selected_item, list) and selected_item: | |
| selected_item = selected_item[0] # Take the first item if it's a list | |
| if isinstance(selected_item, str) and selected_item.startswith("[DIR]"): | |
| dir_name = selected_item[6:] # Remove "[DIR] " prefix | |
| sub_contents = list_folder_contents(os.path.join(ROOT_DIR, "output", folder_name, dir_name)) | |
| return gr.update(choices=sub_contents), "", "" | |
| elif isinstance(selected_item, str): | |
| file_name = selected_item.split("] ")[1] if "]" in selected_item else selected_item # Remove file type prefix if present | |
| file_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts", file_name) | |
| file_size = os.path.getsize(file_path) | |
| file_type = os.path.splitext(file_name)[1] | |
| file_info = f"File: {file_name}\nSize: {file_size} bytes\nType: {file_type}" | |
| content = read_file_content(file_path) | |
| return gr.update(), file_info, content | |
| else: | |
| return gr.update(), "", "" | |
| def initialize_selected_folder(folder_name): | |
| if not folder_name: | |
| return "Please select a folder first.", gr.update(choices=[]) | |
| folder_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts") | |
| if not os.path.exists(folder_path): | |
| return f"Artifacts folder not found in '{folder_name}'.", gr.update(choices=[]) | |
| contents = list_folder_contents(folder_path) | |
| return f"Folder '{folder_name}/artifacts' initialized with {len(contents)} items.", gr.update(choices=contents) | |
| def upload_file(file): | |
| if file is not None: | |
| input_dir = os.path.join(ROOT_DIR, 'input') | |
| os.makedirs(input_dir, exist_ok=True) | |
| # Get the original filename from the uploaded file | |
| original_filename = file.name | |
| # Create the destination path | |
| destination_path = os.path.join(input_dir, os.path.basename(original_filename)) | |
| # Move the uploaded file to the destination path | |
| shutil.move(file.name, destination_path) | |
| logging.info(f"File uploaded and moved to: {destination_path}") | |
| status = f"File uploaded: {os.path.basename(original_filename)}" | |
| else: | |
| status = "No file uploaded" | |
| # Get the updated file list | |
| updated_file_list = [f["path"] for f in list_input_files()] | |
| return status, gr.update(choices=updated_file_list), update_logs() | |
| def list_input_files(): | |
| input_dir = os.path.join(ROOT_DIR, 'input') | |
| files = [] | |
| if os.path.exists(input_dir): | |
| files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] | |
| return [{"name": f, "path": os.path.join(input_dir, f)} for f in files] | |
| def delete_file(file_path): | |
| try: | |
| os.remove(file_path) | |
| logging.info(f"File deleted: {file_path}") | |
| status = f"File deleted: {os.path.basename(file_path)}" | |
| except Exception as e: | |
| logging.error(f"Error deleting file: {str(e)}") | |
| status = f"Error deleting file: {str(e)}" | |
| # Get the updated file list | |
| updated_file_list = [f["path"] for f in list_input_files()] | |
| return status, gr.update(choices=updated_file_list), update_logs() | |
| def read_file_content(file_path): | |
| try: | |
| if file_path.endswith('.parquet'): | |
| df = pd.read_parquet(file_path) | |
| # Get basic information about the DataFrame | |
| info = f"Parquet File: {os.path.basename(file_path)}\n" | |
| info += f"Rows: {len(df)}, Columns: {len(df.columns)}\n\n" | |
| info += "Column Names:\n" + "\n".join(df.columns) + "\n\n" | |
| # Display first few rows | |
| info += "First 5 rows:\n" | |
| info += df.head().to_string() + "\n\n" | |
| # Display basic statistics | |
| info += "Basic Statistics:\n" | |
| info += df.describe().to_string() | |
| return info | |
| else: | |
| with open(file_path, 'r', encoding='utf-8', errors='replace') as file: | |
| content = file.read() | |
| return content | |
| except Exception as e: | |
| logging.error(f"Error reading file: {str(e)}") | |
| return f"Error reading file: {str(e)}" | |
| def save_file_content(file_path, content): | |
| try: | |
| with open(file_path, 'w') as file: | |
| file.write(content) | |
| logging.info(f"File saved: {file_path}") | |
| status = f"File saved: {os.path.basename(file_path)}" | |
| except Exception as e: | |
| logging.error(f"Error saving file: {str(e)}") | |
| status = f"Error saving file: {str(e)}" | |
| return status, update_logs() | |
| def manage_data(): | |
| db = lancedb.connect(f"{ROOT_DIR}/lancedb") | |
| tables = db.table_names() | |
| table_info = "" | |
| if tables: | |
| table = db[tables[0]] | |
| table_info = f"Table: {tables[0]}\nSchema: {table.schema}" | |
| input_files = list_input_files() | |
| return { | |
| "database_info": f"Tables: {', '.join(tables)}\n\n{table_info}", | |
| "input_files": input_files | |
| } | |
| def find_latest_graph_file(root_dir): | |
| pattern = os.path.join(root_dir, "output", "*", "artifacts", "*.graphml") | |
| graph_files = glob.glob(pattern) | |
| if not graph_files: | |
| # If no files found, try excluding .DS_Store | |
| output_dir = os.path.join(root_dir, "output") | |
| run_dirs = [d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d)) and d != ".DS_Store"] | |
| if run_dirs: | |
| latest_run = max(run_dirs) | |
| pattern = os.path.join(root_dir, "output", latest_run, "artifacts", "*.graphml") | |
| graph_files = glob.glob(pattern) | |
| if not graph_files: | |
| return None | |
| # Sort files by modification time, most recent first | |
| latest_file = max(graph_files, key=os.path.getmtime) | |
| return latest_file | |
| def find_latest_output_folder(): | |
| root_dir =f"{ROOT_DIR}/output" | |
| folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))] | |
| if not folders: | |
| raise ValueError("No output folders found") | |
| # Sort folders by creation time, most recent first | |
| sorted_folders = sorted(folders, key=lambda x: os.path.getctime(os.path.join(root_dir, x)), reverse=True) | |
| latest_folder = None | |
| timestamp = None | |
| for folder in sorted_folders: | |
| try: | |
| # Try to parse the folder name as a timestamp | |
| timestamp = datetime.strptime(folder, "%Y%m%d-%H%M%S") | |
| latest_folder = folder | |
| break | |
| except ValueError: | |
| # If the folder name is not a valid timestamp, skip it | |
| continue | |
| if latest_folder is None: | |
| raise ValueError("No valid timestamp folders found") | |
| latest_path = os.path.join(root_dir, latest_folder) | |
| artifacts_path = os.path.join(latest_path, "artifacts") | |
| if not os.path.exists(artifacts_path): | |
| raise ValueError(f"Artifacts folder not found in {latest_path}") | |
| return latest_path, latest_folder | |
| def initialize_data(): | |
| global entity_df, relationship_df, text_unit_df, report_df, covariate_df | |
| tables = { | |
| "entity_df": "create_final_nodes", | |
| "relationship_df": "create_final_edges", | |
| "text_unit_df": "create_final_text_units", | |
| "report_df": "create_final_reports", | |
| "covariate_df": "create_final_covariates" | |
| } | |
| timestamp = None # Initialize timestamp to None | |
| try: | |
| latest_output_folder, timestamp = find_latest_output_folder() | |
| artifacts_folder = os.path.join(latest_output_folder, "artifacts") | |
| for df_name, file_prefix in tables.items(): | |
| file_pattern = os.path.join(artifacts_folder, f"{file_prefix}*.parquet") | |
| matching_files = glob.glob(file_pattern) | |
| if matching_files: | |
| latest_file = max(matching_files, key=os.path.getctime) | |
| df = pd.read_parquet(latest_file) | |
| globals()[df_name] = df | |
| logging.info(f"Successfully loaded {df_name} from {latest_file}") | |
| else: | |
| logging.warning(f"No matching file found for {df_name} in {artifacts_folder}. Initializing as an empty DataFrame.") | |
| globals()[df_name] = pd.DataFrame() | |
| except Exception as e: | |
| logging.error(f"Error initializing data: {str(e)}") | |
| for df_name in tables.keys(): | |
| globals()[df_name] = pd.DataFrame() | |
| return timestamp | |
| # Call initialize_data and store the timestamp | |
| current_timestamp = initialize_data() | |
| ###########MODELS################## | |
| def normalize_api_base(api_base: str) -> str: | |
| """Normalize the API base URL by removing trailing slashes and /v1 or /api suffixes.""" | |
| api_base = api_base.rstrip('/') | |
| if api_base.endswith('/v1') or api_base.endswith('/api'): | |
| api_base = api_base[:-3] | |
| return api_base | |
| def is_ollama_api(base_url: str) -> bool: | |
| """Check if the given base URL is for Ollama API.""" | |
| try: | |
| response = requests.get(f"{normalize_api_base(base_url)}/api/tags") | |
| return response.status_code == 200 | |
| except requests.RequestException: | |
| return False | |
| def get_ollama_models(base_url: str) -> List[str]: | |
| """Fetch available models from Ollama API.""" | |
| try: | |
| response = requests.get(f"{normalize_api_base(base_url)}/api/tags") | |
| response.raise_for_status() | |
| models = response.json().get('models', []) | |
| return [model['name'] for model in models] | |
| except requests.RequestException as e: | |
| logger.error(f"Error fetching Ollama models: {str(e)}") | |
| return [] | |
| def get_openai_compatible_models(base_url: str) -> List[str]: | |
| """Fetch available models from OpenAI-compatible API.""" | |
| try: | |
| response = requests.get(f"{normalize_api_base(base_url)}/v1/models") | |
| response.raise_for_status() | |
| models = response.json().get('data', []) | |
| return [model['id'] for model in models] | |
| except requests.RequestException as e: | |
| logger.error(f"Error fetching OpenAI-compatible models: {str(e)}") | |
| return [] | |
| def get_local_models(base_url: str) -> List[str]: | |
| """Get available models based on the API type.""" | |
| if is_ollama_api(base_url): | |
| return get_ollama_models(base_url) | |
| else: | |
| return get_openai_compatible_models(base_url) | |
| def get_model_params(base_url: str, model_name: str) -> dict: | |
| """Get model parameters for Ollama models.""" | |
| if is_ollama_api(base_url): | |
| try: | |
| response = requests.post(f"{normalize_api_base(base_url)}/api/show", json={"name": model_name}) | |
| response.raise_for_status() | |
| model_info = response.json() | |
| return model_info.get('parameters', {}) | |
| except requests.RequestException as e: | |
| logger.error(f"Error fetching Ollama model parameters: {str(e)}") | |
| return {} | |
| #########API########### | |
| def start_indexing(request: IndexingRequest): | |
| url = f"{API_BASE_URL}/v1/index" | |
| try: | |
| response = requests.post(url, json=request.dict()) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result['message'], gr.update(interactive=False), gr.update(interactive=True) | |
| except requests.RequestException as e: | |
| logger.error(f"Error starting indexing: {str(e)}") | |
| return f"Error: {str(e)}", gr.update(interactive=True), gr.update(interactive=False) | |
| def check_indexing_status(): | |
| url = f"{API_BASE_URL}/v1/index_status" | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result['status'], "\n".join(result['logs']) | |
| except requests.RequestException as e: | |
| logger.error(f"Error checking indexing status: {str(e)}") | |
| return "Error", f"Failed to check indexing status: {str(e)}" | |
| def start_prompt_tuning(request: PromptTuneRequest): | |
| url = f"{API_BASE_URL}/v1/prompt_tune" | |
| try: | |
| response = requests.post(url, json=request.dict()) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result['message'], gr.update(interactive=False) | |
| except requests.RequestException as e: | |
| logger.error(f"Error starting prompt tuning: {str(e)}") | |
| return f"Error: {str(e)}", gr.update(interactive=True) | |
| def check_prompt_tuning_status(): | |
| url = f"{API_BASE_URL}/v1/prompt_tune_status" | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result['status'], "\n".join(result['logs']) | |
| except requests.RequestException as e: | |
| logger.error(f"Error checking prompt tuning status: {str(e)}") | |
| return "Error", f"Failed to check prompt tuning status: {str(e)}" | |
| def update_model_params(model_name): | |
| params = get_model_params(model_name) | |
| return gr.update(value=json.dumps(params, indent=2)) | |
| ########################### | |
| css = """ | |
| html, body { | |
| margin: 0; | |
| padding: 0; | |
| height: 100vh; | |
| overflow: hidden; | |
| } | |
| .gradio-container { | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| width: 100vw !important; | |
| max-width: 100vw !important; | |
| height: 100vh !important; | |
| max-height: 100vh !important; | |
| overflow: auto; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| #main-container { | |
| flex: 1; | |
| display: flex; | |
| overflow: hidden; | |
| } | |
| #left-column, #right-column { | |
| height: 100%; | |
| overflow-y: auto; | |
| padding: 10px; | |
| } | |
| #left-column { | |
| flex: 1; | |
| } | |
| #right-column { | |
| flex: 2; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| #chat-container { | |
| flex: 0 0 auto; /* Don't allow this to grow */ | |
| height: 100%; | |
| display: flex; | |
| flex-direction: column; | |
| overflow: hidden; | |
| border: 1px solid var(--color-accent); | |
| border-radius: 8px; | |
| padding: 10px; | |
| overflow-y: auto; | |
| } | |
| #chatbot { | |
| overflow-y: hidden; | |
| height: 100%; | |
| } | |
| #chat-input-row { | |
| margin-top: 10px; | |
| } | |
| #visualization-plot { | |
| width: 100%; | |
| aspect-ratio: 1 / 1; | |
| max-height: 600px; /* Adjust this value as needed */ | |
| } | |
| #vis-controls-row { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-top: 10px; | |
| } | |
| #vis-controls-row > * { | |
| flex: 1; | |
| margin: 0 5px; | |
| } | |
| #vis-status { | |
| margin-top: 10px; | |
| } | |
| /* Chat input styling */ | |
| #chat-input-row { | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| #chat-input-row > div { | |
| width: 100% !important; | |
| } | |
| #chat-input-row input[type="text"] { | |
| width: 100% !important; | |
| } | |
| /* Adjust padding for all containers */ | |
| .gr-box, .gr-form, .gr-panel { | |
| padding: 10px !important; | |
| } | |
| /* Ensure all textboxes and textareas have full height */ | |
| .gr-textbox, .gr-textarea { | |
| height: auto !important; | |
| min-height: 100px !important; | |
| } | |
| /* Ensure all dropdowns have full width */ | |
| .gr-dropdown { | |
| width: 100% !important; | |
| } | |
| :root { | |
| --color-background: #2C3639; | |
| --color-foreground: #3F4E4F; | |
| --color-accent: #A27B5C; | |
| --color-text: #DCD7C9; | |
| } | |
| body, .gradio-container { | |
| background-color: var(--color-background); | |
| color: var(--color-text); | |
| } | |
| .gr-button { | |
| background-color: var(--color-accent); | |
| color: var(--color-text); | |
| } | |
| .gr-input, .gr-textarea, .gr-dropdown { | |
| background-color: var(--color-foreground); | |
| color: var(--color-text); | |
| border: 1px solid var(--color-accent); | |
| } | |
| .gr-panel { | |
| background-color: var(--color-foreground); | |
| border: 1px solid var(--color-accent); | |
| } | |
| .gr-box { | |
| border-radius: 8px; | |
| margin-bottom: 10px; | |
| background-color: var(--color-foreground); | |
| } | |
| .gr-padded { | |
| padding: 10px; | |
| } | |
| .gr-form { | |
| background-color: var(--color-foreground); | |
| } | |
| .gr-input-label, .gr-radio-label { | |
| color: var(--color-text); | |
| } | |
| .gr-checkbox-label { | |
| color: var(--color-text); | |
| } | |
| .gr-markdown { | |
| color: var(--color-text); | |
| } | |
| .gr-accordion { | |
| background-color: var(--color-foreground); | |
| border: 1px solid var(--color-accent); | |
| } | |
| .gr-accordion-header { | |
| background-color: var(--color-accent); | |
| color: var(--color-text); | |
| } | |
| #visualization-container { | |
| display: flex; | |
| flex-direction: column; | |
| border: 2px solid var(--color-accent); | |
| border-radius: 8px; | |
| margin-top: 20px; | |
| padding: 10px; | |
| background-color: var(--color-foreground); | |
| height: calc(100vh - 300px); /* Adjust this value as needed */ | |
| } | |
| #visualization-plot { | |
| width: 100%; | |
| height: 100%; | |
| } | |
| #vis-controls-row { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-top: 10px; | |
| } | |
| #vis-controls-row > * { | |
| flex: 1; | |
| margin: 0 5px; | |
| } | |
| #vis-status { | |
| margin-top: 10px; | |
| } | |
| #log-container { | |
| background-color: var(--color-foreground); | |
| border: 1px solid var(--color-accent); | |
| border-radius: 8px; | |
| padding: 10px; | |
| margin-top: 20px; | |
| max-height: auto; | |
| overflow-y: auto; | |
| } | |
| .setting-accordion .label-wrap { | |
| cursor: pointer; | |
| } | |
| .setting-accordion .icon { | |
| transition: transform 0.3s ease; | |
| } | |
| .setting-accordion[open] .icon { | |
| transform: rotate(90deg); | |
| } | |
| .gr-form.gr-box { | |
| border: none !important; | |
| background: none !important; | |
| } | |
| .model-params { | |
| border-top: 1px solid var(--color-accent); | |
| margin-top: 10px; | |
| padding-top: 10px; | |
| } | |
| """ | |
| def create_interface(): | |
| settings = load_settings() | |
| llm_api_base = normalize_api_base(settings['api_base']) | |
| embeddings_api_base = normalize_api_base(settings['embeddings_api_base']) | |
| with gr.Blocks(theme=gr.themes.Base(), css=css) as demo: | |
| gr.Markdown("# GraphRAG Indexer") | |
| with gr.Tabs(): | |
| with gr.TabItem("Indexing"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Indexing Configuration") | |
| with gr.Row(): | |
| llm_name = gr.Dropdown(label="LLM Model", choices=[], value=settings['llm_model'], allow_custom_value=True) | |
| refresh_llm_btn = gr.Button("🔄", size='sm', scale=0) | |
| with gr.Row(): | |
| embed_name = gr.Dropdown(label="Embedding Model", choices=[], value=settings['embedding_model'], allow_custom_value=True) | |
| refresh_embed_btn = gr.Button("🔄", size='sm', scale=0) | |
| save_config_button = gr.Button("Save Configuration", variant="primary") | |
| config_status = gr.Textbox(label="Configuration Status", lines=2) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| root_dir = gr.Textbox(label="Root Directory (Edit in .env file)", value=f"{ROOT_DIR}") | |
| with gr.Group(): | |
| verbose = gr.Checkbox(label="Verbose", interactive=True, value=True) | |
| nocache = gr.Checkbox(label="No Cache", interactive=True, value=True) | |
| with gr.Accordion("Advanced Options", open=True): | |
| resume = gr.Textbox(label="Resume Timestamp (optional)") | |
| reporter = gr.Dropdown( | |
| label="Reporter", | |
| choices=["rich", "print", "none"], | |
| value="rich", | |
| interactive=True | |
| ) | |
| emit_formats = gr.CheckboxGroup( | |
| label="Emit Formats", | |
| choices=["json", "csv", "parquet"], | |
| value=["parquet"], | |
| interactive=True | |
| ) | |
| custom_args = gr.Textbox(label="Custom CLI Arguments", placeholder="--arg1 value1 --arg2 value2") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Indexing Output") | |
| index_output = gr.Textbox(label="Output", lines=10) | |
| index_status = gr.Textbox(label="Status", lines=2) | |
| run_index_button = gr.Button("Run Indexing", variant="primary") | |
| check_status_button = gr.Button("Check Indexing Status") | |
| with gr.TabItem("Prompt Tuning"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Prompt Tuning Configuration") | |
| pt_root = gr.Textbox(label="Root Directory", value=f"{ROOT_DIR}", interactive=True) | |
| pt_domain = gr.Textbox(label="Domain (optional)") | |
| pt_method = gr.Dropdown( | |
| label="Method", | |
| choices=["random", "top", "all"], | |
| value="random", | |
| interactive=True | |
| ) | |
| pt_limit = gr.Number(label="Limit", value=15, precision=0, interactive=True) | |
| pt_language = gr.Textbox(label="Language (optional)") | |
| pt_max_tokens = gr.Number(label="Max Tokens", value=2000, precision=0, interactive=True) | |
| pt_chunk_size = gr.Number(label="Chunk Size", value=200, precision=0, interactive=True) | |
| pt_no_entity_types = gr.Checkbox(label="No Entity Types", value=False) | |
| pt_output_dir = gr.Textbox(label="Output Directory", value=f"{ROOT_DIR}/prompts", interactive=True) | |
| save_pt_config_button = gr.Button("Save Prompt Tuning Configuration", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Prompt Tuning Output") | |
| pt_output = gr.Textbox(label="Output", lines=10) | |
| pt_status = gr.Textbox(label="Status", lines=10) | |
| run_pt_button = gr.Button("Run Prompt Tuning", variant="primary") | |
| check_pt_status_button = gr.Button("Check Prompt Tuning Status") | |
| with gr.TabItem("Data Management"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Accordion("File Upload", open=True): | |
| file_upload = gr.File(label="Upload File", file_types=[".txt", ".csv", ".parquet"]) | |
| upload_btn = gr.Button("Upload File", variant="primary") | |
| upload_output = gr.Textbox(label="Upload Status", visible=True) | |
| with gr.Accordion("File Management", open=True): | |
| file_list = gr.Dropdown(label="Select File", choices=[], interactive=True) | |
| refresh_btn = gr.Button("Refresh File List", variant="secondary") | |
| file_content = gr.TextArea(label="File Content", lines=10) | |
| with gr.Row(): | |
| delete_btn = gr.Button("Delete Selected File", variant="stop") | |
| save_btn = gr.Button("Save Changes", variant="primary") | |
| operation_status = gr.Textbox(label="Operation Status", visible=True) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Output Folders", open=True): | |
| output_folder_list = gr.Dropdown(label="Select Output Folder", choices=[], interactive=True) | |
| refresh_output_btn = gr.Button("Refresh Output Folders", variant="secondary") | |
| folder_content_list = gr.Dropdown(label="Folder Contents", choices=[], interactive=True, multiselect=False) | |
| file_info = gr.Textbox(label="File Info", lines=3) | |
| output_content = gr.TextArea(label="File Content", lines=10) | |
| # Event handlers | |
| def refresh_llm_models(): | |
| models = get_local_models(llm_api_base) | |
| return gr.update(choices=models) | |
| def refresh_embed_models(): | |
| models = get_local_models(embeddings_api_base) | |
| return gr.update(choices=models) | |
| refresh_llm_btn.click( | |
| refresh_llm_models, | |
| outputs=[llm_name] | |
| ) | |
| refresh_embed_btn.click( | |
| refresh_embed_models, | |
| outputs=[embed_name] | |
| ) | |
| # Initialize model lists on page load | |
| demo.load(refresh_llm_models, outputs=[llm_name]) | |
| demo.load(refresh_embed_models, outputs=[embed_name]) | |
| def create_indexing_request(): | |
| return IndexingRequest( | |
| llm_model=llm_name.value, | |
| embed_model=embed_name.value, | |
| llm_api_base=llm_api_base, | |
| embed_api_base=embeddings_api_base, | |
| root=root_dir.value, | |
| verbose=verbose.value, | |
| nocache=nocache.value, | |
| resume=resume.value if resume.value else None, | |
| reporter=reporter.value, | |
| emit=[fmt for fmt in emit_formats.value], | |
| custom_args=custom_args.value if custom_args.value else None | |
| ) | |
| run_index_button.click( | |
| lambda: start_indexing(create_indexing_request()), | |
| outputs=[index_output, run_index_button, check_status_button] | |
| ) | |
| check_status_button.click( | |
| check_indexing_status, | |
| outputs=[index_status, index_output] | |
| ) | |
| def create_prompt_tune_request(): | |
| return PromptTuneRequest( | |
| root=pt_root.value, | |
| domain=pt_domain.value if pt_domain.value else None, | |
| method=pt_method.value, | |
| limit=int(pt_limit.value), | |
| language=pt_language.value if pt_language.value else None, | |
| max_tokens=int(pt_max_tokens.value), | |
| chunk_size=int(pt_chunk_size.value), | |
| no_entity_types=pt_no_entity_types.value, | |
| output=pt_output_dir.value | |
| ) | |
| def update_pt_output(request): | |
| result, button_update = start_prompt_tuning(request) | |
| return result, button_update, gr.update(value=f"Request: {request.dict()}") | |
| run_pt_button.click( | |
| lambda: update_pt_output(create_prompt_tune_request()), | |
| outputs=[pt_output, run_pt_button, pt_status] | |
| ) | |
| check_pt_status_button.click( | |
| check_prompt_tuning_status, | |
| outputs=[pt_status, pt_output] | |
| ) | |
| # Add event handlers for real-time updates | |
| pt_root.change(lambda x: gr.update(value=f"Root Directory changed to: {x}"), inputs=[pt_root], outputs=[pt_status]) | |
| pt_limit.change(lambda x: gr.update(value=f"Limit changed to: {x}"), inputs=[pt_limit], outputs=[pt_status]) | |
| pt_max_tokens.change(lambda x: gr.update(value=f"Max Tokens changed to: {x}"), inputs=[pt_max_tokens], outputs=[pt_status]) | |
| pt_chunk_size.change(lambda x: gr.update(value=f"Chunk Size changed to: {x}"), inputs=[pt_chunk_size], outputs=[pt_status]) | |
| pt_output_dir.change(lambda x: gr.update(value=f"Output Directory changed to: {x}"), inputs=[pt_output_dir], outputs=[pt_status]) | |
| # Event handlers for Data Management | |
| upload_btn.click( | |
| upload_file, | |
| inputs=[file_upload], | |
| outputs=[upload_output, file_list, operation_status] | |
| ) | |
| refresh_btn.click( | |
| update_file_list, | |
| outputs=[file_list] | |
| ) | |
| refresh_output_btn.click( | |
| update_output_folder_list, | |
| outputs=[output_folder_list] | |
| ) | |
| file_list.change( | |
| update_file_content, | |
| inputs=[file_list], | |
| outputs=[file_content] | |
| ) | |
| delete_btn.click( | |
| delete_file, | |
| inputs=[file_list], | |
| outputs=[operation_status, file_list, operation_status] | |
| ) | |
| save_btn.click( | |
| save_file_content, | |
| inputs=[file_list, file_content], | |
| outputs=[operation_status, operation_status] | |
| ) | |
| output_folder_list.change( | |
| update_folder_content_list, | |
| inputs=[output_folder_list], | |
| outputs=[folder_content_list] | |
| ) | |
| folder_content_list.change( | |
| handle_content_selection, | |
| inputs=[output_folder_list, folder_content_list], | |
| outputs=[folder_content_list, file_info, output_content] | |
| ) | |
| # Event handler for saving configuration | |
| save_config_button.click( | |
| update_env_file, | |
| inputs=[llm_name, embed_name], | |
| outputs=[config_status] | |
| ) | |
| # Event handler for saving prompt tuning configuration | |
| save_pt_config_button.click( | |
| save_prompt_tuning_config, | |
| inputs=[pt_root, pt_domain, pt_method, pt_limit, pt_language, pt_max_tokens, pt_chunk_size, pt_no_entity_types, pt_output_dir], | |
| outputs=[pt_status] | |
| ) | |
| # Initialize file list and output folder list | |
| demo.load(update_file_list, outputs=[file_list]) | |
| demo.load(update_output_folder_list, outputs=[output_folder_list]) | |
| return demo | |
| def update_env_file(llm_model, embed_model): | |
| env_path = os.path.join(ROOT_DIR, '.env') | |
| set_key(env_path, 'LLM_MODEL', llm_model) | |
| set_key(env_path, 'EMBEDDINGS_MODEL', embed_model) | |
| # Reload the environment variables | |
| load_dotenv(env_path, override=True) | |
| return f"Environment updated: LLM_MODEL={llm_model}, EMBEDDINGS_MODEL={embed_model}" | |
| def save_prompt_tuning_config(root, domain, method, limit, language, max_tokens, chunk_size, no_entity_types, output_dir): | |
| config = { | |
| 'prompt_tuning': { | |
| 'root': root, | |
| 'domain': domain, | |
| 'method': method, | |
| 'limit': limit, | |
| 'language': language, | |
| 'max_tokens': max_tokens, | |
| 'chunk_size': chunk_size, | |
| 'no_entity_types': no_entity_types, | |
| 'output': output_dir | |
| } | |
| } | |
| config_path = os.path.join(ROOT_DIR, 'prompt_tuning_config.yaml') | |
| with open(config_path, 'w') as f: | |
| yaml.dump(config, f) | |
| return f"Prompt Tuning configuration saved to {config_path}" | |
| demo = create_interface() | |
| if __name__ == "__main__": | |
| demo.launch(server_port=7861) | |