Spaces:
Runtime error
Runtime error
| import argparse | |
| import markdown2 | |
| import os | |
| import sys | |
| import uvicorn | |
| import requests | |
| from pathlib import Path | |
| from typing import Union, Optional | |
| from fastapi import FastAPI, Depends, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel, Field | |
| from sse_starlette.sse import EventSourceResponse, ServerSentEvent | |
| from tclogger import logger | |
| from constants.models import AVAILABLE_MODELS_DICTS, PRO_MODELS | |
| from constants.envs import CONFIG, SECRETS | |
| from networks.exceptions import HfApiException, INVALID_API_KEY_ERROR | |
| from messagers.message_composer import MessageComposer | |
| from mocks.stream_chat_mocker import stream_chat_mock | |
| from networks.huggingface_streamer import HuggingfaceStreamer | |
| from networks.huggingchat_streamer import HuggingchatStreamer | |
| from networks.openai_streamer import OpenaiStreamer | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| import tiktoken | |
| class EmbeddingsAPIInference: | |
| def __init__(self, model_name): | |
| self.model_name=model_name | |
| def encode(self, x:str, api_key=None): | |
| if api_key: | |
| headers = {"Authorization": f"Bearer {api_key}"} | |
| else: | |
| headers = None | |
| API_URL = "https://api-inference.huggingface.co/models/"+self.model_name | |
| payload = { | |
| "inputs": x, | |
| "options":{"wait_for_model":True} | |
| } | |
| return requests.post(API_URL, headers=headers, json=payload).json() | |
| class SentenceTransformerLocal(SentenceTransformer): | |
| def encode(self, *args, **kwargs): | |
| kwargs.pop("api_key", None) | |
| return super().encode(*args, **kwargs).tolist() | |
| class ChatAPIApp: | |
| def __init__(self): | |
| self.app = FastAPI( | |
| docs_url="/", | |
| title=CONFIG["app_name"], | |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
| version=CONFIG["version"], | |
| ) | |
| self.setup_routes() | |
| self.embeddings = { | |
| "mxbai-embed-large":SentenceTransformerLocal("mixedbread-ai/mxbai-embed-large-v1"), | |
| "nomic-embed-text": SentenceTransformerLocal("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True), | |
| "multilingual-e5-large-instruct":SentenceTransformerLocal("intfloat/multilingual-e5-large-instruct"), | |
| "intfloat/multilingual-e5-large-instruct":EmbeddingsAPIInference("intfloat/multilingual-e5-large-instruct"), | |
| "mixedbread-ai/mxbai-embed-large-v1":EmbeddingsAPIInference("mixedbread-ai/mxbai-embed-large-v1") | |
| } | |
| self.rerank = { | |
| "bge-reranker-v2-m3":CrossEncoder("BAAI/bge-reranker-v2-m3") | |
| } | |
| def get_available_models(self): | |
| return {"object": "list", "data": AVAILABLE_MODELS_DICTS} | |
| def get_available_models_ollama(self): | |
| ollama_models_dict = [{"name" if k == "id" else k:v for k,v in d.items()} for d in AVAILABLE_MODELS_DICTS.copy()] | |
| return {"object": "list", "models":ollama_models_dict} | |
| def extract_api_key( | |
| credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer()), | |
| ): | |
| api_key = None | |
| if credentials: | |
| api_key = credentials.credentials | |
| env_api_key = SECRETS["HF_LLM_API_KEY"] | |
| return api_key | |
| def auth_api_key(self, api_key: str): | |
| env_api_key = SECRETS["HF_LLM_API_KEY"] | |
| # require no api_key | |
| if not env_api_key: | |
| return None | |
| # user provides HF_TOKEN | |
| if api_key and api_key.startswith("hf_"): | |
| return api_key | |
| # user provides correct API_KEY | |
| if str(api_key) == str(env_api_key): | |
| return None | |
| raise INVALID_API_KEY_ERROR | |
| class ChatCompletionsPostItem(BaseModel): | |
| model: str = Field( | |
| default="nous-mixtral-8x7b", | |
| description="(str) `nous-mixtral-8x7b`", | |
| ) | |
| messages: list = Field( | |
| default=[{"role": "user", "content": "Hello, who are you?"}], | |
| description="(list) Messages", | |
| ) | |
| temperature: Union[float, None] = Field( | |
| default=0.5, | |
| description="(float) Temperature", | |
| ) | |
| top_p: Union[float, None] = Field( | |
| default=0.95, | |
| description="(float) top p", | |
| ) | |
| max_tokens: Union[int, None] = Field( | |
| default=-1, | |
| description="(int) Max tokens", | |
| ) | |
| use_cache: bool = Field( | |
| default=False, | |
| description="(bool) Use cache", | |
| ) | |
| stream: bool = Field( | |
| default=True, | |
| description="(bool) Stream", | |
| ) | |
| def chat_completions( | |
| self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key) | |
| ): | |
| try: | |
| print(item.messages) | |
| item.model = "llama3-8b" if item.model == "llama3" else item.model | |
| api_key = self.auth_api_key(api_key) | |
| if item.model == "gpt-3.5-turbo": | |
| streamer = OpenaiStreamer() | |
| stream_response = streamer.chat_response(messages=item.messages) | |
| elif item.model in PRO_MODELS: | |
| streamer = HuggingchatStreamer(model=item.model) | |
| stream_response = streamer.chat_response( | |
| messages=item.messages, | |
| ) | |
| else: | |
| streamer = HuggingfaceStreamer(model=item.model) | |
| composer = MessageComposer(model=item.model) | |
| composer.merge(messages=item.messages) | |
| stream_response = streamer.chat_response( | |
| prompt=composer.merged_str, | |
| temperature=item.temperature, | |
| top_p=item.top_p, | |
| max_new_tokens=item.max_tokens, | |
| api_key=api_key, | |
| use_cache=item.use_cache, | |
| ) | |
| if item.stream: | |
| event_source_response = EventSourceResponse( | |
| streamer.chat_return_generator(stream_response), | |
| media_type="text/event-stream", | |
| ping=2000, | |
| ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), | |
| ) | |
| return event_source_response | |
| else: | |
| data_response = streamer.chat_return_dict(stream_response) | |
| return data_response | |
| except HfApiException as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.detail) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def chat_completions_ollama( | |
| self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key) | |
| ): | |
| try: | |
| print(item.messages) | |
| item.model = "llama3-8b" if item.model == "llama3" else item.model | |
| api_key = self.auth_api_key(api_key) | |
| if item.model == "gpt-3.5-turbo": | |
| streamer = OpenaiStreamer() | |
| stream_response = streamer.chat_response(messages=item.messages) | |
| elif item.model in PRO_MODELS: | |
| streamer = HuggingchatStreamer(model=item.model) | |
| stream_response = streamer.chat_response( | |
| messages=item.messages, | |
| ) | |
| else: | |
| streamer = HuggingfaceStreamer(model=item.model) | |
| composer = MessageComposer(model=item.model) | |
| composer.merge(messages=item.messages) | |
| stream_response = streamer.chat_response( | |
| prompt=composer.merged_str, | |
| temperature=item.temperature, | |
| top_p=item.top_p, | |
| max_new_tokens=item.max_tokens, | |
| api_key=api_key, | |
| use_cache=item.use_cache, | |
| ) | |
| data_response = streamer.chat_return_dict(stream_response) | |
| print(data_response) | |
| data_response = { | |
| "model": data_response.get('model'), | |
| "created_at": data_response.get('created'), | |
| "message": { | |
| "role": "assistant", | |
| "content": data_response["choices"][0]["message"]["content"], | |
| }, | |
| # "response": data_response["choices"][0]["message"]["content"], | |
| "done": True, | |
| } | |
| return data_response | |
| except HfApiException as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.detail) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| class GenerateRequest(BaseModel): | |
| model: str = Field( | |
| default="nous-mixtral-8x7b", | |
| description="(str) `nous-mixtral-8x7b`", | |
| ) | |
| prompt: str = Field( | |
| default="Hello, who are you?", | |
| description="(str) Prompt", | |
| ) | |
| stream: bool = Field( | |
| default=False, | |
| description="(bool) Stream", | |
| ) | |
| options: dict = Field( | |
| default={ | |
| "temperature":0.6, | |
| "top_p":0.9, | |
| "max_tokens":-1, | |
| "use_cache":False | |
| }, | |
| description="(dict) Options" | |
| ) | |
| # temperature: Union[float, None] = Field( | |
| # default=0.5, | |
| # description="(float) Temperature", | |
| # ) | |
| # top_p: Union[float, None] = Field( | |
| # default=0.95, | |
| # description="(float) top p", | |
| # ) | |
| # max_tokens: Union[int, None] = Field( | |
| # default=-1, | |
| # description="(int) Max tokens", | |
| # ) | |
| # use_cache: bool = Field( | |
| # default=False, | |
| # description="(bool) Use cache", | |
| # ) | |
| def generate_text( | |
| self, item: GenerateRequest, api_key: str = Depends(extract_api_key) | |
| ): | |
| try: | |
| item.model = "llama3-8b" if item.model == "llama3" else item.model | |
| api_key = self.auth_api_key(api_key) | |
| if item.model == "gpt-3.5-turbo": | |
| streamer = OpenaiStreamer() | |
| stream_response = streamer.chat_response(messages=[{"user":item.prompt}]) | |
| elif item.model in PRO_MODELS: | |
| streamer = HuggingchatStreamer(model=item.model) | |
| stream_response = streamer.chat_response( | |
| messages=[{"user":item.prompt}], | |
| ) | |
| else: | |
| streamer = HuggingfaceStreamer(model=item.model) | |
| options = {k:v for k,v in item.options.items() if v is not None} | |
| stream_response = streamer.chat_response( | |
| prompt=item.prompt, | |
| **options, | |
| api_key=api_key, | |
| # temperature=item.temperature, | |
| # top_p=item.top_p, | |
| # max_new_tokens=item.max_tokens, | |
| # api_key=api_key, | |
| # use_cache=item.use_cache, | |
| # temperature=item.options.get('temperature', 0.6), | |
| # top_p=item.options.get('top_p', 0.95), | |
| # max_new_tokens=item.options.get('max_new_tokens', -1), | |
| # api_key=api_key, | |
| # use_cache=item.options.get('use_cache', False), | |
| ) | |
| if item.stream: | |
| event_source_response = EventSourceResponse( | |
| streamer.ollama_return_generator(stream_response), | |
| media_type="text/event-stream", | |
| ping=2000, | |
| ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), | |
| ) | |
| # import json | |
| # print(event_source_response, "EVENT RESPONSE FIRST") | |
| # event_source_response = json.loads(str(event_source_response).split('data: ')[-1]) | |
| # print(event_source_response, "EVENT RESPONSE SECOND") | |
| # event_source_response = { | |
| # "model": event_source_response.get('model'), | |
| # "created_at": event_source_response.get('created_at'), | |
| # "response": event_source_response.get('choices')[-1].get('delta').get('content'), | |
| # "done": True if event_source_response.get('choices')[-1].get('finish_reason') != None else False, | |
| # } | |
| # print(event_source_response, "EVENT RESPONSE THIRD") | |
| return event_source_response | |
| else: | |
| data_response = streamer.chat_return_dict(stream_response) | |
| print(data_response) | |
| data_response = { | |
| "model": data_response.get('model'), | |
| "created_at": data_response.get('created'), | |
| "response": data_response["choices"][0]["message"]["content"], | |
| "done": True, | |
| } | |
| return data_response | |
| except HfApiException as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.detail) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| class EmbeddingRequest(BaseModel): | |
| model: str | |
| input: list | |
| options: Optional[dict] = None | |
| class OllamaEmbeddingRequest(BaseModel): | |
| model: str | |
| prompt: str | |
| options: Optional[dict] = None | |
| def get_embeddings(self, request: EmbeddingRequest, api_key: str = Depends(extract_api_key)): | |
| try: | |
| model = request.model | |
| model_kwargs = request.options | |
| encoding = tiktoken.get_encoding("cl100k_base") | |
| embeddings = self.embeddings[model].encode([encoding.decode(inp) for inp in request.input], api_key=api_key)#, **model_kwargs) | |
| return { | |
| "object":"list", | |
| "data":[ | |
| {"object": "embedding", "index": i, "embedding": emb} for i,emb in enumerate(embeddings)#.tolist()) | |
| ], | |
| "model": model, | |
| "usage":{}, | |
| } | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def get_embeddings_ollama(self, request: OllamaEmbeddingRequest, api_key: str = Depends(extract_api_key)): | |
| try: | |
| model = request.model | |
| model_kwargs = request.options | |
| embeddings = self.embeddings[model].encode(request.prompt, api_key=api_key)#, **model_kwargs) | |
| return {"embedding": embeddings}#.tolist()} | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| class RerankRequest(BaseModel): | |
| model: str | |
| input: str | |
| documents: list | |
| return_documents: bool | |
| top_k: Optional[int]=None | |
| def _score_to_list(self, x): | |
| x['score'] = x['score'].tolist() | |
| return x | |
| def get_rerank(self, request: RerankRequest, api_key: str = Depends(extract_api_key)): | |
| ranks = self.rerank[request.model].rank( | |
| request.input, | |
| request.documents, | |
| top_k=request.top_k, | |
| return_documents=request.return_documents | |
| ) | |
| return [self._score_to_list(x) for x in ranks] | |
| def get_readme(self): | |
| readme_path = Path(__file__).parents[1] / "README.md" | |
| with open(readme_path, "r", encoding="utf-8") as rf: | |
| readme_str = rf.read() | |
| readme_html = markdown2.markdown( | |
| readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"] | |
| ) | |
| return readme_html | |
| def setup_routes(self): | |
| for prefix in ["", "/v1", "/api", "/api/v1"]: | |
| if prefix in ["/api/v1"]: | |
| include_in_schema = True | |
| else: | |
| include_in_schema = False | |
| self.app.get( | |
| prefix + "/models", | |
| summary="Get available models", | |
| include_in_schema=include_in_schema, | |
| )(self.get_available_models) | |
| self.app.post( | |
| prefix+"/rerank", | |
| summary="Rerank documents", | |
| include_in_schema=include_in_schema, | |
| )(self.get_rerank) | |
| self.app.post( | |
| prefix + "/chat/completions", | |
| summary="OpenAI Chat completions in conversation session", | |
| include_in_schema=include_in_schema, | |
| )(self.chat_completions) | |
| self.app.post( | |
| prefix + "/generate", | |
| summary="Ollama text generation", | |
| include_in_schema=include_in_schema, | |
| )(self.generate_text) | |
| self.app.post( | |
| prefix + "/chat", | |
| summary="Ollama Chat completions in conversation session", | |
| include_in_schema=include_in_schema, | |
| )(self.chat_completions_ollama) | |
| if prefix in ["/api"]: | |
| self.app.post( | |
| prefix + "/embeddings", | |
| summary="Ollama Get Embeddings with prompt", | |
| include_in_schema=True, | |
| )(self.get_embeddings_ollama) | |
| else: | |
| self.app.post( | |
| prefix + "/embeddings", | |
| summary="Get Embeddings with prompt", | |
| include_in_schema=include_in_schema, | |
| )(self.get_embeddings) | |
| self.app.get( | |
| "/api/tags", | |
| summary="Get Available Models Ollama", | |
| include_in_schema=True, | |
| )(self.get_available_models_ollama) | |
| self.app.get( | |
| "/readme", | |
| summary="README of HF LLM API", | |
| response_class=HTMLResponse, | |
| include_in_schema=False, | |
| )(self.get_readme) | |
| class ArgParser(argparse.ArgumentParser): | |
| def __init__(self, *args, **kwargs): | |
| super(ArgParser, self).__init__(*args, **kwargs) | |
| self.add_argument( | |
| "-s", | |
| "--host", | |
| type=str, | |
| default=CONFIG["host"], | |
| help=f"Host for {CONFIG['app_name']}", | |
| ) | |
| self.add_argument( | |
| "-p", | |
| "--port", | |
| type=int, | |
| default=CONFIG["port"], | |
| help=f"Port for {CONFIG['app_name']}", | |
| ) | |
| self.add_argument( | |
| "-d", | |
| "--dev", | |
| default=False, | |
| action="store_true", | |
| help="Run in dev mode", | |
| ) | |
| self.args = self.parse_args(sys.argv[1:]) | |
| app = ChatAPIApp().app | |
| if __name__ == "__main__": | |
| args = ArgParser().args | |
| if args.dev: | |
| uvicorn.run("__main__:app", host=args.host, port=args.port, reload=True) | |
| else: | |
| uvicorn.run("__main__:app", host=args.host, port=args.port, reload=False) | |
| # python -m apis.chat_api # [Docker] on product mode | |
| # python -m apis.chat_api -d # [Dev] on develop mode | |