Spaces:
Paused
Paused
| """ | |
| What is this? | |
| Provider-specific Pass-Through Endpoints | |
| Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc. | |
| """ | |
| import os | |
| from typing import Optional | |
| import httpx | |
| from fastapi import APIRouter, Depends, HTTPException, Request, Response | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES | |
| from litellm.llms.vertex_ai.vertex_llm_base import VertexBase | |
| from litellm.proxy._types import * | |
| from litellm.proxy.auth.route_checks import RouteChecks | |
| from litellm.proxy.auth.user_api_key_auth import user_api_key_auth | |
| from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key | |
| from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( | |
| create_pass_through_route, | |
| ) | |
| from litellm.secret_managers.main import get_secret_str | |
| from .passthrough_endpoint_router import PassthroughEndpointRouter | |
| vertex_llm_base = VertexBase() | |
| router = APIRouter() | |
| default_vertex_config = None | |
| passthrough_endpoint_router = PassthroughEndpointRouter() | |
| def create_request_copy(request: Request): | |
| return { | |
| "method": request.method, | |
| "url": str(request.url), | |
| "headers": dict(request.headers), | |
| "cookies": request.cookies, | |
| "query_params": dict(request.query_params), | |
| } | |
| async def llm_passthrough_factory_proxy_route( | |
| custom_llm_provider: str, | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| Factory function for creating pass-through endpoints for LLM providers. | |
| """ | |
| from litellm.types.utils import LlmProviders | |
| from litellm.utils import ProviderConfigManager | |
| provider_config = ProviderConfigManager.get_provider_model_info( | |
| provider=LlmProviders(custom_llm_provider), | |
| model=None, | |
| ) | |
| if provider_config is None: | |
| raise HTTPException( | |
| status_code=404, detail=f"Provider {custom_llm_provider} not found" | |
| ) | |
| base_target_url = provider_config.get_api_base() | |
| if base_target_url is None: | |
| raise HTTPException( | |
| status_code=404, detail=f"Provider {custom_llm_provider} api base not found" | |
| ) | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = base_url.copy_with(path=encoded_endpoint) | |
| # Add or update query parameters | |
| provider_api_key = passthrough_endpoint_router.get_credentials( | |
| custom_llm_provider=custom_llm_provider, | |
| region_name=None, | |
| ) | |
| auth_headers = provider_config.validate_environment( | |
| headers={}, | |
| model="", | |
| messages=[], | |
| optional_params={}, | |
| litellm_params={}, | |
| api_key=provider_api_key, | |
| api_base=base_target_url, | |
| ) | |
| ## check for streaming | |
| is_streaming_request = False | |
| # anthropic is streaming when 'stream' = True is in the body | |
| if request.method == "POST": | |
| _request_body = await request.json() | |
| if _request_body.get("stream"): | |
| is_streaming_request = True | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=str(updated_url), | |
| custom_headers=auth_headers, | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| received_value = await endpoint_func( | |
| request, | |
| fastapi_response, | |
| user_api_key_dict, | |
| stream=is_streaming_request, # type: ignore | |
| ) | |
| return received_value | |
| async def gemini_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| ): | |
| """ | |
| [Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio) | |
| """ | |
| ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY | |
| google_ai_studio_api_key = request.query_params.get("key") or request.headers.get( | |
| "x-goog-api-key" | |
| ) | |
| user_api_key_dict = await user_api_key_auth( | |
| request=request, api_key=f"Bearer {google_ai_studio_api_key}" | |
| ) | |
| base_target_url = "https://generativelanguage.googleapis.com" | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = base_url.copy_with(path=encoded_endpoint) | |
| # Add or update query parameters | |
| gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials( | |
| custom_llm_provider="gemini", | |
| region_name=None, | |
| ) | |
| if gemini_api_key is None: | |
| raise Exception( | |
| "Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio." | |
| ) | |
| # Merge query parameters, giving precedence to those in updated_url | |
| merged_params = dict(request.query_params) | |
| merged_params.update({"key": gemini_api_key}) | |
| ## check for streaming | |
| is_streaming_request = False | |
| if "stream" in str(updated_url): | |
| is_streaming_request = True | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=str(updated_url), | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| received_value = await endpoint_func( | |
| request, | |
| fastapi_response, | |
| user_api_key_dict, | |
| query_params=merged_params, # type: ignore | |
| stream=is_streaming_request, # type: ignore | |
| ) | |
| return received_value | |
| async def cohere_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| [Docs](https://docs.litellm.ai/docs/pass_through/cohere) | |
| """ | |
| base_target_url = "https://api.cohere.com" | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = base_url.copy_with(path=encoded_endpoint) | |
| # Add or update query parameters | |
| cohere_api_key = passthrough_endpoint_router.get_credentials( | |
| custom_llm_provider="cohere", | |
| region_name=None, | |
| ) | |
| ## check for streaming | |
| is_streaming_request = False | |
| if "stream" in str(updated_url): | |
| is_streaming_request = True | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=str(updated_url), | |
| custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)}, | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| received_value = await endpoint_func( | |
| request, | |
| fastapi_response, | |
| user_api_key_dict, | |
| stream=is_streaming_request, # type: ignore | |
| ) | |
| return received_value | |
| async def vllm_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| [Docs](https://docs.litellm.ai/docs/pass_through/vllm) | |
| """ | |
| return await llm_passthrough_factory_proxy_route( | |
| endpoint=endpoint, | |
| request=request, | |
| fastapi_response=fastapi_response, | |
| user_api_key_dict=user_api_key_dict, | |
| custom_llm_provider="vllm", | |
| ) | |
| async def mistral_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| [Docs](https://docs.litellm.ai/docs/anthropic_completion) | |
| """ | |
| base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai" | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = base_url.copy_with(path=encoded_endpoint) | |
| # Add or update query parameters | |
| mistral_api_key = passthrough_endpoint_router.get_credentials( | |
| custom_llm_provider="mistral", | |
| region_name=None, | |
| ) | |
| ## check for streaming | |
| is_streaming_request = False | |
| # anthropic is streaming when 'stream' = True is in the body | |
| if request.method == "POST": | |
| _request_body = await request.json() | |
| if _request_body.get("stream"): | |
| is_streaming_request = True | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=str(updated_url), | |
| custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)}, | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| received_value = await endpoint_func( | |
| request, | |
| fastapi_response, | |
| user_api_key_dict, | |
| stream=is_streaming_request, # type: ignore | |
| ) | |
| return received_value | |
| async def anthropic_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| [Docs](https://docs.litellm.ai/docs/anthropic_completion) | |
| """ | |
| base_target_url = "https://api.anthropic.com" | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = base_url.copy_with(path=encoded_endpoint) | |
| # Add or update query parameters | |
| anthropic_api_key = passthrough_endpoint_router.get_credentials( | |
| custom_llm_provider="anthropic", | |
| region_name=None, | |
| ) | |
| ## check for streaming | |
| is_streaming_request = False | |
| # anthropic is streaming when 'stream' = True is in the body | |
| if request.method == "POST": | |
| _request_body = await request.json() | |
| if _request_body.get("stream"): | |
| is_streaming_request = True | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=str(updated_url), | |
| custom_headers={"x-api-key": "{}".format(anthropic_api_key)}, | |
| _forward_headers=True, | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| received_value = await endpoint_func( | |
| request, | |
| fastapi_response, | |
| user_api_key_dict, | |
| stream=is_streaming_request, # type: ignore | |
| ) | |
| return received_value | |
| async def bedrock_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| [Docs](https://docs.litellm.ai/docs/pass_through/bedrock) | |
| """ | |
| create_request_copy(request) | |
| try: | |
| from botocore.auth import SigV4Auth | |
| from botocore.awsrequest import AWSRequest | |
| from botocore.credentials import Credentials | |
| except ImportError: | |
| raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") | |
| aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME") | |
| if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents | |
| base_target_url = ( | |
| f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com" | |
| ) | |
| else: | |
| base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = base_url.copy_with(path=encoded_endpoint) | |
| # Add or update query parameters | |
| from litellm.llms.bedrock.chat import BedrockConverseLLM | |
| credentials: Credentials = BedrockConverseLLM().get_credentials() | |
| sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) | |
| headers = {"Content-Type": "application/json"} | |
| # Assuming the body contains JSON data, parse it | |
| try: | |
| data = await request.json() | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail={"error": e}) | |
| _request = AWSRequest( | |
| method="POST", url=str(updated_url), data=json.dumps(data), headers=headers | |
| ) | |
| sigv4.add_auth(_request) | |
| prepped = _request.prepare() | |
| ## check for streaming | |
| is_streaming_request = False | |
| if "stream" in str(updated_url): | |
| is_streaming_request = True | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=str(prepped.url), | |
| custom_headers=prepped.headers, # type: ignore | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| received_value = await endpoint_func( | |
| request, | |
| fastapi_response, | |
| user_api_key_dict, | |
| stream=is_streaming_request, # type: ignore | |
| custom_body=data, # type: ignore | |
| query_params={}, # type: ignore | |
| ) | |
| return received_value | |
| def _is_bedrock_agent_runtime_route(endpoint: str) -> bool: | |
| """ | |
| Return True, if the endpoint should be routed to the `bedrock-agent-runtime` endpoint. | |
| """ | |
| for _route in BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES: | |
| if _route in endpoint: | |
| return True | |
| return False | |
| async def assemblyai_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import ( | |
| AssemblyAIPassthroughLoggingHandler, | |
| ) | |
| """ | |
| [Docs](https://api.assemblyai.com) | |
| """ | |
| # Set base URL based on the route | |
| assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url( | |
| url=str(request.url) | |
| ) | |
| base_target_url = ( | |
| AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region( | |
| region=assembly_region | |
| ) | |
| ) | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = base_url.copy_with(path=encoded_endpoint) | |
| # Add or update query parameters | |
| assemblyai_api_key = passthrough_endpoint_router.get_credentials( | |
| custom_llm_provider="assemblyai", | |
| region_name=assembly_region, | |
| ) | |
| ## check for streaming | |
| is_streaming_request = False | |
| # assemblyai is streaming when 'stream' = True is in the body | |
| if request.method == "POST": | |
| _request_body = await request.json() | |
| if _request_body.get("stream"): | |
| is_streaming_request = True | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=str(updated_url), | |
| custom_headers={"Authorization": "{}".format(assemblyai_api_key)}, | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| received_value = await endpoint_func( | |
| request=request, | |
| fastapi_response=fastapi_response, | |
| user_api_key_dict=user_api_key_dict, | |
| stream=is_streaming_request, # type: ignore | |
| ) | |
| return received_value | |
| async def azure_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| Call any azure endpoint using the proxy. | |
| Just use `{PROXY_BASE_URL}/azure/{endpoint:path}` | |
| """ | |
| base_target_url = get_secret_str(secret_name="AZURE_API_BASE") | |
| if base_target_url is None: | |
| raise Exception( | |
| "Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure." | |
| ) | |
| # Add or update query parameters | |
| azure_api_key = passthrough_endpoint_router.get_credentials( | |
| custom_llm_provider=litellm.LlmProviders.AZURE.value, | |
| region_name=None, | |
| ) | |
| if azure_api_key is None: | |
| raise Exception( | |
| "Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure." | |
| ) | |
| return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler( | |
| endpoint=endpoint, | |
| request=request, | |
| fastapi_response=fastapi_response, | |
| user_api_key_dict=user_api_key_dict, | |
| base_target_url=base_target_url, | |
| api_key=azure_api_key, | |
| custom_llm_provider=litellm.LlmProviders.AZURE, | |
| ) | |
| from abc import ABC, abstractmethod | |
| class BaseVertexAIPassThroughHandler(ABC): | |
| def get_default_base_target_url(vertex_location: Optional[str]) -> str: | |
| pass | |
| def update_base_target_url_with_credential_location( | |
| base_target_url: str, vertex_location: Optional[str] | |
| ) -> str: | |
| pass | |
| class VertexAIDiscoveryPassThroughHandler(BaseVertexAIPassThroughHandler): | |
| def get_default_base_target_url(vertex_location: Optional[str]) -> str: | |
| return "https://discoveryengine.googleapis.com/" | |
| def update_base_target_url_with_credential_location( | |
| base_target_url: str, vertex_location: Optional[str] | |
| ) -> str: | |
| return base_target_url | |
| class VertexAIPassThroughHandler(BaseVertexAIPassThroughHandler): | |
| def get_default_base_target_url(vertex_location: Optional[str]) -> str: | |
| return f"https://{vertex_location}-aiplatform.googleapis.com/" | |
| def update_base_target_url_with_credential_location( | |
| base_target_url: str, vertex_location: Optional[str] | |
| ) -> str: | |
| return f"https://{vertex_location}-aiplatform.googleapis.com/" | |
| def get_vertex_pass_through_handler( | |
| call_type: Literal["discovery", "aiplatform"] | |
| ) -> BaseVertexAIPassThroughHandler: | |
| if call_type == "discovery": | |
| return VertexAIDiscoveryPassThroughHandler() | |
| elif call_type == "aiplatform": | |
| return VertexAIPassThroughHandler() | |
| else: | |
| raise ValueError(f"Invalid call type: {call_type}") | |
| async def _base_vertex_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| get_vertex_pass_through_handler: BaseVertexAIPassThroughHandler, | |
| user_api_key_dict: Optional[UserAPIKeyAuth] = None, | |
| ): | |
| """ | |
| Base function for Vertex AI passthrough routes. | |
| Handles common logic for all Vertex AI services. | |
| Default base_target_url is `https://{vertex_location}-aiplatform.googleapis.com/` | |
| """ | |
| from litellm.llms.vertex_ai.common_utils import ( | |
| construct_target_url, | |
| get_vertex_location_from_url, | |
| get_vertex_project_id_from_url, | |
| ) | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| verbose_proxy_logger.debug("requested endpoint %s", endpoint) | |
| headers: dict = {} | |
| api_key_to_use = get_litellm_virtual_key(request=request) | |
| user_api_key_dict = await user_api_key_auth( | |
| request=request, | |
| api_key=api_key_to_use, | |
| ) | |
| if user_api_key_dict is None: | |
| api_key_to_use = get_litellm_virtual_key(request=request) | |
| user_api_key_dict = await user_api_key_auth( | |
| request=request, | |
| api_key=api_key_to_use, | |
| ) | |
| vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint) | |
| vertex_location: Optional[str] = get_vertex_location_from_url(endpoint) | |
| vertex_credentials = passthrough_endpoint_router.get_vertex_credentials( | |
| project_id=vertex_project, | |
| location=vertex_location, | |
| ) | |
| base_target_url = get_vertex_pass_through_handler.get_default_base_target_url( | |
| vertex_location | |
| ) | |
| headers_passed_through = False | |
| # Use headers from the incoming request if no vertex credentials are found | |
| if vertex_credentials is None or vertex_credentials.vertex_project is None: | |
| headers = dict(request.headers) or {} | |
| headers_passed_through = True | |
| verbose_proxy_logger.debug( | |
| "default_vertex_config not set, incoming request headers %s", headers | |
| ) | |
| headers.pop("content-length", None) | |
| headers.pop("host", None) | |
| else: | |
| vertex_project = vertex_credentials.vertex_project | |
| vertex_location = vertex_credentials.vertex_location | |
| vertex_credentials_str = vertex_credentials.vertex_credentials | |
| _auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async( | |
| credentials=vertex_credentials_str, | |
| project_id=vertex_project, | |
| custom_llm_provider="vertex_ai_beta", | |
| ) | |
| auth_header, _ = vertex_llm_base._get_token_and_url( | |
| model="", | |
| auth_header=_auth_header, | |
| gemini_api_key=None, | |
| vertex_credentials=vertex_credentials_str, | |
| vertex_project=vertex_project, | |
| vertex_location=vertex_location, | |
| stream=False, | |
| custom_llm_provider="vertex_ai_beta", | |
| api_base="", | |
| ) | |
| headers = { | |
| "Authorization": f"Bearer {auth_header}", | |
| } | |
| base_target_url = get_vertex_pass_through_handler.update_base_target_url_with_credential_location( | |
| base_target_url, vertex_location | |
| ) | |
| if base_target_url is None: | |
| base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" | |
| request_route = encoded_endpoint | |
| verbose_proxy_logger.debug("request_route %s", request_route) | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| updated_url = construct_target_url( | |
| base_url=base_target_url, | |
| requested_route=encoded_endpoint, | |
| vertex_location=vertex_location, | |
| vertex_project=vertex_project, | |
| ) | |
| verbose_proxy_logger.debug("updated url %s", updated_url) | |
| ## check for streaming | |
| target = str(updated_url) | |
| is_streaming_request = False | |
| if "stream" in str(updated_url): | |
| is_streaming_request = True | |
| target += "?alt=sse" | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=target, | |
| custom_headers=headers, | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| try: | |
| received_value = await endpoint_func( | |
| request, | |
| fastapi_response, | |
| user_api_key_dict, | |
| stream=is_streaming_request, # type: ignore | |
| ) | |
| except ProxyException as e: | |
| if headers_passed_through: | |
| e.message = f"No credentials found on proxy for project_name={vertex_project} + location={vertex_location}, check `/model/info` for allowed project + region combinations with `use_in_pass_through: true`. Headers were passed through directly but request failed with error: {e.message}" | |
| raise e | |
| return received_value | |
| async def vertex_discovery_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| ): | |
| """ | |
| Call any vertex discovery endpoint using the proxy. | |
| Just use `{PROXY_BASE_URL}/vertex_ai/discovery/{endpoint:path}` | |
| Target url: `https://discoveryengine.googleapis.com` | |
| """ | |
| discovery_handler = get_vertex_pass_through_handler(call_type="discovery") | |
| return await _base_vertex_proxy_route( | |
| endpoint=endpoint, | |
| request=request, | |
| fastapi_response=fastapi_response, | |
| get_vertex_pass_through_handler=discovery_handler, | |
| ) | |
| async def vertex_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| Call LiteLLM proxy via Vertex AI SDK. | |
| [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) | |
| """ | |
| ai_platform_handler = get_vertex_pass_through_handler(call_type="aiplatform") | |
| return await _base_vertex_proxy_route( | |
| endpoint=endpoint, | |
| request=request, | |
| fastapi_response=fastapi_response, | |
| get_vertex_pass_through_handler=ai_platform_handler, | |
| user_api_key_dict=user_api_key_dict, | |
| ) | |
| async def openai_proxy_route( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| ): | |
| """ | |
| Simple pass-through for OpenAI. Use this if you want to directly send a request to OpenAI. | |
| """ | |
| base_target_url = "https://api.openai.com/" | |
| # Add or update query parameters | |
| openai_api_key = passthrough_endpoint_router.get_credentials( | |
| custom_llm_provider=litellm.LlmProviders.OPENAI.value, | |
| region_name=None, | |
| ) | |
| if openai_api_key is None: | |
| raise Exception( | |
| "Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI." | |
| ) | |
| return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler( | |
| endpoint=endpoint, | |
| request=request, | |
| fastapi_response=fastapi_response, | |
| user_api_key_dict=user_api_key_dict, | |
| base_target_url=base_target_url, | |
| api_key=openai_api_key, | |
| custom_llm_provider=litellm.LlmProviders.OPENAI, | |
| ) | |
| class BaseOpenAIPassThroughHandler: | |
| async def _base_openai_pass_through_handler( | |
| endpoint: str, | |
| request: Request, | |
| fastapi_response: Response, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| base_target_url: str, | |
| api_key: str, | |
| custom_llm_provider: litellm.LlmProviders, | |
| ): | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL by properly joining the base URL and endpoint path | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = BaseOpenAIPassThroughHandler._join_url_paths( | |
| base_url=base_url, | |
| path=encoded_endpoint, | |
| custom_llm_provider=custom_llm_provider, | |
| ) | |
| ## check for streaming | |
| is_streaming_request = False | |
| if "stream" in str(updated_url): | |
| is_streaming_request = True | |
| ## CREATE PASS-THROUGH | |
| endpoint_func = create_pass_through_route( | |
| endpoint=endpoint, | |
| target=str(updated_url), | |
| custom_headers=BaseOpenAIPassThroughHandler._assemble_headers( | |
| api_key=api_key, request=request | |
| ), | |
| ) # dynamically construct pass-through endpoint based on incoming path | |
| received_value = await endpoint_func( | |
| request, | |
| fastapi_response, | |
| user_api_key_dict, | |
| stream=is_streaming_request, # type: ignore | |
| query_params=dict(request.query_params), # type: ignore | |
| ) | |
| return received_value | |
| def _append_openai_beta_header(headers: dict, request: Request) -> dict: | |
| """ | |
| Appends the OpenAI-Beta header to the headers if the request is an OpenAI Assistants API request | |
| """ | |
| if ( | |
| RouteChecks._is_assistants_api_request(request) is True | |
| and "OpenAI-Beta" not in headers | |
| ): | |
| headers["OpenAI-Beta"] = "assistants=v2" | |
| return headers | |
| def _assemble_headers(api_key: str, request: Request) -> dict: | |
| base_headers = { | |
| "authorization": "Bearer {}".format(api_key), | |
| "api-key": "{}".format(api_key), | |
| } | |
| return BaseOpenAIPassThroughHandler._append_openai_beta_header( | |
| headers=base_headers, | |
| request=request, | |
| ) | |
| def _join_url_paths( | |
| base_url: httpx.URL, path: str, custom_llm_provider: litellm.LlmProviders | |
| ) -> str: | |
| """ | |
| Properly joins a base URL with a path, preserving any existing path in the base URL. | |
| """ | |
| # Join paths correctly by removing trailing/leading slashes as needed | |
| if not base_url.path or base_url.path == "/": | |
| # If base URL has no path, just use the new path | |
| joined_path_str = str(base_url.copy_with(path=path)) | |
| else: | |
| # Otherwise, combine the paths | |
| base_path = base_url.path.rstrip("/") | |
| clean_path = path.lstrip("/") | |
| full_path = f"{base_path}/{clean_path}" | |
| joined_path_str = str(base_url.copy_with(path=full_path)) | |
| # Apply OpenAI-specific path handling for both branches | |
| if ( | |
| custom_llm_provider == litellm.LlmProviders.OPENAI | |
| and "/v1/" not in joined_path_str | |
| ): | |
| # Insert v1 after api.openai.com for OpenAI requests | |
| joined_path_str = joined_path_str.replace( | |
| "api.openai.com/", "api.openai.com/v1/" | |
| ) | |
| return joined_path_str | |