Spaces:
Runtime error
Runtime error
| """This file should be imported if and only if you want to run the UI locally.""" | |
| import itertools | |
| import logging | |
| import time | |
| from collections.abc import Iterable | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr # type: ignore | |
| from fastapi import FastAPI | |
| from gradio.themes.utils.colors import slate # type: ignore | |
| from injector import inject, singleton | |
| from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole | |
| from pydantic import BaseModel | |
| from private_gpt.constants import PROJECT_ROOT_PATH | |
| from private_gpt.di import global_injector | |
| from private_gpt.open_ai.extensions.context_filter import ContextFilter | |
| from private_gpt.server.chat.chat_service import ChatService, CompletionGen | |
| from private_gpt.server.chunks.chunks_service import Chunk, ChunksService | |
| from private_gpt.server.ingest.ingest_service import IngestService | |
| from private_gpt.settings.settings import settings | |
| from private_gpt.ui.images import logo_svg | |
| logger = logging.getLogger(__name__) | |
| THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH) | |
| # Should be "private_gpt/ui/avatar-bot.ico" | |
| AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico" | |
| UI_TAB_TITLE = "My Private GPT" | |
| SOURCES_SEPARATOR = "\n\n Sources: \n" | |
| MODES = ["Query Files", "Search Files", "LLM Chat (no context from files)"] | |
| class Source(BaseModel): | |
| file: str | |
| page: str | |
| text: str | |
| class Config: | |
| frozen = True | |
| def curate_sources(sources: list[Chunk]) -> list["Source"]: | |
| curated_sources = [] | |
| for chunk in sources: | |
| doc_metadata = chunk.document.doc_metadata | |
| file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-" | |
| page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-" | |
| source = Source(file=file_name, page=page_label, text=chunk.text) | |
| curated_sources.append(source) | |
| curated_sources = list( | |
| dict.fromkeys(curated_sources).keys() | |
| ) # Unique sources only | |
| return curated_sources | |
| class PrivateGptUi: | |
| def __init__( | |
| self, | |
| ingest_service: IngestService, | |
| chat_service: ChatService, | |
| chunks_service: ChunksService, | |
| ) -> None: | |
| self._ingest_service = ingest_service | |
| self._chat_service = chat_service | |
| self._chunks_service = chunks_service | |
| # Cache the UI blocks | |
| self._ui_block = None | |
| self._selected_filename = None | |
| # Initialize system prompt based on default mode | |
| self.mode = MODES[0] | |
| self._system_prompt = self._get_default_system_prompt(self.mode) | |
| def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: | |
| def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: | |
| full_response: str = "" | |
| stream = completion_gen.response | |
| for delta in stream: | |
| if isinstance(delta, str): | |
| full_response += str(delta) | |
| elif isinstance(delta, ChatResponse): | |
| full_response += delta.delta or "" | |
| yield full_response | |
| time.sleep(0.02) | |
| if completion_gen.sources: | |
| full_response += SOURCES_SEPARATOR | |
| cur_sources = Source.curate_sources(completion_gen.sources) | |
| sources_text = "\n\n\n" | |
| used_files = set() | |
| for index, source in enumerate(cur_sources, start=1): | |
| if f"{source.file}-{source.page}" not in used_files: | |
| sources_text = ( | |
| sources_text | |
| + f"{index}. {source.file} (page {source.page}) \n\n" | |
| ) | |
| used_files.add(f"{source.file}-{source.page}") | |
| full_response += sources_text | |
| yield full_response | |
| def build_history() -> list[ChatMessage]: | |
| history_messages: list[ChatMessage] = list( | |
| itertools.chain( | |
| *[ | |
| [ | |
| ChatMessage(content=interaction[0], role=MessageRole.USER), | |
| ChatMessage( | |
| # Remove from history content the Sources information | |
| content=interaction[1].split(SOURCES_SEPARATOR)[0], | |
| role=MessageRole.ASSISTANT, | |
| ), | |
| ] | |
| for interaction in history | |
| ] | |
| ) | |
| ) | |
| # max 20 messages to try to avoid context overflow | |
| return history_messages[:20] | |
| new_message = ChatMessage(content=message, role=MessageRole.USER) | |
| all_messages = [*build_history(), new_message] | |
| # If a system prompt is set, add it as a system message | |
| if self._system_prompt: | |
| all_messages.insert( | |
| 0, | |
| ChatMessage( | |
| content=self._system_prompt, | |
| role=MessageRole.SYSTEM, | |
| ), | |
| ) | |
| match mode: | |
| case "Query Files": | |
| # Use only the selected file for the query | |
| context_filter = None | |
| if self._selected_filename is not None: | |
| docs_ids = [] | |
| for ingested_document in self._ingest_service.list_ingested(): | |
| if ( | |
| ingested_document.doc_metadata["file_name"] | |
| == self._selected_filename | |
| ): | |
| docs_ids.append(ingested_document.doc_id) | |
| context_filter = ContextFilter(docs_ids=docs_ids) | |
| query_stream = self._chat_service.stream_chat( | |
| messages=all_messages, | |
| use_context=True, | |
| context_filter=context_filter, | |
| ) | |
| yield from yield_deltas(query_stream) | |
| case "LLM Chat (no context from files)": | |
| llm_stream = self._chat_service.stream_chat( | |
| messages=all_messages, | |
| use_context=False, | |
| ) | |
| yield from yield_deltas(llm_stream) | |
| case "Search Files": | |
| response = self._chunks_service.retrieve_relevant( | |
| text=message, limit=4, prev_next_chunks=0 | |
| ) | |
| sources = Source.curate_sources(response) | |
| yield "\n\n\n".join( | |
| f"{index}. **{source.file} " | |
| f"(page {source.page})**\n " | |
| f"{source.text}" | |
| for index, source in enumerate(sources, start=1) | |
| ) | |
| # On initialization and on mode change, this function set the system prompt | |
| # to the default prompt based on the mode (and user settings). | |
| def _get_default_system_prompt(mode: str) -> str: | |
| p = "" | |
| match mode: | |
| # For query chat mode, obtain default system prompt from settings | |
| case "Query Files": | |
| p = settings().ui.default_query_system_prompt | |
| # For chat mode, obtain default system prompt from settings | |
| case "LLM Chat (no context from files)": | |
| p = settings().ui.default_chat_system_prompt | |
| # For any other mode, clear the system prompt | |
| case _: | |
| p = "" | |
| return p | |
| def _set_system_prompt(self, system_prompt_input: str) -> None: | |
| logger.info(f"Setting system prompt to: {system_prompt_input}") | |
| self._system_prompt = system_prompt_input | |
| def _set_current_mode(self, mode: str) -> Any: | |
| self.mode = mode | |
| self._set_system_prompt(self._get_default_system_prompt(mode)) | |
| # Update placeholder and allow interaction if default system prompt is set | |
| if self._system_prompt: | |
| return gr.update(placeholder=self._system_prompt, interactive=True) | |
| # Update placeholder and disable interaction if no default system prompt is set | |
| else: | |
| return gr.update(placeholder=self._system_prompt, interactive=False) | |
| def _list_ingested_files(self) -> list[list[str]]: | |
| files = set() | |
| for ingested_document in self._ingest_service.list_ingested(): | |
| if ingested_document.doc_metadata is None: | |
| # Skipping documents without metadata | |
| continue | |
| file_name = ingested_document.doc_metadata.get( | |
| "file_name", "[FILE NAME MISSING]" | |
| ) | |
| files.add(file_name) | |
| return [[row] for row in files] | |
| def _upload_file(self, files: list[str]) -> None: | |
| logger.debug("Loading count=%s files", len(files)) | |
| paths = [Path(file) for file in files] | |
| # remove all existing Documents with name identical to a new file upload: | |
| file_names = [path.name for path in paths] | |
| doc_ids_to_delete = [] | |
| for ingested_document in self._ingest_service.list_ingested(): | |
| if ( | |
| ingested_document.doc_metadata | |
| and ingested_document.doc_metadata["file_name"] in file_names | |
| ): | |
| doc_ids_to_delete.append(ingested_document.doc_id) | |
| if len(doc_ids_to_delete) > 0: | |
| logger.info( | |
| "Uploading file(s) which were already ingested: %s document(s) will be replaced.", | |
| len(doc_ids_to_delete), | |
| ) | |
| for doc_id in doc_ids_to_delete: | |
| self._ingest_service.delete(doc_id) | |
| self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths]) | |
| def _delete_all_files(self) -> Any: | |
| ingested_files = self._ingest_service.list_ingested() | |
| logger.debug("Deleting count=%s files", len(ingested_files)) | |
| for ingested_document in ingested_files: | |
| self._ingest_service.delete(ingested_document.doc_id) | |
| return [ | |
| gr.List(self._list_ingested_files()), | |
| gr.components.Button(interactive=False), | |
| gr.components.Button(interactive=False), | |
| gr.components.Textbox("All files"), | |
| ] | |
| def _delete_selected_file(self) -> Any: | |
| logger.debug("Deleting selected %s", self._selected_filename) | |
| # Note: keep looping for pdf's (each page became a Document) | |
| for ingested_document in self._ingest_service.list_ingested(): | |
| if ( | |
| ingested_document.doc_metadata | |
| and ingested_document.doc_metadata["file_name"] | |
| == self._selected_filename | |
| ): | |
| self._ingest_service.delete(ingested_document.doc_id) | |
| return [ | |
| gr.List(self._list_ingested_files()), | |
| gr.components.Button(interactive=False), | |
| gr.components.Button(interactive=False), | |
| gr.components.Textbox("All files"), | |
| ] | |
| def _deselect_selected_file(self) -> Any: | |
| self._selected_filename = None | |
| return [ | |
| gr.components.Button(interactive=False), | |
| gr.components.Button(interactive=False), | |
| gr.components.Textbox("All files"), | |
| ] | |
| def _selected_a_file(self, select_data: gr.SelectData) -> Any: | |
| self._selected_filename = select_data.value | |
| return [ | |
| gr.components.Button(interactive=True), | |
| gr.components.Button(interactive=True), | |
| gr.components.Textbox(self._selected_filename), | |
| ] | |
| def _build_ui_blocks(self) -> gr.Blocks: | |
| logger.debug("Creating the UI blocks") | |
| with gr.Blocks( | |
| title=UI_TAB_TITLE, | |
| theme=gr.themes.Soft(primary_hue=slate), | |
| css=".logo { " | |
| "display:flex;" | |
| "background-color: #C7BAFF;" | |
| "height: 80px;" | |
| "border-radius: 8px;" | |
| "align-content: center;" | |
| "justify-content: center;" | |
| "align-items: center;" | |
| "}" | |
| ".logo img { height: 25% }" | |
| ".contain { display: flex !important; flex-direction: column !important; }" | |
| "#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }" | |
| "#chatbot { flex-grow: 1 !important; overflow: auto !important;}" | |
| "#col { height: calc(100vh - 112px - 16px) !important; }", | |
| ) as blocks: | |
| with gr.Row(): | |
| gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=3): | |
| mode = gr.Radio( | |
| MODES, | |
| label="Mode", | |
| value="Query Files", | |
| ) | |
| upload_button = gr.components.UploadButton( | |
| "Upload File(s)", | |
| type="filepath", | |
| file_count="multiple", | |
| size="sm", | |
| ) | |
| ingested_dataset = gr.List( | |
| self._list_ingested_files, | |
| headers=["File name"], | |
| label="Ingested Files", | |
| height=235, | |
| interactive=False, | |
| render=False, # Rendered under the button | |
| ) | |
| upload_button.upload( | |
| self._upload_file, | |
| inputs=upload_button, | |
| outputs=ingested_dataset, | |
| ) | |
| ingested_dataset.change( | |
| self._list_ingested_files, | |
| outputs=ingested_dataset, | |
| ) | |
| ingested_dataset.render() | |
| deselect_file_button = gr.components.Button( | |
| "De-select selected file", size="sm", interactive=False | |
| ) | |
| selected_text = gr.components.Textbox( | |
| "All files", label="Selected for Query or Deletion", max_lines=1 | |
| ) | |
| delete_file_button = gr.components.Button( | |
| "🗑️ Delete selected file", | |
| size="sm", | |
| visible=settings().ui.delete_file_button_enabled, | |
| interactive=False, | |
| ) | |
| delete_files_button = gr.components.Button( | |
| "⚠️ Delete ALL files", | |
| size="sm", | |
| visible=settings().ui.delete_all_files_button_enabled, | |
| ) | |
| deselect_file_button.click( | |
| self._deselect_selected_file, | |
| outputs=[ | |
| delete_file_button, | |
| deselect_file_button, | |
| selected_text, | |
| ], | |
| ) | |
| ingested_dataset.select( | |
| fn=self._selected_a_file, | |
| outputs=[ | |
| delete_file_button, | |
| deselect_file_button, | |
| selected_text, | |
| ], | |
| ) | |
| delete_file_button.click( | |
| self._delete_selected_file, | |
| outputs=[ | |
| ingested_dataset, | |
| delete_file_button, | |
| deselect_file_button, | |
| selected_text, | |
| ], | |
| ) | |
| delete_files_button.click( | |
| self._delete_all_files, | |
| outputs=[ | |
| ingested_dataset, | |
| delete_file_button, | |
| deselect_file_button, | |
| selected_text, | |
| ], | |
| ) | |
| system_prompt_input = gr.Textbox( | |
| placeholder=self._system_prompt, | |
| label="System Prompt", | |
| lines=2, | |
| interactive=True, | |
| render=False, | |
| ) | |
| # When mode changes, set default system prompt | |
| mode.change( | |
| self._set_current_mode, inputs=mode, outputs=system_prompt_input | |
| ) | |
| # On blur, set system prompt to use in queries | |
| system_prompt_input.blur( | |
| self._set_system_prompt, | |
| inputs=system_prompt_input, | |
| ) | |
| def get_model_label() -> str | None: | |
| """Get model label from llm mode setting YAML. | |
| Raises: | |
| ValueError: If an invalid 'llm_mode' is encountered. | |
| Returns: | |
| str: The corresponding model label. | |
| """ | |
| # Get model label from llm mode setting YAML | |
| # Labels: local, openai, openailike, sagemaker, mock, ollama | |
| config_settings = settings() | |
| if config_settings is None: | |
| raise ValueError("Settings are not configured.") | |
| # Get llm_mode from settings | |
| llm_mode = config_settings.llm.mode | |
| # Mapping of 'llm_mode' to corresponding model labels | |
| model_mapping = { | |
| "llamacpp": config_settings.llamacpp.llm_hf_model_file, | |
| "openai": config_settings.openai.model, | |
| "openailike": config_settings.openai.model, | |
| "sagemaker": config_settings.sagemaker.llm_endpoint_name, | |
| "mock": llm_mode, | |
| "ollama": config_settings.ollama.llm_model, | |
| } | |
| if llm_mode not in model_mapping: | |
| print(f"Invalid 'llm mode': {llm_mode}") | |
| return None | |
| return model_mapping[llm_mode] | |
| with gr.Column(scale=7, elem_id="col"): | |
| # Determine the model label based on the value of PGPT_PROFILES | |
| model_label = get_model_label() | |
| if model_label is not None: | |
| label_text = ( | |
| f"LLM: {settings().llm.mode} | Model: {model_label}" | |
| ) | |
| else: | |
| label_text = f"LLM: {settings().llm.mode}" | |
| _ = gr.ChatInterface( | |
| self._chat, | |
| chatbot=gr.Chatbot( | |
| label=label_text, | |
| show_copy_button=True, | |
| elem_id="chatbot", | |
| render=False, | |
| avatar_images=( | |
| None, | |
| AVATAR_BOT, | |
| ), | |
| ), | |
| additional_inputs=[mode, upload_button, system_prompt_input], | |
| ) | |
| return blocks | |
| def get_ui_blocks(self) -> gr.Blocks: | |
| if self._ui_block is None: | |
| self._ui_block = self._build_ui_blocks() | |
| return self._ui_block | |
| def mount_in_app(self, app: FastAPI, path: str) -> None: | |
| blocks = self.get_ui_blocks() | |
| blocks.queue() | |
| logger.info("Mounting the gradio UI, at path=%s", path) | |
| gr.mount_gradio_app(app, blocks, path=path) | |
| if __name__ == "__main__": | |
| ui = global_injector.get(PrivateGptUi) | |
| _blocks = ui.get_ui_blocks() | |
| _blocks.queue() | |
| _blocks.launch(debug=False, show_api=False) | |