Spaces:
Runtime error
Runtime error
| import gc | |
| import json | |
| import time | |
| import requests | |
| import base64 | |
| import uvicorn | |
| import argparse | |
| import torch | |
| from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedModel, PreTrainedTokenizer, \ | |
| TextIteratorStreamer, CodeGenTokenizerFast as Tokenizer | |
| from contextlib import asynccontextmanager | |
| from loguru import logger | |
| from typing import List, Literal, Union, Tuple, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from PIL import Image | |
| from io import BytesIO | |
| import os | |
| import re | |
| from threading import Thread | |
| from moondream import Moondream, detect_device | |
| import omnichat | |
| # 请求 | |
| class TextContent(BaseModel): | |
| type: Literal["text"] | |
| text: str | |
| class ImageUrl(BaseModel): | |
| url: str | |
| class ImageUrlContent(BaseModel): | |
| type: Literal["image_url"] | |
| image_url: ImageUrl | |
| ContentItem = Union[TextContent, ImageUrlContent] | |
| class ChatMessageInput(BaseModel): | |
| role: Literal["user", "assistant", "system"] | |
| content: Union[str, List[ContentItem]] | |
| name: Optional[str] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[ChatMessageInput] | |
| temperature: Optional[float] = 0.8 | |
| top_p: Optional[float] = 0.8 | |
| max_tokens: Optional[int] = None | |
| stream: Optional[bool] = False | |
| # Additional parameters | |
| repetition_penalty: Optional[float] = 1.0 | |
| # 响应 | |
| class ChatMessageResponse(BaseModel): | |
| role: Literal["assistant"] | |
| content: str = None | |
| name: Optional[str] = None | |
| class ChatCompletionResponseChoice(BaseModel): | |
| index: int | |
| message: ChatMessageResponse | |
| class DeltaMessage(BaseModel): | |
| role: Optional[Literal["user", "assistant", "system"]] = None | |
| content: Optional[str] = None | |
| class ChatCompletionResponseStreamChoice(BaseModel): | |
| index: int | |
| delta: DeltaMessage | |
| class UsageInfo(BaseModel): | |
| prompt_tokens: int = 0 | |
| total_tokens: int = 0 | |
| completion_tokens: Optional[int] = 0 | |
| class ChatCompletionResponse(BaseModel): | |
| model: str | |
| object: Literal["chat.completion", "chat.completion.chunk"] | |
| choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] | |
| created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
| usage: Optional[UsageInfo] = None | |
| # 图片输入处理 | |
| def process_img(input_data): | |
| if isinstance(input_data, str): | |
| # URL | |
| if input_data.startswith("http://") or input_data.startswith("https://"): | |
| response = requests.get(input_data) | |
| image_data = response.content | |
| pil_image = Image.open(BytesIO(image_data)).convert('RGB') | |
| # base64 | |
| elif input_data.startswith("data:image/"): | |
| base64_data = input_data.split(",")[1] | |
| image_data = base64.b64decode(base64_data) | |
| pil_image = Image.open(BytesIO(image_data)).convert('RGB') | |
| # img_path | |
| else: | |
| pil_image = Image.open(input_data) | |
| # PIL | |
| elif isinstance(input_data, Image.Image): | |
| pil_image = input_data | |
| else: | |
| raise ValueError("data type error") | |
| return pil_image | |
| # 历史消息处理 | |
| def process_history_and_images(messages: List[ChatMessageInput]) -> Tuple[ | |
| Optional[str], Optional[List[Tuple[str, str]]], Optional[List[Image.Image]]]: | |
| formatted_history = [] | |
| image_list = [] | |
| last_user_query = '' | |
| for i, message in enumerate(messages): | |
| role = message.role | |
| content = message.content | |
| if isinstance(content, list): # text | |
| text_content = ' '.join(item.text for item in content if isinstance(item, TextContent)) | |
| else: | |
| text_content = content | |
| if isinstance(content, list): # image | |
| for item in content: | |
| if isinstance(item, ImageUrlContent): | |
| image_url = item.image_url.url | |
| image = process_img(image_url) | |
| image_list.append(image) | |
| if role == 'user': | |
| if i == len(messages) - 1: # last message | |
| last_user_query = text_content | |
| else: | |
| formatted_history.append((text_content, '')) | |
| elif role == 'assistant': | |
| if formatted_history: | |
| if formatted_history[-1][1] != '': | |
| assert False, f"the last query is answered. answer again. {formatted_history[-1][0]}, {formatted_history[-1][1]}, {text_content}" | |
| formatted_history[-1] = (formatted_history[-1][0], text_content) | |
| else: | |
| assert False, f"assistant reply before user" | |
| else: | |
| assert False, f"unrecognized role: {role}" | |
| return last_user_query, formatted_history, image_list | |
| # Moondrean推理 | |
| def generate_stream_moondream(params: dict): | |
| global model, tokenizer | |
| # 输入处理 | |
| def chat_history_to_prompt(history): | |
| prompt = "" | |
| for i, (old_query, response) in enumerate(history): | |
| prompt += f"Question: {old_query}\n\nAnswer: {response}\n\n" | |
| return prompt | |
| messages = params["messages"] | |
| prompt, formatted_history, image_list = process_history_and_images(messages) | |
| history = chat_history_to_prompt(formatted_history) | |
| # 只处理最后一张图 | |
| img = image_list[-1] | |
| # 构建输入 | |
| ''' | |
| answer_question( | |
| self, | |
| image_embeds, | |
| question, | |
| tokenizer, | |
| chat_history="", | |
| result_queue=None, | |
| **kwargs, | |
| ) | |
| ''' | |
| image_embeds = model.encode_image(img) | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
| gen_kwargs = { | |
| "image_embeds": image_embeds, | |
| "question": prompt, | |
| "tokenizer": tokenizer, | |
| "chat_history": history, | |
| "result_queue": None, | |
| "streamer": streamer, | |
| } | |
| thread = Thread( | |
| target=model.answer_question, | |
| kwargs=gen_kwargs, | |
| ) | |
| input_echo_len = 0 | |
| total_len = 0 | |
| # 启动推理 | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| clean_text = re.sub("<$|END$", "", new_text) | |
| buffer += clean_text | |
| yield { | |
| "text": buffer.strip("<END"), | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| } | |
| generated_ret ={ | |
| "text": buffer.strip("<END"), | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| } | |
| yield generated_ret | |
| # Moondrean单次响应 | |
| def generate_moondream(params: dict): | |
| for response in generate_stream_moondream(params): | |
| pass | |
| return response | |
| # CogVLM推理 | |
| def generate_stream_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): | |
| """ | |
| Generates a stream of responses using the CogVLM model in inference mode. | |
| It's optimized to handle continuous input-output interactions with the model in a streaming manner. | |
| """ | |
| messages = params["messages"] | |
| temperature = float(params.get("temperature", 1.0)) | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| max_new_tokens = int(params.get("max_tokens", 256)) | |
| query, history, image_list = process_history_and_images(messages) | |
| logger.debug(f"==== request ====\n{query}") | |
| # only can slove the latest picture | |
| input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, | |
| images=[image_list[-1]]) | |
| inputs = { | |
| 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), | |
| 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), | |
| 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), | |
| 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]], | |
| } | |
| if 'cross_images' in input_by_model and input_by_model['cross_images']: | |
| inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]] | |
| input_echo_len = len(inputs["input_ids"][0]) | |
| streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = { | |
| "repetition_penalty": repetition_penalty, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": False, | |
| "top_p": top_p, | |
| 'streamer': streamer, | |
| } | |
| if temperature > 1e-5: | |
| gen_kwargs["temperature"] = temperature | |
| total_len = 0 | |
| generated_text = "" | |
| with torch.no_grad(): | |
| model.generate(**inputs, **gen_kwargs) | |
| for next_text in streamer: | |
| generated_text += next_text | |
| yield { | |
| "text": generated_text, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| } | |
| ret = { | |
| "text": generated_text, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| } | |
| yield ret | |
| # CogVLM单次响应 | |
| def generate_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): | |
| for response in generate_stream_cogvlm(model, tokenizer, params): | |
| pass | |
| return response | |
| def generate_minicpm(model, params): | |
| messages = params["messages"] | |
| query, history, image_list = process_history_and_images(messages) | |
| msgs = history | |
| msgs.append({'role': 'user', 'content': query}) | |
| image = image_list[-1] | |
| # image is a PIL image | |
| buffer = BytesIO() | |
| image.save(buffer, format="JPEG") # You can adjust the format as needed | |
| buffer.seek(0) | |
| image_base64 = base64.b64encode(buffer.read()) | |
| image_base64_str = image_base64.decode("utf-8") | |
| input = {'image': image_base64_str, 'question': json.dumps(msgs)} | |
| generation = model.chat(input) | |
| response = {"text": generation, "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}} | |
| print(response) | |
| return response | |
| # 流式响应 | |
| async def predict(model_id: str, params: dict): | |
| return "no stream" | |
| torch.set_grad_enabled(False) | |
| # 生命周期管理器,结束清显存 | |
| async def lifespan(app: FastAPI): | |
| yield | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| app = FastAPI(lifespan=lifespan) | |
| # 允许跨域 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 对话路由 | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| global model, tokenizer | |
| # 检查请求 | |
| if len(request.messages) < 1 or request.messages[-1].role == "assistant": | |
| raise HTTPException(status_code=400, detail="Invalid request") | |
| gen_params = dict( | |
| messages=request.messages, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| max_tokens=request.max_tokens or 1024, | |
| echo=False, | |
| stream=request.stream, | |
| ) | |
| # 流式响应 | |
| if request.stream: | |
| generate = predict(request.model, gen_params) | |
| return | |
| # 单次响应 | |
| if STATE_MOD == "cog": | |
| response = generate_cogvlm(model, tokenizer, gen_params) | |
| elif STATE_MOD == "moon": | |
| response = generate_moondream(gen_params) | |
| elif STATE_MOD == "mini": | |
| response = generate_minicpm(model, gen_params) | |
| usage = UsageInfo() | |
| message = ChatMessageResponse( | |
| role="assistant", | |
| content=response["text"], | |
| ) | |
| logger.debug(f"==== message ====\n{message}") | |
| choice_data = ChatCompletionResponseChoice( | |
| index=0, | |
| message=message, | |
| ) | |
| task_usage = UsageInfo.model_validate(response["usage"]) | |
| for usage_key, usage_value in task_usage.model_dump().items(): | |
| setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
| return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) | |
| # 模型切换路由配置 | |
| STATE_MOD = "moon" | |
| MODEL_PATH = "" | |
| # 模型加载 | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def load_mod(model_input, mod_type): | |
| global model, tokenizer, language_processor_version | |
| if mod_type == "cog": | |
| tokenizer_path = os.environ.get("TOKENIZER_PATH", 'lmsys/vicuna-7b-v1.5') | |
| tokenizer = LlamaTokenizer.from_pretrained( | |
| tokenizer_path, | |
| trust_remote_code=True, | |
| signal_type=language_processor_version | |
| ) | |
| if 'cuda' in DEVICE: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_input, | |
| trust_remote_code=True, | |
| load_in_4bit=True, | |
| torch_dtype=torch_type, | |
| low_cpu_mem_usage=True | |
| ).eval() | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_input, | |
| trust_remote_code=True | |
| ).float().to(DEVICE).eval() | |
| elif mod_type == "moon": | |
| device, dtype = detect_device() | |
| model = Moondream.from_pretrained(model_input).to(device=device, dtype=dtype).eval() | |
| tokenizer = Tokenizer.from_pretrained(model_input) | |
| elif mod_type == "mini": | |
| model, tokenizer = omnichat.OmniLMMChat(model_input), None | |
| async def switch_vqa(): | |
| global model, STATE_MOD, mod_vqa, language_processor_version | |
| STATE_MOD = "cog" | |
| del model | |
| model = None | |
| language_processor_version = "chat_old" | |
| load_mod(mod_vqa, STATE_MOD) | |
| async def switch_chat(): | |
| global model, STATE_MOD, mod_chat, language_processor_version | |
| STATE_MOD = "cog" | |
| del model | |
| model = None | |
| language_processor_version = "chat" | |
| load_mod(mod_chat, STATE_MOD) | |
| async def switch_moon(): | |
| global model, STATE_MOD, mod_moon | |
| STATE_MOD = "moon" | |
| del model | |
| model = None | |
| load_mod(mod_moon, STATE_MOD) | |
| async def switch_mini(): | |
| global model, STATE_MOD, mod_mini | |
| STATE_MOD = "mini" | |
| del model | |
| model = None | |
| load_mod(mod_mini, STATE_MOD) | |
| # 关闭 | |
| async def close(): | |
| global model | |
| del model | |
| model = None | |
| gc.collect() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--mod", type=str, default="moondrean") | |
| args = parser.parse_args() | |
| mod = args.mod | |
| mod_vqa = './models/cogagent-vqa-hf' | |
| mod_chat = './models/cogagent-chat-hf' | |
| mod_moon = './models/moondream' | |
| mod_mini = './models/MiniCPM-Llama3-V-2_5' | |
| ''' | |
| mod_list = [ | |
| "moondrean", | |
| "Cog-vqa", | |
| "Cog-chat" | |
| "MiniCPM" | |
| ] | |
| ''' | |
| if mod == "Cog-vqa": | |
| STATE_MOD = "cog" | |
| MODEL_PATH = mod_vqa | |
| language_processor_version = "chat_old" | |
| elif mod == "Cog-chat": | |
| STATE_MOD = "cog" | |
| MODEL_PATH = mod_chat | |
| language_processor_version = "chat" | |
| elif mod == "moondream": | |
| STATE_MOD = "moon" | |
| MODEL_PATH = mod_moon | |
| elif mod == "MiniCPM": | |
| STATE_MOD = "mini" | |
| MODEL_PATH = mod_mini | |
| if __name__ == "__main__": | |
| if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
| torch_type = torch.bfloat16 | |
| else: | |
| torch_type = torch.float16 | |
| print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) | |
| load_mod(MODEL_PATH, STATE_MOD) | |
| uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) | |