Spaces:
Runtime error
Runtime error
| from typing import Literal | |
| from pydantic import BaseModel, Field | |
| from private_gpt.settings.settings_loader import load_active_settings | |
| class CorsSettings(BaseModel): | |
| """CORS configuration. | |
| For more details on the CORS configuration, see: | |
| # * https://fastapi.tiangolo.com/tutorial/cors/ | |
| # * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS | |
| """ | |
| enabled: bool = Field( | |
| description="Flag indicating if CORS headers are set or not." | |
| "If set to True, the CORS headers will be set to allow all origins, methods and headers.", | |
| default=False, | |
| ) | |
| allow_credentials: bool = Field( | |
| description="Indicate that cookies should be supported for cross-origin requests", | |
| default=False, | |
| ) | |
| allow_origins: list[str] = Field( | |
| description="A list of origins that should be permitted to make cross-origin requests.", | |
| default=[], | |
| ) | |
| allow_origin_regex: list[str] = Field( | |
| description="A regex string to match against origins that should be permitted to make cross-origin requests.", | |
| default=None, | |
| ) | |
| allow_methods: list[str] = Field( | |
| description="A list of HTTP methods that should be allowed for cross-origin requests.", | |
| default=[ | |
| "GET", | |
| ], | |
| ) | |
| allow_headers: list[str] = Field( | |
| description="A list of HTTP request headers that should be supported for cross-origin requests.", | |
| default=[], | |
| ) | |
| class AuthSettings(BaseModel): | |
| """Authentication configuration. | |
| The implementation of the authentication strategy must | |
| """ | |
| enabled: bool = Field( | |
| description="Flag indicating if authentication is enabled or not.", | |
| default=False, | |
| ) | |
| secret: str = Field( | |
| description="The secret to be used for authentication. " | |
| "It can be any non-blank string. For HTTP basic authentication, " | |
| "this value should be the whole 'Authorization' header that is expected" | |
| ) | |
| class ServerSettings(BaseModel): | |
| env_name: str = Field( | |
| description="Name of the environment (prod, staging, local...)" | |
| ) | |
| port: int = Field(description="Port of PrivateGPT FastAPI server, defaults to 8001") | |
| cors: CorsSettings = Field( | |
| description="CORS configuration", default=CorsSettings(enabled=False) | |
| ) | |
| auth: AuthSettings = Field( | |
| description="Authentication configuration", | |
| default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"), | |
| ) | |
| class DataSettings(BaseModel): | |
| local_data_folder: str = Field( | |
| description="Path to local storage." | |
| "It will be treated as an absolute path if it starts with /" | |
| ) | |
| class LLMSettings(BaseModel): | |
| mode: Literal[ | |
| "llamacpp", "openai", "openailike", "azopenai", "sagemaker", "mock", "ollama" | |
| ] | |
| max_new_tokens: int = Field( | |
| 256, | |
| description="The maximum number of token that the LLM is authorized to generate in one completion.", | |
| ) | |
| context_window: int = Field( | |
| 3900, | |
| description="The maximum number of context tokens for the model.", | |
| ) | |
| tokenizer: str = Field( | |
| None, | |
| description="The model id of a predefined tokenizer hosted inside a model repo on " | |
| "huggingface.co. Valid model ids can be located at the root-level, like " | |
| "`bert-base-uncased`, or namespaced under a user or organization name, " | |
| "like `HuggingFaceH4/zephyr-7b-beta`. If not set, will load a tokenizer matching " | |
| "gpt-3.5-turbo LLM.", | |
| ) | |
| temperature: float = Field( | |
| 0.1, | |
| description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.", | |
| ) | |
| prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field( | |
| "llama2", | |
| description=( | |
| "The prompt style to use for the chat engine. " | |
| "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n" | |
| "If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n" | |
| "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n" | |
| "If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]" | |
| "`llama2` is the historic behaviour. `default` might work better with your custom models." | |
| ), | |
| ) | |
| class VectorstoreSettings(BaseModel): | |
| database: Literal["chroma", "qdrant", "postgres"] | |
| class NodeStoreSettings(BaseModel): | |
| database: Literal["simple", "postgres"] | |
| class LlamaCPPSettings(BaseModel): | |
| llm_hf_repo_id: str | |
| llm_hf_model_file: str | |
| tfs_z: float = Field( | |
| 1.0, | |
| description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.", | |
| ) | |
| top_k: int = Field( | |
| 40, | |
| description="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)", | |
| ) | |
| top_p: float = Field( | |
| 0.9, | |
| description="Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)", | |
| ) | |
| repeat_penalty: float = Field( | |
| 1.1, | |
| description="Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)", | |
| ) | |
| class HuggingFaceSettings(BaseModel): | |
| embedding_hf_model_name: str = Field( | |
| description="Name of the HuggingFace model to use for embeddings" | |
| ) | |
| access_token: str = Field( | |
| None, | |
| description="Huggingface access token, required to download some models", | |
| ) | |
| class EmbeddingSettings(BaseModel): | |
| mode: Literal["huggingface", "openai", "azopenai", "sagemaker", "ollama", "mock"] | |
| ingest_mode: Literal["simple", "batch", "parallel", "pipeline"] = Field( | |
| "simple", | |
| description=( | |
| "The ingest mode to use for the embedding engine:\n" | |
| "If `simple` - ingest files sequentially and one by one. It is the historic behaviour.\n" | |
| "If `batch` - if multiple files, parse all the files in parallel, " | |
| "and send them in batch to the embedding model.\n" | |
| "In `pipeline` - The Embedding engine is kept as busy as possible\n" | |
| "If `parallel` - parse the files in parallel using multiple cores, and embedd them in parallel.\n" | |
| "`parallel` is the fastest mode for local setup, as it parallelize IO RW in the index.\n" | |
| "For modes that leverage parallelization, you can specify the number of " | |
| "workers to use with `count_workers`.\n" | |
| ), | |
| ) | |
| count_workers: int = Field( | |
| 2, | |
| description=( | |
| "The number of workers to use for file ingestion.\n" | |
| "In `batch` mode, this is the number of workers used to parse the files.\n" | |
| "In `parallel` mode, this is the number of workers used to parse the files and embed them.\n" | |
| "In `pipeline` mode, this is the number of workers that can perform embeddings.\n" | |
| "This is only used if `ingest_mode` is not `simple`.\n" | |
| "Do not go too high with this number, as it might cause memory issues. (especially in `parallel` mode)\n" | |
| "Do not set it higher than your number of threads of your CPU." | |
| ), | |
| ) | |
| embed_dim: int = Field( | |
| 384, | |
| description="The dimension of the embeddings stored in the Postgres database", | |
| ) | |
| class SagemakerSettings(BaseModel): | |
| llm_endpoint_name: str | |
| embedding_endpoint_name: str | |
| class OpenAISettings(BaseModel): | |
| api_base: str = Field( | |
| None, | |
| description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.", | |
| ) | |
| api_key: str | |
| model: str = Field( | |
| "gpt-3.5-turbo", | |
| description="OpenAI Model to use. Example: 'gpt-4'.", | |
| ) | |
| class OllamaSettings(BaseModel): | |
| api_base: str = Field( | |
| "http://localhost:11434", | |
| description="Base URL of Ollama API. Example: 'https://localhost:11434'.", | |
| ) | |
| embedding_api_base: str = Field( | |
| "http://localhost:11434", | |
| description="Base URL of Ollama embedding API. Example: 'https://localhost:11434'.", | |
| ) | |
| llm_model: str = Field( | |
| None, | |
| description="Model to use. Example: 'llama2-uncensored'.", | |
| ) | |
| embedding_model: str = Field( | |
| None, | |
| description="Model to use. Example: 'nomic-embed-text'.", | |
| ) | |
| keep_alive: str = Field( | |
| "5m", | |
| description="Time the model will stay loaded in memory after a request. examples: 5m, 5h, '-1' ", | |
| ) | |
| tfs_z: float = Field( | |
| 1.0, | |
| description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.", | |
| ) | |
| num_predict: int = Field( | |
| None, | |
| description="Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)", | |
| ) | |
| top_k: int = Field( | |
| 40, | |
| description="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)", | |
| ) | |
| top_p: float = Field( | |
| 0.9, | |
| description="Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)", | |
| ) | |
| repeat_last_n: int = Field( | |
| 64, | |
| description="Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)", | |
| ) | |
| repeat_penalty: float = Field( | |
| 1.1, | |
| description="Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)", | |
| ) | |
| request_timeout: float = Field( | |
| 120.0, | |
| description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ", | |
| ) | |
| class AzureOpenAISettings(BaseModel): | |
| api_key: str | |
| azure_endpoint: str | |
| api_version: str = Field( | |
| "2023_05_15", | |
| description="The API version to use for this operation. This follows the YYYY-MM-DD format.", | |
| ) | |
| embedding_deployment_name: str | |
| embedding_model: str = Field( | |
| "text-embedding-ada-002", | |
| description="OpenAI Model to use. Example: 'text-embedding-ada-002'.", | |
| ) | |
| llm_deployment_name: str | |
| llm_model: str = Field( | |
| "gpt-35-turbo", | |
| description="OpenAI Model to use. Example: 'gpt-4'.", | |
| ) | |
| class UISettings(BaseModel): | |
| enabled: bool | |
| path: str | |
| default_chat_system_prompt: str = Field( | |
| None, | |
| description="The default system prompt to use for the chat mode.", | |
| ) | |
| default_query_system_prompt: str = Field( | |
| None, description="The default system prompt to use for the query mode." | |
| ) | |
| delete_file_button_enabled: bool = Field( | |
| True, description="If the button to delete a file is enabled or not." | |
| ) | |
| delete_all_files_button_enabled: bool = Field( | |
| False, description="If the button to delete all files is enabled or not." | |
| ) | |
| class RerankSettings(BaseModel): | |
| enabled: bool = Field( | |
| False, | |
| description="This value controls whether a reranker should be included in the RAG pipeline.", | |
| ) | |
| model: str = Field( | |
| "cross-encoder/ms-marco-MiniLM-L-2-v2", | |
| description="Rerank model to use. Limited to SentenceTransformer cross-encoder models.", | |
| ) | |
| top_n: int = Field( | |
| 2, | |
| description="This value controls the number of documents returned by the RAG pipeline.", | |
| ) | |
| class RagSettings(BaseModel): | |
| similarity_top_k: int = Field( | |
| 2, | |
| description="This value controls the number of documents returned by the RAG pipeline or considered for reranking if enabled.", | |
| ) | |
| similarity_value: float = Field( | |
| None, | |
| description="If set, any documents retrieved from the RAG must meet a certain match score. Acceptable values are between 0 and 1.", | |
| ) | |
| rerank: RerankSettings | |
| class PostgresSettings(BaseModel): | |
| host: str = Field( | |
| "localhost", | |
| description="The server hosting the Postgres database", | |
| ) | |
| port: int = Field( | |
| 5432, | |
| description="The port on which the Postgres database is accessible", | |
| ) | |
| user: str = Field( | |
| "postgres", | |
| description="The user to use to connect to the Postgres database", | |
| ) | |
| password: str = Field( | |
| "postgres", | |
| description="The password to use to connect to the Postgres database", | |
| ) | |
| database: str = Field( | |
| "postgres", | |
| description="The database to use to connect to the Postgres database", | |
| ) | |
| schema_name: str = Field( | |
| "public", | |
| description="The name of the schema in the Postgres database to use", | |
| ) | |
| class QdrantSettings(BaseModel): | |
| location: str | None = Field( | |
| None, | |
| description=( | |
| "If `:memory:` - use in-memory Qdrant instance.\n" | |
| "If `str` - use it as a `url` parameter.\n" | |
| ), | |
| ) | |
| url: str | None = Field( | |
| None, | |
| description=( | |
| "Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'." | |
| ), | |
| ) | |
| port: int | None = Field(6333, description="Port of the REST API interface.") | |
| grpc_port: int | None = Field(6334, description="Port of the gRPC interface.") | |
| prefer_grpc: bool | None = Field( | |
| False, | |
| description="If `true` - use gRPC interface whenever possible in custom methods.", | |
| ) | |
| https: bool | None = Field( | |
| None, | |
| description="If `true` - use HTTPS(SSL) protocol.", | |
| ) | |
| api_key: str | None = Field( | |
| None, | |
| description="API key for authentication in Qdrant Cloud.", | |
| ) | |
| prefix: str | None = Field( | |
| None, | |
| description=( | |
| "Prefix to add to the REST URL path." | |
| "Example: `service/v1` will result in " | |
| "'http://localhost:6333/service/v1/{qdrant-endpoint}' for REST API." | |
| ), | |
| ) | |
| timeout: float | None = Field( | |
| None, | |
| description="Timeout for REST and gRPC API requests.", | |
| ) | |
| host: str | None = Field( | |
| None, | |
| description="Host name of Qdrant service. If url and host are None, set to 'localhost'.", | |
| ) | |
| path: str | None = Field(None, description="Persistence path for QdrantLocal.") | |
| force_disable_check_same_thread: bool | None = Field( | |
| True, | |
| description=( | |
| "For QdrantLocal, force disable check_same_thread. Default: `True`" | |
| "Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient." | |
| ), | |
| ) | |
| class Settings(BaseModel): | |
| server: ServerSettings | |
| data: DataSettings | |
| ui: UISettings | |
| llm: LLMSettings | |
| embedding: EmbeddingSettings | |
| llamacpp: LlamaCPPSettings | |
| huggingface: HuggingFaceSettings | |
| sagemaker: SagemakerSettings | |
| openai: OpenAISettings | |
| ollama: OllamaSettings | |
| azopenai: AzureOpenAISettings | |
| vectorstore: VectorstoreSettings | |
| nodestore: NodeStoreSettings | |
| rag: RagSettings | |
| qdrant: QdrantSettings | None = None | |
| postgres: PostgresSettings | None = None | |
| """ | |
| This is visible just for DI or testing purposes. | |
| Use dependency injection or `settings()` method instead. | |
| """ | |
| unsafe_settings = load_active_settings() | |
| """ | |
| This is visible just for DI or testing purposes. | |
| Use dependency injection or `settings()` method instead. | |
| """ | |
| unsafe_typed_settings = Settings(**unsafe_settings) | |
| def settings() -> Settings: | |
| """Get the current loaded settings from the DI container. | |
| This method exists to keep compatibility with the existing code, | |
| that require global access to the settings. | |
| For regular components use dependency injection instead. | |
| """ | |
| from private_gpt.di import global_injector | |
| return global_injector.get(Settings) | |