Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
| import torch, os, base64, io, logging, time | |
| from typing import Any, Dict, List, Tuple | |
| from PIL import Image | |
| MODEL_ID = "osunlp/UGround-V1-72B" | |
| CACHE_DIR = ( | |
| os.environ.get("HF_HUB_CACHE") | |
| or os.environ.get("HF_HOME") | |
| or "/data/huggingface" | |
| ) | |
| # PyTorch performance settings | |
| # 1) Ensure CUDA kernel cache directory is writable/persistent to avoid recompilation stalls | |
| KERNEL_CACHE_DIR = os.environ.get("PYTORCH_KERNEL_CACHE_PATH", "/tmp/torch_kernels") | |
| os.environ["PYTORCH_KERNEL_CACHE_PATH"] = KERNEL_CACHE_DIR | |
| try: | |
| os.makedirs(KERNEL_CACHE_DIR, exist_ok=True) | |
| except Exception: | |
| pass | |
| # 2) Enable TF32 for faster matmul on Ampere+ GPUs (minimal quality impact) | |
| try: | |
| torch.backends.cuda.matmul.allow_tf32 = True # type: ignore[attr-defined] | |
| torch.backends.cudnn.allow_tf32 = True # type: ignore[attr-defined] | |
| torch.set_float32_matmul_precision("high") # type: ignore[attr-defined] | |
| except Exception: | |
| pass | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR, use_fast=False | |
| ) | |
| model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| cache_dir=CACHE_DIR, | |
| ) | |
| model.eval() | |
| try: | |
| torch.set_grad_enabled(False) | |
| except Exception: | |
| pass | |
| app = FastAPI() | |
| # Configure basic logging for debugging | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| async def root(): | |
| return {"status": "ok"} | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[Dict[str, Any]] | |
| max_tokens: int = 256 | |
| MAX_IMAGE_WIDTH = 512 | |
| MAX_IMAGE_HEIGHT = 388 | |
| def _decode_base64_image(data_url: str) -> Image.Image: | |
| try: | |
| is_data_url = data_url.startswith("data:") | |
| if is_data_url: | |
| header, b64data = data_url.split(",", 1) | |
| logger.debug("Decoding image from data URL; header prefix=%r", header[:50]) | |
| else: | |
| b64data = data_url | |
| logger.debug("Decoding image from raw base64 string; length=%d", len(b64data)) | |
| img_bytes = base64.b64decode(b64data) | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| orig_w, orig_h = img.width, img.height | |
| # Downscale if larger than bounds, preserving aspect ratio | |
| if orig_w > MAX_IMAGE_WIDTH or orig_h > MAX_IMAGE_HEIGHT: | |
| target = (MAX_IMAGE_WIDTH, MAX_IMAGE_HEIGHT) | |
| img = img.copy() | |
| img.thumbnail(target, Image.LANCZOS) | |
| logger.debug( | |
| "Resized image from %sx%s to %sx%s (bounds %sx%s)", | |
| orig_w, | |
| orig_h, | |
| img.width, | |
| img.height, | |
| MAX_IMAGE_WIDTH, | |
| MAX_IMAGE_HEIGHT, | |
| ) | |
| try: | |
| logger.debug("Decoded image: size=%sx%s mode=%s", img.width, img.height, img.mode) | |
| except Exception: | |
| logger.debug("Decoded image but could not log image metadata") | |
| return img | |
| except Exception: | |
| logger.exception("Failed to decode base64 image") | |
| raise | |
| def _to_qwen_messages_and_images(messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Any]]: | |
| qwen_msgs: List[Dict[str, Any]] = [] | |
| images: List[Any] = [] | |
| logger.debug("Begin parsing messages: count=%d", len(messages) if messages else 0) | |
| for idx, msg in enumerate(messages): | |
| role = msg.get("role", "user") | |
| content = msg.get("content") | |
| logger.debug("Processing message #%d role=%s content_type=%s", idx, role, type(content).__name__) | |
| q_content: List[Dict[str, Any]] = [] | |
| if isinstance(content, str): | |
| logger.debug("Message #%d text length=%d", idx, len(content)) | |
| q_content.append({"type": "text", "text": content}) | |
| elif isinstance(content, list): | |
| logger.debug("Message #%d has %d content parts", idx, len(content)) | |
| for pidx, part in enumerate(content): | |
| ptype = part.get("type") | |
| logger.debug("Part #%d type=%s", pidx, ptype) | |
| if ptype == "text": | |
| text_val = part.get("text") or part.get("content") or "" | |
| logger.debug("Part #%d text length=%d", pidx, len(text_val)) | |
| q_content.append({"type": "text", "text": text_val}) | |
| elif ptype in ("image", "image_url"): | |
| # OpenAI style: {type:"image_url", image_url:{url:"..."}} | |
| url = part.get("image") | |
| if url is None and isinstance(part.get("image_url"), dict): | |
| url = part["image_url"].get("url") | |
| if isinstance(url, str) and url.startswith("data:image"): | |
| logger.debug("Part #%d image provided as base64 data URL", pidx) | |
| img = _decode_base64_image(url) | |
| images.append(img) | |
| q_content.append({"type": "image", "image": img}) | |
| else: | |
| # URL or non-base64 string | |
| logger.debug("Part #%d image provided as URL or non-base64 string: %s", pidx, str(url)[:200]) | |
| images.append(url) | |
| q_content.append({"type": "image", "image": url}) | |
| else: | |
| # Unknown content; coerce to text | |
| logger.debug("Message #%d unknown content type; coercing to text", idx) | |
| q_content.append({"type": "text", "text": str(content)}) | |
| qwen_msgs.append({"role": role, "content": q_content}) | |
| logger.debug("Finished parsing messages: qwen_msgs=%d images=%d", len(qwen_msgs), len(images)) | |
| return qwen_msgs, images | |
| def _make_tiny_base64_png(size: Tuple[int, int] = (64, 48), color: Tuple[int, int, int] = (128, 128, 128)) -> str: | |
| buf = io.BytesIO() | |
| Image.new("RGB", size, color).save(buf, format="PNG") | |
| data = base64.b64encode(buf.getvalue()).decode("ascii") | |
| return f"data:image/png;base64,{data}" | |
| async def _startup_warmup(): | |
| if os.environ.get("DISABLE_WARMUP", "0") == "1": | |
| logger.info("Warmup disabled via DISABLE_WARMUP=1") | |
| return | |
| try: | |
| logger.info("Warmup start: compiling kernels (text + tiny image)") | |
| # Text-only warmup | |
| text_msgs: List[Dict[str, Any]] = [ | |
| {"role": "user", "content": "Hello"} | |
| ] | |
| qmsgs_t, _ = _to_qwen_messages_and_images(text_msgs) | |
| prompt_t = processor.apply_chat_template(qmsgs_t, tokenize=False, add_generation_prompt=True) | |
| inputs_t = processor(text=[prompt_t], images=None, padding=True, return_tensors="pt") | |
| inputs_t = inputs_t.to(model.device) | |
| _t0 = time.perf_counter() | |
| with torch.no_grad(): | |
| _ = model.generate(**inputs_t, max_new_tokens=int(os.environ.get("WARMUP_MAX_NEW_TOKENS", "4")), max_time=float(os.environ.get("WARMUP_MAX_TIME_SECONDS", "3"))) | |
| logger.info("Text warmup done in %.1f ms", (time.perf_counter() - _t0) * 1000.0) | |
| # Tiny image + text warmup | |
| tiny_url = _make_tiny_base64_png() | |
| viz_msgs: List[Dict[str, Any]] = [ | |
| {"role": "user", "content": [ | |
| {"type": "text", "text": "Describe the image"}, | |
| {"type": "image_url", "image_url": {"url": tiny_url}}, | |
| ]} | |
| ] | |
| qmsgs_v, images_v = _to_qwen_messages_and_images(viz_msgs) | |
| prompt_v = processor.apply_chat_template(qmsgs_v, tokenize=False, add_generation_prompt=True) | |
| inputs_v = processor(text=[prompt_v], images=images_v, padding=True, return_tensors="pt") | |
| inputs_v = inputs_v.to(model.device) | |
| _t1 = time.perf_counter() | |
| with torch.no_grad(): | |
| _ = model.generate(**inputs_v, max_new_tokens=int(os.environ.get("WARMUP_MAX_NEW_TOKENS", "4")), max_time=float(os.environ.get("WARMUP_MAX_TIME_SECONDS", "3"))) | |
| logger.info("Vision warmup done in %.1f ms", (time.perf_counter() - _t1) * 1000.0) | |
| logger.info("Warmup complete") | |
| except Exception: | |
| logger.exception("Warmup failed") | |
| async def chat_completions(req: ChatCompletionRequest): | |
| logger.debug( | |
| "Request received: model=%s, max_tokens=%s, message_count=%d", | |
| req.model, | |
| req.max_tokens, | |
| len(req.messages) if req.messages is not None else 0, | |
| ) | |
| if req.messages: | |
| logger.debug("First message preview: %s", str(req.messages[0])[:300]) | |
| qwen_messages, image_inputs = _to_qwen_messages_and_images(req.messages) | |
| logger.debug( | |
| "Converted messages: qwen_count=%d, images_count=%d", | |
| len(qwen_messages), | |
| len(image_inputs) if image_inputs is not None else 0, | |
| ) | |
| if qwen_messages: | |
| logger.debug("First qwen message preview: %s", str(qwen_messages[0])[:300]) | |
| prompt_text = processor.apply_chat_template( | |
| qwen_messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| logger.debug("Prompt length (chars)=%d; preview=%r", len(prompt_text), prompt_text[:200]) | |
| inputs = processor( | |
| text=[prompt_text], | |
| images=image_inputs if image_inputs else None, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| try: | |
| tensor_info_pre = { | |
| k: (tuple(v.shape), str(getattr(v, "dtype", "<na>"))) | |
| for k, v in inputs.items() | |
| if hasattr(v, "shape") | |
| } | |
| logger.debug("Processor outputs (pre .to): %s", tensor_info_pre) | |
| except Exception: | |
| logger.debug("Could not summarize processor outputs before device move") | |
| inputs = inputs.to(model.device) | |
| try: | |
| tensor_info_post = { | |
| k: ( | |
| tuple(v.shape), | |
| str(getattr(v, "dtype", "<na>")), | |
| str(getattr(v, "device", "<na>")), | |
| ) | |
| for k, v in inputs.items() | |
| if torch.is_tensor(v) | |
| } | |
| logger.debug("Inputs moved to device=%s; tensor_info=%s", getattr(model, "device", "<unknown>"), tensor_info_post) | |
| except Exception: | |
| logger.debug("Could not summarize inputs after device move") | |
| logger.debug("Starting generation: max_new_tokens=%d", req.max_tokens) | |
| _t0 = time.perf_counter() | |
| generated_ids = model.generate(**inputs, max_new_tokens=req.max_tokens) | |
| _elapsed_ms = (time.perf_counter() - _t0) * 1000.0 | |
| try: | |
| logger.debug( | |
| "Generation done in %.1f ms; generated_ids shape=%s dtype=%s device=%s", | |
| _elapsed_ms, | |
| tuple(generated_ids.shape) if hasattr(generated_ids, "shape") else "<na>", | |
| str(getattr(generated_ids, "dtype", "<na>")), | |
| str(getattr(generated_ids, "device", "<na>")), | |
| ) | |
| except Exception: | |
| logger.debug("Could not summarize generated_ids") | |
| trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| try: | |
| lengths_in = [row.size(0) for row in inputs.input_ids] | |
| lengths_out = [row.size(0) for row in generated_ids] | |
| logger.debug("Token lengths: input=%s, output=%s", lengths_in, lengths_out) | |
| except Exception: | |
| logger.debug("Could not compute token length summaries") | |
| output_texts = processor.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| text = output_texts[0] if output_texts else "" | |
| logger.debug( | |
| "Decoded %d sequences; first_text_len=%d", | |
| len(output_texts), | |
| len(text) if text else 0, | |
| ) | |
| if text: | |
| logger.debug("Output preview: %r", text[:500]) | |
| return { | |
| "id": "chatcmpl-uground72b", | |
| "object": "chat.completion", | |
| "choices": [{ | |
| "index": 0, | |
| "message": {"role": "assistant", "content": text}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |