Spaces:
Paused
Paused
| # What is this? | |
| ## Common checks for /v1/models and `/model/info` | |
| from typing import Dict, List, Optional, Set | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth | |
| from litellm.router import Router | |
| from litellm.types.router import LiteLLM_Params | |
| from litellm.utils import get_valid_models | |
| def _check_wildcard_routing(model: str) -> bool: | |
| """ | |
| Returns True if a model is a provider wildcard. | |
| eg: | |
| - anthropic/* | |
| - openai/* | |
| - * | |
| """ | |
| if "*" in model: | |
| return True | |
| return False | |
| def get_provider_models( | |
| provider: str, litellm_params: Optional[LiteLLM_Params] = None | |
| ) -> Optional[List[str]]: | |
| """ | |
| Returns the list of known models by provider | |
| """ | |
| if provider == "*": | |
| return get_valid_models(litellm_params=litellm_params) | |
| if provider in litellm.models_by_provider: | |
| provider_models = get_valid_models( | |
| custom_llm_provider=provider, litellm_params=litellm_params | |
| ) | |
| # provider_models = copy.deepcopy(litellm.models_by_provider[provider]) | |
| for idx, _model in enumerate(provider_models): | |
| if provider not in _model: | |
| provider_models[idx] = f"{provider}/{_model}" | |
| return provider_models | |
| return None | |
| def _get_models_from_access_groups( | |
| model_access_groups: Dict[str, List[str]], | |
| all_models: List[str], | |
| ) -> List[str]: | |
| idx_to_remove = [] | |
| new_models = [] | |
| for idx, model in enumerate(all_models): | |
| if model in model_access_groups: | |
| idx_to_remove.append(idx) | |
| new_models.extend(model_access_groups[model]) | |
| for idx in sorted(idx_to_remove, reverse=True): | |
| all_models.pop(idx) | |
| all_models.extend(new_models) | |
| return all_models | |
| def get_key_models( | |
| user_api_key_dict: UserAPIKeyAuth, | |
| proxy_model_list: List[str], | |
| model_access_groups: Dict[str, List[str]], | |
| ) -> List[str]: | |
| """ | |
| Returns: | |
| - List of model name strings | |
| - Empty list if no models set | |
| - If model_access_groups is provided, only return models that are in the access groups | |
| """ | |
| all_models: List[str] = [] | |
| if len(user_api_key_dict.models) > 0: | |
| all_models = user_api_key_dict.models | |
| if SpecialModelNames.all_team_models.value in all_models: | |
| all_models = user_api_key_dict.team_models | |
| if SpecialModelNames.all_proxy_models.value in all_models: | |
| all_models = proxy_model_list | |
| all_models = _get_models_from_access_groups( | |
| model_access_groups=model_access_groups, all_models=all_models | |
| ) | |
| verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models))) | |
| return all_models | |
| def get_team_models( | |
| team_models: List[str], | |
| proxy_model_list: List[str], | |
| model_access_groups: Dict[str, List[str]], | |
| ) -> List[str]: | |
| """ | |
| Returns: | |
| - List of model name strings | |
| - Empty list if no models set | |
| - If model_access_groups is provided, only return models that are in the access groups | |
| """ | |
| all_models = [] | |
| if len(team_models) > 0: | |
| all_models = team_models | |
| if SpecialModelNames.all_team_models.value in all_models: | |
| all_models = team_models | |
| if SpecialModelNames.all_proxy_models.value in all_models: | |
| all_models = proxy_model_list | |
| all_models = _get_models_from_access_groups( | |
| model_access_groups=model_access_groups, all_models=all_models | |
| ) | |
| verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models))) | |
| return all_models | |
| def get_complete_model_list( | |
| key_models: List[str], | |
| team_models: List[str], | |
| proxy_model_list: List[str], | |
| user_model: Optional[str], | |
| infer_model_from_keys: Optional[bool], | |
| return_wildcard_routes: Optional[bool] = False, | |
| llm_router: Optional[Router] = None, | |
| ) -> List[str]: | |
| """Logic for returning complete model list for a given key + team pair""" | |
| """ | |
| - If key list is empty -> defer to team list | |
| - If team list is empty -> defer to proxy model list | |
| If list contains wildcard -> return known provider models | |
| """ | |
| unique_models: Set[str] = set() | |
| if key_models: | |
| unique_models.update(key_models) | |
| elif team_models: | |
| unique_models.update(team_models) | |
| else: | |
| unique_models.update(proxy_model_list) | |
| if user_model: | |
| unique_models.add(user_model) | |
| if infer_model_from_keys: | |
| valid_models = get_valid_models() | |
| unique_models.update(valid_models) | |
| all_wildcard_models = _get_wildcard_models( | |
| unique_models=unique_models, | |
| return_wildcard_routes=return_wildcard_routes, | |
| llm_router=llm_router, | |
| ) | |
| return list(unique_models) + all_wildcard_models | |
| def get_known_models_from_wildcard( | |
| wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None | |
| ) -> List[str]: | |
| try: | |
| provider, model = wildcard_model.split("/", 1) | |
| except ValueError: # safely fail | |
| return [] | |
| # get all known provider models | |
| wildcard_models = get_provider_models( | |
| provider=provider, litellm_params=litellm_params | |
| ) | |
| if wildcard_models is None: | |
| return [] | |
| if model == "*": | |
| return wildcard_models or [] | |
| else: | |
| model_prefix = model.replace("*", "") | |
| filtered_wildcard_models = [ | |
| wc_model | |
| for wc_model in wildcard_models | |
| if wc_model.split("/")[1].startswith(model_prefix) | |
| ] | |
| return filtered_wildcard_models | |
| def _get_wildcard_models( | |
| unique_models: Set[str], | |
| return_wildcard_routes: Optional[bool] = False, | |
| llm_router: Optional[Router] = None, | |
| ) -> List[str]: | |
| models_to_remove = set() | |
| all_wildcard_models = [] | |
| for model in unique_models: | |
| if _check_wildcard_routing(model=model): | |
| if ( | |
| return_wildcard_routes | |
| ): # will add the wildcard route to the list eg: anthropic/*. | |
| all_wildcard_models.append(model) | |
| ## get litellm params from model | |
| if llm_router is not None: | |
| model_list = llm_router.get_model_list(model_name=model) | |
| if model_list is not None: | |
| for router_model in model_list: | |
| wildcard_models = get_known_models_from_wildcard( | |
| wildcard_model=model, | |
| litellm_params=LiteLLM_Params( | |
| **router_model["litellm_params"] # type: ignore | |
| ), | |
| ) | |
| all_wildcard_models.extend(wildcard_models) | |
| else: | |
| # get all known provider models | |
| wildcard_models = get_known_models_from_wildcard(wildcard_model=model) | |
| if wildcard_models is not None: | |
| models_to_remove.add(model) | |
| all_wildcard_models.extend(wildcard_models) | |
| for model in models_to_remove: | |
| unique_models.remove(model) | |
| return all_wildcard_models | |