Spaces:
Runtime error
Runtime error
| from dotenv import load_dotenv | |
| import os | |
| import asyncio | |
| import tempfile | |
| from collections import deque | |
| import time | |
| import uuid | |
| import json | |
| import re | |
| import pandas as pd | |
| import tiktoken | |
| import logging | |
| import yaml | |
| import shutil | |
| from fastapi import Body | |
| from fastapi import FastAPI, HTTPException, Request, BackgroundTasks, Depends | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional, Dict, Any, Union | |
| from contextlib import asynccontextmanager | |
| from web import DuckDuckGoSearchAPIWrapper | |
| from functools import lru_cache | |
| import requests | |
| import subprocess | |
| import argparse | |
| # GraphRAG related imports | |
| from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey | |
| from graphrag.query.indexer_adapters import ( | |
| read_indexer_covariates, | |
| read_indexer_entities, | |
| read_indexer_relationships, | |
| read_indexer_reports, | |
| read_indexer_text_units, | |
| ) | |
| from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings | |
| from graphrag.query.llm.oai.chat_openai import ChatOpenAI | |
| from graphrag.query.llm.oai.embedding import OpenAIEmbedding | |
| from graphrag.query.llm.oai.typing import OpenaiApiType | |
| from graphrag.query.question_gen.local_gen import LocalQuestionGen | |
| from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext | |
| from graphrag.query.structured_search.local_search.search import LocalSearch | |
| from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext | |
| from graphrag.query.structured_search.global_search.search import GlobalSearch | |
| from graphrag.vector_stores.lancedb import LanceDBVectorStore | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv('indexing/.env') | |
| LLM_API_BASE = os.getenv('LLM_API_BASE', '') | |
| LLM_MODEL = os.getenv('LLM_MODEL') | |
| LLM_PROVIDER = os.getenv('LLM_PROVIDER', 'openai').lower() | |
| EMBEDDINGS_API_BASE = os.getenv('EMBEDDINGS_API_BASE', '') | |
| EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL') | |
| EMBEDDINGS_PROVIDER = os.getenv('EMBEDDINGS_PROVIDER', 'openai').lower() | |
| INPUT_DIR = os.getenv('INPUT_DIR', './indexing/output') | |
| ROOT_DIR = os.getenv('ROOT_DIR', 'indexing') | |
| PORT = int(os.getenv('API_PORT', 8012)) | |
| LANCEDB_URI = f"{INPUT_DIR}/lancedb" | |
| COMMUNITY_REPORT_TABLE = "create_final_community_reports" | |
| ENTITY_TABLE = "create_final_nodes" | |
| ENTITY_EMBEDDING_TABLE = "create_final_entities" | |
| RELATIONSHIP_TABLE = "create_final_relationships" | |
| COVARIATE_TABLE = "create_final_covariates" | |
| TEXT_UNIT_TABLE = "create_final_text_units" | |
| COMMUNITY_LEVEL = 2 | |
| # Global variables for storing search engines and question generator | |
| local_search_engine = None | |
| global_search_engine = None | |
| question_generator = None | |
| # Data models | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class QueryOptions(BaseModel): | |
| query_type: str | |
| preset: Optional[str] = None | |
| community_level: Optional[int] = None | |
| response_type: Optional[str] = None | |
| custom_cli_args: Optional[str] = None | |
| selected_folder: Optional[str] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[Message] | |
| temperature: Optional[float] = 0.7 | |
| max_tokens: Optional[int] = None | |
| stream: Optional[bool] = False | |
| query_options: Optional[QueryOptions] = None | |
| class ChatCompletionResponseChoice(BaseModel): | |
| index: int | |
| message: Message | |
| finish_reason: Optional[str] = None | |
| class Usage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class ChatCompletionResponse(BaseModel): | |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") | |
| object: str = "chat.completion" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| model: str | |
| choices: List[ChatCompletionResponseChoice] | |
| usage: Usage | |
| system_fingerprint: Optional[str] = None | |
| def list_output_folders(): | |
| return [f for f in os.listdir(INPUT_DIR) if os.path.isdir(os.path.join(INPUT_DIR, f))] | |
| def list_folder_contents(folder_name): | |
| folder_path = os.path.join(INPUT_DIR, folder_name, "artifacts") | |
| if not os.path.exists(folder_path): | |
| return [] | |
| return [item for item in os.listdir(folder_path) if item.endswith('.parquet')] | |
| def normalize_api_base(api_base: str) -> str: | |
| """Normalize the API base URL by removing trailing slashes and /v1 or /api suffixes.""" | |
| api_base = api_base.rstrip('/') | |
| if api_base.endswith('/v1') or api_base.endswith('/api'): | |
| api_base = api_base[:-3] | |
| return api_base | |
| def get_models_endpoint(api_base: str, api_type: str) -> str: | |
| """Get the appropriate models endpoint based on the API type.""" | |
| normalized_base = normalize_api_base(api_base) | |
| if api_type.lower() == 'openai': | |
| return f"{normalized_base}/v1/models" | |
| elif api_type.lower() == 'azure': | |
| return f"{normalized_base}/openai/deployments?api-version=2022-12-01" | |
| else: # For other API types (e.g., local LLMs) | |
| return f"{normalized_base}/models" | |
| async def fetch_available_models(settings: Dict[str, Any]) -> List[str]: | |
| """Fetch available models from the API.""" | |
| api_base = settings['api_base'] | |
| api_type = settings['api_type'] | |
| api_key = settings['api_key'] | |
| models_endpoint = get_models_endpoint(api_base, api_type) | |
| headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} | |
| try: | |
| response = requests.get(models_endpoint, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| data = response.json() | |
| if api_type.lower() == 'openai': | |
| return [model['id'] for model in data['data']] | |
| elif api_type.lower() == 'azure': | |
| return [model['id'] for model in data['value']] | |
| else: | |
| # Adjust this based on the actual response format of your local LLM API | |
| return [model['name'] for model in data['models']] | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Error fetching models: {str(e)}") | |
| return [] | |
| def load_settings(): | |
| config_path = os.getenv('GRAPHRAG_CONFIG', 'config.yaml') | |
| if os.path.exists(config_path): | |
| with open(config_path, 'r') as config_file: | |
| config = yaml.safe_load(config_file) | |
| else: | |
| config = {} | |
| settings = { | |
| 'llm_model': os.getenv('LLM_MODEL', config.get('llm_model')), | |
| 'embedding_model': os.getenv('EMBEDDINGS_MODEL', config.get('embedding_model')), | |
| 'community_level': int(os.getenv('COMMUNITY_LEVEL', config.get('community_level', 2))), | |
| 'token_limit': int(os.getenv('TOKEN_LIMIT', config.get('token_limit', 4096))), | |
| 'api_key': os.getenv('GRAPHRAG_API_KEY', config.get('api_key')), | |
| 'api_base': os.getenv('LLM_API_BASE', config.get('api_base')), | |
| 'embeddings_api_base': os.getenv('EMBEDDINGS_API_BASE', config.get('embeddings_api_base')), | |
| 'api_type': os.getenv('API_TYPE', config.get('api_type', 'openai')), | |
| } | |
| return settings | |
| return settings | |
| async def setup_llm_and_embedder(settings): | |
| logger.info("Setting up LLM and embedder") | |
| try: | |
| llm = ChatOpenAI( | |
| api_key=settings['api_key'], | |
| api_base=f"{settings['api_base']}/v1", | |
| model=settings['llm_model'], | |
| api_type=OpenaiApiType[settings['api_type'].capitalize()], | |
| max_retries=20, | |
| ) | |
| token_encoder = tiktoken.get_encoding("cl100k_base") | |
| text_embedder = OpenAIEmbedding( | |
| api_key=settings['api_key'], | |
| api_base=f"{settings['embeddings_api_base']}/v1", | |
| api_type=OpenaiApiType[settings['api_type'].capitalize()], | |
| model=settings['embedding_model'], | |
| deployment_name=settings['embedding_model'], | |
| max_retries=20, | |
| ) | |
| logger.info("LLM and embedder setup complete") | |
| return llm, token_encoder, text_embedder | |
| except Exception as e: | |
| logger.error(f"Error setting up LLM and embedder: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to set up LLM and embedder: {str(e)}") | |
| async def load_context(selected_folder, settings): | |
| """ | |
| Load context data including entities, relationships, reports, text units, and covariates | |
| """ | |
| logger.info("Loading context data") | |
| try: | |
| input_dir = os.path.join(INPUT_DIR, selected_folder, "artifacts") | |
| entity_df = pd.read_parquet(f"{input_dir}/{ENTITY_TABLE}.parquet") | |
| entity_embedding_df = pd.read_parquet(f"{input_dir}/{ENTITY_EMBEDDING_TABLE}.parquet") | |
| entities = read_indexer_entities(entity_df, entity_embedding_df, settings['community_level']) | |
| description_embedding_store = LanceDBVectorStore(collection_name="entity_description_embeddings") | |
| description_embedding_store.connect(db_uri=LANCEDB_URI) | |
| store_entity_semantic_embeddings(entities=entities, vectorstore=description_embedding_store) | |
| relationship_df = pd.read_parquet(f"{input_dir}/{RELATIONSHIP_TABLE}.parquet") | |
| relationships = read_indexer_relationships(relationship_df) | |
| report_df = pd.read_parquet(f"{input_dir}/{COMMUNITY_REPORT_TABLE}.parquet") | |
| reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL) | |
| text_unit_df = pd.read_parquet(f"{input_dir}/{TEXT_UNIT_TABLE}.parquet") | |
| text_units = read_indexer_text_units(text_unit_df) | |
| covariate_df = pd.read_parquet(f"{input_dir}/{COVARIATE_TABLE}.parquet") | |
| claims = read_indexer_covariates(covariate_df) | |
| logger.info(f"Number of claim records: {len(claims)}") | |
| covariates = {"claims": claims} | |
| logger.info("Context data loading complete") | |
| return entities, relationships, reports, text_units, description_embedding_store, covariates | |
| except Exception as e: | |
| logger.error(f"Error loading context data: {str(e)}") | |
| raise | |
| async def setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units, | |
| description_embedding_store, covariates): | |
| """ | |
| Set up local and global search engines | |
| """ | |
| logger.info("Setting up search engines") | |
| # Set up local search engine | |
| local_context_builder = LocalSearchMixedContext( | |
| community_reports=reports, | |
| text_units=text_units, | |
| entities=entities, | |
| relationships=relationships, | |
| covariates=covariates, | |
| entity_text_embeddings=description_embedding_store, | |
| embedding_vectorstore_key=EntityVectorStoreKey.ID, | |
| text_embedder=text_embedder, | |
| token_encoder=token_encoder, | |
| ) | |
| local_context_params = { | |
| "text_unit_prop": 0.5, | |
| "community_prop": 0.1, | |
| "conversation_history_max_turns": 5, | |
| "conversation_history_user_turns_only": True, | |
| "top_k_mapped_entities": 10, | |
| "top_k_relationships": 10, | |
| "include_entity_rank": True, | |
| "include_relationship_weight": True, | |
| "include_community_rank": False, | |
| "return_candidate_context": False, | |
| "embedding_vectorstore_key": EntityVectorStoreKey.ID, | |
| "max_tokens": 12_000, | |
| } | |
| local_llm_params = { | |
| "max_tokens": 2_000, | |
| "temperature": 0.0, | |
| } | |
| local_search_engine = LocalSearch( | |
| llm=llm, | |
| context_builder=local_context_builder, | |
| token_encoder=token_encoder, | |
| llm_params=local_llm_params, | |
| context_builder_params=local_context_params, | |
| response_type="multiple paragraphs", | |
| ) | |
| # Set up global search engine | |
| global_context_builder = GlobalCommunityContext( | |
| community_reports=reports, | |
| entities=entities, | |
| token_encoder=token_encoder, | |
| ) | |
| global_context_builder_params = { | |
| "use_community_summary": False, | |
| "shuffle_data": True, | |
| "include_community_rank": True, | |
| "min_community_rank": 0, | |
| "community_rank_name": "rank", | |
| "include_community_weight": True, | |
| "community_weight_name": "occurrence weight", | |
| "normalize_community_weight": True, | |
| "max_tokens": 12_000, | |
| "context_name": "Reports", | |
| } | |
| map_llm_params = { | |
| "max_tokens": 1000, | |
| "temperature": 0.0, | |
| "response_format": {"type": "json_object"}, | |
| } | |
| reduce_llm_params = { | |
| "max_tokens": 2000, | |
| "temperature": 0.0, | |
| } | |
| global_search_engine = GlobalSearch( | |
| llm=llm, | |
| context_builder=global_context_builder, | |
| token_encoder=token_encoder, | |
| max_data_tokens=12_000, | |
| map_llm_params=map_llm_params, | |
| reduce_llm_params=reduce_llm_params, | |
| allow_general_knowledge=False, | |
| json_mode=True, | |
| context_builder_params=global_context_builder_params, | |
| concurrent_coroutines=32, | |
| response_type="multiple paragraphs", | |
| ) | |
| logger.info("Search engines setup complete") | |
| return local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params | |
| def format_response(response): | |
| """ | |
| Format the response by adding appropriate line breaks and paragraph separations. | |
| """ | |
| paragraphs = re.split(r'\n{2,}', response) | |
| formatted_paragraphs = [] | |
| for para in paragraphs: | |
| if '```' in para: | |
| parts = para.split('```') | |
| for i, part in enumerate(parts): | |
| if i % 2 == 1: # This is a code block | |
| parts[i] = f"\n```\n{part.strip()}\n```\n" | |
| para = ''.join(parts) | |
| else: | |
| para = para.replace('. ', '.\n') | |
| formatted_paragraphs.append(para.strip()) | |
| return '\n\n'.join(formatted_paragraphs) | |
| async def lifespan(app: FastAPI): | |
| global settings | |
| try: | |
| logger.info("Loading settings...") | |
| settings = load_settings() | |
| logger.info("Settings loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Error loading settings: {str(e)}") | |
| raise | |
| yield | |
| logger.info("Shutting down...") | |
| app = FastAPI(lifespan=lifespan) | |
| # Create a cache for loaded contexts | |
| context_cache = {} | |
| def get_settings(): | |
| return load_settings() | |
| async def get_context(selected_folder: str, settings: dict = Depends(get_settings)): | |
| if selected_folder not in context_cache: | |
| try: | |
| llm, token_encoder, text_embedder = await setup_llm_and_embedder(settings) | |
| entities, relationships, reports, text_units, description_embedding_store, covariates = await load_context(selected_folder, settings) | |
| local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params = await setup_search_engines( | |
| llm, token_encoder, text_embedder, entities, relationships, reports, text_units, | |
| description_embedding_store, covariates | |
| ) | |
| question_generator = LocalQuestionGen( | |
| llm=llm, | |
| context_builder=local_context_builder, | |
| token_encoder=token_encoder, | |
| llm_params=local_llm_params, | |
| context_builder_params=local_context_params, | |
| ) | |
| context_cache[selected_folder] = { | |
| "local_search_engine": local_search_engine, | |
| "global_search_engine": global_search_engine, | |
| "question_generator": question_generator | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading context for folder {selected_folder}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to load context for folder {selected_folder}") | |
| return context_cache[selected_folder] | |
| async def chat_completions(request: ChatCompletionRequest): | |
| try: | |
| logger.info(f"Received request for model: {request.model}") | |
| if request.model == "direct-chat": | |
| logger.info("Routing to direct chat") | |
| return await run_direct_chat(request) | |
| elif request.model.startswith("graphrag-"): | |
| logger.info("Routing to GraphRAG query") | |
| if not request.query_options or not request.query_options.selected_folder: | |
| raise HTTPException(status_code=400, detail="Selected folder is required for GraphRAG queries") | |
| return await run_graphrag_query(request) | |
| elif request.model == "duckduckgo-search:latest": | |
| logger.info("Routing to DuckDuckGo search") | |
| return await run_duckduckgo_search(request) | |
| elif request.model == "full-model:latest": | |
| logger.info("Routing to full model search") | |
| return await run_full_model_search(request) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Invalid model specified: {request.model}") | |
| except HTTPException as he: | |
| logger.error(f"HTTP Exception: {str(he)}") | |
| raise he | |
| except Exception as e: | |
| logger.error(f"Error in chat completion: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def run_direct_chat(request: ChatCompletionRequest) -> ChatCompletionResponse: | |
| try: | |
| if not LLM_API_BASE: | |
| raise ValueError("LLM_API_BASE environment variable is not set") | |
| headers = {"Content-Type": "application/json"} | |
| payload = { | |
| "model": LLM_MODEL, | |
| "messages": [{"role": msg.role, "content": msg.content} for msg in request.messages], | |
| "stream": False | |
| } | |
| # Optional parameters | |
| if request.temperature is not None: | |
| payload["temperature"] = request.temperature | |
| if request.max_tokens is not None: | |
| payload["max_tokens"] = request.max_tokens | |
| full_url = f"{normalize_api_base(LLM_API_BASE)}/v1/chat/completions" | |
| logger.info(f"Sending request to: {full_url}") | |
| logger.info(f"Payload: {payload}") | |
| try: | |
| response = requests.post(full_url, json=payload, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| except requests.exceptions.RequestException as req_ex: | |
| logger.error(f"Request to LLM API failed: {str(req_ex)}") | |
| if isinstance(req_ex, requests.exceptions.ConnectionError): | |
| raise HTTPException(status_code=503, detail="Unable to connect to LLM API. Please check your API settings.") | |
| elif isinstance(req_ex, requests.exceptions.Timeout): | |
| raise HTTPException(status_code=504, detail="Request to LLM API timed out") | |
| else: | |
| raise HTTPException(status_code=500, detail=f"Request to LLM API failed: {str(req_ex)}") | |
| result = response.json() | |
| logger.info(f"Received response: {result}") | |
| content = result['choices'][0]['message']['content'] | |
| return ChatCompletionResponse( | |
| model=LLM_MODEL, | |
| choices=[ | |
| ChatCompletionResponseChoice( | |
| index=0, | |
| message=Message( | |
| role="assistant", | |
| content=content | |
| ), | |
| finish_reason=None | |
| ) | |
| ], | |
| usage=None | |
| ) | |
| except HTTPException as he: | |
| logger.error(f"HTTP Exception in direct chat: {str(he)}") | |
| raise he | |
| except Exception as e: | |
| logger.error(f"Unexpected error in direct chat: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"An unexpected error occurred during the direct chat: {str(e)}") | |
| def get_embeddings(text: str) -> List[float]: | |
| settings = load_settings() | |
| embeddings_api_base = settings['embeddings_api_base'] | |
| headers = {"Content-Type": "application/json"} | |
| if EMBEDDINGS_PROVIDER == 'ollama': | |
| payload = { | |
| "model": EMBEDDINGS_MODEL, | |
| "prompt": text | |
| } | |
| full_url = f"{embeddings_api_base}/api/embeddings" | |
| else: # OpenAI-compatible API | |
| payload = { | |
| "model": EMBEDDINGS_MODEL, | |
| "input": text | |
| } | |
| full_url = f"{embeddings_api_base}/v1/embeddings" | |
| try: | |
| response = requests.post(full_url, json=payload, headers=headers) | |
| response.raise_for_status() | |
| except requests.exceptions.RequestException as req_ex: | |
| logger.error(f"Request to Embeddings API failed: {str(req_ex)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get embeddings: {str(req_ex)}") | |
| result = response.json() | |
| if EMBEDDINGS_PROVIDER == 'ollama': | |
| return result['embedding'] | |
| else: | |
| return result['data'][0]['embedding'] | |
| async def run_graphrag_query(request: ChatCompletionRequest) -> ChatCompletionResponse: | |
| try: | |
| query_options = request.query_options | |
| query = request.messages[-1].content # Get the last user message as the query | |
| cmd = ["python", "-m", "graphrag.query"] | |
| cmd.extend(["--data", f"./indexing/output/{query_options.selected_folder}/artifacts"]) | |
| cmd.extend(["--method", query_options.query_type.split('-')[1]]) # 'global' or 'local' | |
| if query_options.community_level: | |
| cmd.extend(["--community_level", str(query_options.community_level)]) | |
| if query_options.response_type: | |
| cmd.extend(["--response_type", query_options.response_type]) | |
| # Handle preset CLI args | |
| if query_options.preset and query_options.preset != "Custom Query": | |
| preset_args = get_preset_args(query_options.preset) | |
| cmd.extend(preset_args) | |
| # Handle custom CLI args | |
| if query_options.custom_cli_args: | |
| cmd.extend(query_options.custom_cli_args.split()) | |
| cmd.append(query) | |
| logger.info(f"Executing GraphRAG query: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise Exception(f"GraphRAG query failed: {result.stderr}") | |
| return ChatCompletionResponse( | |
| model=request.model, | |
| choices=[ | |
| ChatCompletionResponseChoice( | |
| index=0, | |
| message=Message( | |
| role="assistant", | |
| content=result.stdout | |
| ), | |
| finish_reason="stop" | |
| ) | |
| ], | |
| usage=Usage( | |
| prompt_tokens=0, | |
| completion_tokens=0, | |
| total_tokens=0 | |
| ) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in GraphRAG query: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred during the GraphRAG query: {str(e)}") | |
| def get_preset_args(preset: str) -> List[str]: | |
| preset_args = { | |
| "Default Global Search": ["--community_level", "2", "--response_type", "Multiple Paragraphs"], | |
| "Default Local Search": ["--community_level", "2", "--response_type", "Multiple Paragraphs"], | |
| "Detailed Global Analysis": ["--community_level", "3", "--response_type", "Multi-Page Report"], | |
| "Detailed Local Analysis": ["--community_level", "3", "--response_type", "Multi-Page Report"], | |
| "Quick Global Summary": ["--community_level", "1", "--response_type", "Single Paragraph"], | |
| "Quick Local Summary": ["--community_level", "1", "--response_type", "Single Paragraph"], | |
| "Global Bullet Points": ["--community_level", "2", "--response_type", "List of 3-7 Points"], | |
| "Local Bullet Points": ["--community_level", "2", "--response_type", "List of 3-7 Points"], | |
| "Comprehensive Global Report": ["--community_level", "4", "--response_type", "Multi-Page Report"], | |
| "Comprehensive Local Report": ["--community_level", "4", "--response_type", "Multi-Page Report"], | |
| "High-Level Global Overview": ["--community_level", "1", "--response_type", "Single Page"], | |
| "High-Level Local Overview": ["--community_level", "1", "--response_type", "Single Page"], | |
| "Focused Global Insight": ["--community_level", "3", "--response_type", "Single Paragraph"], | |
| "Focused Local Insight": ["--community_level", "3", "--response_type", "Single Paragraph"], | |
| } | |
| return preset_args.get(preset, []) | |
| ddg_search = DuckDuckGoSearchAPIWrapper(max_results=5) | |
| async def run_duckduckgo_search(request: ChatCompletionRequest) -> ChatCompletionResponse: | |
| query = request.messages[-1].content | |
| results = ddg_search.results(query, max_results=5) | |
| if not results: | |
| content = "No results found for the given query." | |
| else: | |
| content = "DuckDuckGo Search Results:\n\n" | |
| for result in results: | |
| content += f"Title: {result['title']}\n" | |
| content += f"Snippet: {result['snippet']}\n" | |
| content += f"Link: {result['link']}\n" | |
| if 'date' in result: | |
| content += f"Date: {result['date']}\n" | |
| if 'source' in result: | |
| content += f"Source: {result['source']}\n" | |
| content += "\n" | |
| return ChatCompletionResponse( | |
| model=request.model, | |
| choices=[ | |
| ChatCompletionResponseChoice( | |
| index=0, | |
| message=Message( | |
| role="assistant", | |
| content=content | |
| ), | |
| finish_reason="stop" | |
| ) | |
| ], | |
| usage=Usage( | |
| prompt_tokens=0, | |
| completion_tokens=0, | |
| total_tokens=0 | |
| ) | |
| ) | |
| async def run_full_model_search(request: ChatCompletionRequest) -> ChatCompletionResponse: | |
| query = request.messages[-1].content | |
| # Run all search types | |
| graphrag_global = await run_graphrag_query(ChatCompletionRequest(model="graphrag-global-search:latest", messages=request.messages, query_options=request.query_options)) | |
| graphrag_local = await run_graphrag_query(ChatCompletionRequest(model="graphrag-local-search:latest", messages=request.messages, query_options=request.query_options)) | |
| duckduckgo = await run_duckduckgo_search(request) | |
| # Combine results | |
| combined_content = f"""Full Model Search Results: | |
| Global Search: | |
| {graphrag_global.choices[0].message.content} | |
| Local Search: | |
| {graphrag_local.choices[0].message.content} | |
| DuckDuckGo Search: | |
| {duckduckgo.choices[0].message.content} | |
| """ | |
| return ChatCompletionResponse( | |
| model=request.model, | |
| choices=[ | |
| ChatCompletionResponseChoice( | |
| index=0, | |
| message=Message( | |
| role="assistant", | |
| content=combined_content | |
| ), | |
| finish_reason="stop" | |
| ) | |
| ], | |
| usage=Usage( | |
| prompt_tokens=0, | |
| completion_tokens=0, | |
| total_tokens=0 | |
| ) | |
| ) | |
| async def health_check(): | |
| return {"status": "ok"} | |
| async def list_models(): | |
| settings = load_settings() | |
| try: | |
| api_models = await fetch_available_models(settings) | |
| except Exception as e: | |
| logger.error(f"Error fetching API models: {str(e)}") | |
| api_models = [] | |
| # Include the hardcoded models | |
| hardcoded_models = [ | |
| {"id": "graphrag-local-search:latest", "object": "model", "owned_by": "graphrag"}, | |
| {"id": "graphrag-global-search:latest", "object": "model", "owned_by": "graphrag"}, | |
| {"id": "duckduckgo-search:latest", "object": "model", "owned_by": "duckduckgo"}, | |
| {"id": "full-model:latest", "object": "model", "owned_by": "combined"}, | |
| ] | |
| # Combine API models with hardcoded models | |
| all_models = [{"id": model, "object": "model", "owned_by": "api"} for model in api_models] + hardcoded_models | |
| return JSONResponse(content={"data": all_models}) | |
| class PromptTuneRequest(BaseModel): | |
| root: str = "./{ROOT_DIR}" | |
| domain: Optional[str] = None | |
| method: str = "random" | |
| limit: int = 15 | |
| language: Optional[str] = None | |
| max_tokens: int = 2000 | |
| chunk_size: int = 200 | |
| no_entity_types: bool = False | |
| output: str = "./{ROOT_DIR}/prompts" | |
| class PromptTuneResponse(BaseModel): | |
| status: str | |
| message: str | |
| # Global variable to store the latest logs | |
| prompt_tune_logs = deque(maxlen=100) | |
| async def run_prompt_tuning(request: PromptTuneRequest): | |
| cmd = ["python", "-m", "graphrag.prompt_tune"] | |
| # Create a temporary directory for output | |
| with tempfile.TemporaryDirectory() as temp_output: | |
| # Expand environment variables in the root path | |
| root_path = os.path.expandvars(request.root) | |
| cmd.extend(["--root", root_path]) | |
| cmd.extend(["--method", request.method]) | |
| cmd.extend(["--limit", str(request.limit)]) | |
| if request.domain: | |
| cmd.extend(["--domain", request.domain]) | |
| if request.language: | |
| cmd.extend(["--language", request.language]) | |
| cmd.extend(["--max-tokens", str(request.max_tokens)]) | |
| cmd.extend(["--chunk-size", str(request.chunk_size)]) | |
| if request.no_entity_types: | |
| cmd.append("--no-entity-types") | |
| # Use the temporary directory for output | |
| cmd.extend(["--output", temp_output]) | |
| logger.info(f"Executing prompt tuning command: {' '.join(cmd)}") | |
| try: | |
| process = await asyncio.create_subprocess_exec( | |
| *cmd, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE | |
| ) | |
| async def read_stream(stream): | |
| while True: | |
| line = await stream.readline() | |
| if not line: | |
| break | |
| line = line.decode().strip() | |
| prompt_tune_logs.append(line) | |
| logger.info(line) | |
| await asyncio.gather( | |
| read_stream(process.stdout), | |
| read_stream(process.stderr) | |
| ) | |
| await process.wait() | |
| if process.returncode == 0: | |
| logger.info("Prompt tuning completed successfully") | |
| # Replace the existing template files with the newly generated prompts | |
| dest_dir = os.path.join(ROOT_DIR, "prompts") | |
| for filename in os.listdir(temp_output): | |
| if filename.endswith(".txt"): | |
| source_file = os.path.join(temp_output, filename) | |
| dest_file = os.path.join(dest_dir, filename) | |
| shutil.move(source_file, dest_file) | |
| logger.info(f"Replaced {filename} in {dest_file}") | |
| return PromptTuneResponse(status="success", message="Prompt tuning completed successfully. Existing prompts have been replaced.") | |
| else: | |
| logger.error("Prompt tuning failed") | |
| return PromptTuneResponse(status="error", message="Prompt tuning failed. Check logs for details.") | |
| except Exception as e: | |
| logger.error(f"Prompt tuning failed: {str(e)}") | |
| return PromptTuneResponse(status="error", message=f"Prompt tuning failed: {str(e)}") | |
| async def prompt_tune(request: PromptTuneRequest, background_tasks: BackgroundTasks): | |
| background_tasks.add_task(run_prompt_tuning, request) | |
| return {"status": "started", "message": "Prompt tuning process has been started in the background"} | |
| async def prompt_tune_status(): | |
| return { | |
| "status": "running" if prompt_tune_logs else "idle", | |
| "logs": list(prompt_tune_logs) | |
| } | |
| class IndexingRequest(BaseModel): | |
| llm_model: str | |
| embed_model: str | |
| llm_api_base: str | |
| embed_api_base: str | |
| root: str | |
| verbose: bool = False | |
| nocache: bool = False | |
| resume: Optional[str] = None | |
| reporter: str = "rich" | |
| emit: List[str] = ["parquet"] | |
| custom_args: Optional[str] = None | |
| llm_params: Dict[str, Any] = Field(default_factory=dict) | |
| embed_params: Dict[str, Any] = Field(default_factory=dict) | |
| # Global variable to store the latest indexing logs | |
| indexing_logs = deque(maxlen=100) | |
| async def run_indexing(request: IndexingRequest): | |
| cmd = ["python", "-m", "graphrag.index"] | |
| cmd.extend(["--root", request.root]) | |
| if request.verbose: | |
| cmd.append("--verbose") | |
| if request.nocache: | |
| cmd.append("--nocache") | |
| if request.resume: | |
| cmd.extend(["--resume", request.resume]) | |
| cmd.extend(["--reporter", request.reporter]) | |
| cmd.extend(["--emit", ",".join(request.emit)]) | |
| # Set environment variables for LLM and embedding models | |
| env: Dict[str, Any] = os.environ.copy() | |
| env["GRAPHRAG_LLM_MODEL"] = request.llm_model | |
| env["GRAPHRAG_EMBED_MODEL"] = request.embed_model | |
| env["GRAPHRAG_LLM_API_BASE"] = LLM_API_BASE | |
| env["GRAPHRAG_EMBED_API_BASE"] = EMBEDDINGS_API_BASE | |
| # Set environment variables for LLM parameters | |
| for key, value in request.llm_params.items(): | |
| env[f"GRAPHRAG_LLM_{key.upper()}"] = str(value) | |
| # Set environment variables for embedding parameters | |
| for key, value in request.embed_params.items(): | |
| env[f"GRAPHRAG_EMBED_{key.upper()}"] = str(value) | |
| # Add custom CLI arguments | |
| if request.custom_args: | |
| cmd.extend(request.custom_args.split()) | |
| logger.info(f"Executing indexing command: {' '.join(cmd)}") | |
| logger.info(f"Environment variables: {env}") | |
| try: | |
| process = await asyncio.create_subprocess_exec( | |
| *cmd, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| env=env | |
| ) | |
| async def read_stream(stream): | |
| while True: | |
| line = await stream.readline() | |
| if not line: | |
| break | |
| line = line.decode().strip() | |
| indexing_logs.append(line) | |
| logger.info(line) | |
| await asyncio.gather( | |
| read_stream(process.stdout), | |
| read_stream(process.stderr) | |
| ) | |
| await process.wait() | |
| if process.returncode == 0: | |
| logger.info("Indexing completed successfully") | |
| return {"status": "success", "message": "Indexing completed successfully"} | |
| else: | |
| logger.error("Indexing failed") | |
| return {"status": "error", "message": "Indexing failed. Check logs for details."} | |
| except Exception as e: | |
| logger.error(f"Indexing failed: {str(e)}") | |
| return {"status": "error", "message": f"Indexing failed: {str(e)}"} | |
| async def start_indexing(request: IndexingRequest, background_tasks: BackgroundTasks): | |
| background_tasks.add_task(run_indexing, request) | |
| return {"status": "started", "message": "Indexing process has been started in the background"} | |
| async def indexing_status(): | |
| return { | |
| "status": "running" if indexing_logs else "idle", | |
| "logs": list(indexing_logs) | |
| } | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Launch the GraphRAG API server") | |
| parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to bind the server to") | |
| parser.add_argument("--port", type=int, default=PORT, help="Port to bind the server to") | |
| parser.add_argument("--reload", action="store_true", help="Enable auto-reload mode") | |
| args = parser.parse_args() | |
| import uvicorn | |
| uvicorn.run( | |
| "api:app", | |
| host=args.host, | |
| port=args.port, | |
| reload=args.reload | |
| ) | |