Spaces:
Paused
Paused
| """ | |
| This module is used to pass through requests to the LLM APIs. | |
| """ | |
| import asyncio | |
| import contextvars | |
| from functools import partial | |
| from typing import Any, Coroutine, Optional, Union | |
| from urllib.parse import urlencode | |
| import httpx | |
| from httpx._types import CookieTypes, QueryParamTypes, RequestFiles | |
| import litellm | |
| from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider | |
| from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler | |
| from litellm.utils import client | |
| from .utils import BasePassthroughUtils | |
| async def allm_passthrough_route( | |
| *, | |
| method: str, | |
| endpoint: str, | |
| custom_llm_provider: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| request_query_params: Optional[dict] = None, | |
| request_headers: Optional[dict] = None, | |
| stream: bool = False, | |
| content: Optional[Any] = None, | |
| data: Optional[dict] = None, | |
| files: Optional[RequestFiles] = None, | |
| json: Optional[Any] = None, | |
| params: Optional[QueryParamTypes] = None, | |
| cookies: Optional[CookieTypes] = None, | |
| client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
| **kwargs, | |
| ) -> Union[httpx.Response, Coroutine[Any, Any, httpx.Response]]: | |
| """ | |
| Async: Reranks a list of documents based on their relevance to the query | |
| """ | |
| try: | |
| loop = asyncio.get_event_loop() | |
| kwargs["allm_passthrough_route"] = True | |
| func = partial( | |
| llm_passthrough_route, | |
| method=method, | |
| endpoint=endpoint, | |
| custom_llm_provider=custom_llm_provider, | |
| api_base=api_base, | |
| api_key=api_key, | |
| request_query_params=request_query_params, | |
| request_headers=request_headers, | |
| stream=stream, | |
| content=content, | |
| data=data, | |
| files=files, | |
| json=json, | |
| params=params, | |
| cookies=cookies, | |
| client=client, | |
| **kwargs, | |
| ) | |
| ctx = contextvars.copy_context() | |
| func_with_context = partial(ctx.run, func) | |
| init_response = await loop.run_in_executor(None, func_with_context) | |
| if asyncio.iscoroutine(init_response): | |
| response = await init_response | |
| else: | |
| response = init_response | |
| return response | |
| except Exception as e: | |
| raise e | |
| def llm_passthrough_route( | |
| *, | |
| method: str, | |
| endpoint: str, | |
| model: str, | |
| custom_llm_provider: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| request_query_params: Optional[dict] = None, | |
| request_headers: Optional[dict] = None, | |
| allm_passthrough_route: bool = False, | |
| stream: bool = False, | |
| content: Optional[Any] = None, | |
| data: Optional[dict] = None, | |
| files: Optional[RequestFiles] = None, | |
| json: Optional[Any] = None, | |
| params: Optional[QueryParamTypes] = None, | |
| cookies: Optional[CookieTypes] = None, | |
| client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
| **kwargs, | |
| ) -> Union[httpx.Response, Coroutine[Any, Any, httpx.Response]]: | |
| """ | |
| Pass through requests to the LLM APIs. | |
| Step 1. Build the request | |
| Step 2. Send the request | |
| Step 3. Return the response | |
| [TODO] Refactor this into a provider-config pattern, once we expand this to non-vllm providers. | |
| """ | |
| if client is None: | |
| if allm_passthrough_route: | |
| client = litellm.module_level_aclient | |
| else: | |
| client = litellm.module_level_client | |
| model, custom_llm_provider, api_key, api_base = get_llm_provider( | |
| model=model, | |
| custom_llm_provider=custom_llm_provider, | |
| api_base=api_base, | |
| api_key=api_key, | |
| ) | |
| from litellm.types.utils import LlmProviders | |
| from litellm.utils import ProviderConfigManager | |
| provider_config = ProviderConfigManager.get_provider_model_info( | |
| provider=LlmProviders(custom_llm_provider), | |
| model=model, | |
| ) | |
| if provider_config is None: | |
| raise Exception(f"Provider {custom_llm_provider} not found") | |
| base_target_url = provider_config.get_api_base(api_base) | |
| if base_target_url is None: | |
| raise Exception(f"Provider {custom_llm_provider} api base not found") | |
| encoded_endpoint = httpx.URL(endpoint).path | |
| # Ensure endpoint starts with '/' for proper URL construction | |
| if not encoded_endpoint.startswith("/"): | |
| encoded_endpoint = "/" + encoded_endpoint | |
| # Construct the full target URL using httpx | |
| base_url = httpx.URL(base_target_url) | |
| updated_url = base_url.copy_with(path=encoded_endpoint) | |
| if request_query_params: | |
| # Create a new URL with the merged query params | |
| updated_url = updated_url.copy_with( | |
| query=urlencode(request_query_params).encode("ascii") | |
| ) | |
| # Add or update query parameters | |
| provider_api_key = provider_config.get_api_key(api_key) | |
| auth_headers = provider_config.validate_environment( | |
| headers={}, | |
| model=model, | |
| messages=[], | |
| optional_params={}, | |
| litellm_params={}, | |
| api_key=provider_api_key, | |
| api_base=base_target_url, | |
| ) | |
| headers = BasePassthroughUtils.forward_headers_from_request( | |
| request_headers=request_headers or {}, | |
| headers=auth_headers, | |
| forward_headers=False, | |
| ) | |
| ## SWAP MODEL IN JSON BODY | |
| if json and isinstance(json, dict) and "model" in json: | |
| json["model"] = model | |
| request = client.client.build_request( | |
| method=method, | |
| url=updated_url, | |
| content=content, | |
| data=data, | |
| files=files, | |
| json=json, | |
| params=params, | |
| headers=headers, | |
| cookies=cookies, | |
| ) | |
| response = client.client.send(request=request, stream=stream) | |
| return response | |