ChickenMcSwag commited on
Commit
27a259f
·
1 Parent(s): f4e3c36

take in iamges now

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. server.py +69 -10
requirements.txt CHANGED
@@ -4,5 +4,6 @@ torch
4
  transformers>=4.43.0
5
  accelerate
6
  sentencepiece
 
7
  pillow
8
  torchvision
 
4
  transformers>=4.43.0
5
  accelerate
6
  sentencepiece
7
+ Pillow
8
  pillow
9
  torchvision
server.py CHANGED
@@ -1,7 +1,9 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
4
- import torch, os
 
 
5
 
6
  MODEL_ID = "osunlp/UGround-V1-72B"
7
  CACHE_DIR = (
@@ -27,21 +29,78 @@ app = FastAPI()
27
  async def root():
28
  return {"status": "ok"}
29
 
30
- class Message(BaseModel):
31
- role: str
32
- content: str
33
-
34
  class ChatCompletionRequest(BaseModel):
35
  model: str
36
- messages: list[Message]
37
  max_tokens: int = 128
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  @app.post("/v1/chat/completions")
40
  async def chat_completions(req: ChatCompletionRequest):
41
- prompt = "\n".join([m.content for m in req.messages])
42
- inputs = processor(text=prompt, return_tensors="pt").to(model.device)
43
- outputs = model.generate(**inputs, max_new_tokens=req.max_tokens)
44
- text = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  return {
47
  "id": "chatcmpl-uground72b",
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
4
+ import torch, os, base64, io
5
+ from typing import Any, Dict, List, Tuple
6
+ from PIL import Image
7
 
8
  MODEL_ID = "osunlp/UGround-V1-72B"
9
  CACHE_DIR = (
 
29
  async def root():
30
  return {"status": "ok"}
31
 
 
 
 
 
32
  class ChatCompletionRequest(BaseModel):
33
  model: str
34
+ messages: List[Dict[str, Any]]
35
  max_tokens: int = 128
36
 
37
+ def _decode_base64_image(data_url: str) -> Image.Image:
38
+ if data_url.startswith("data:"):
39
+ header, b64data = data_url.split(",", 1)
40
+ else:
41
+ b64data = data_url
42
+ img_bytes = base64.b64decode(b64data)
43
+ return Image.open(io.BytesIO(img_bytes)).convert("RGB")
44
+
45
+ def _to_qwen_messages_and_images(messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Any]]:
46
+ qwen_msgs: List[Dict[str, Any]] = []
47
+ images: List[Any] = []
48
+ for msg in messages:
49
+ role = msg.get("role", "user")
50
+ content = msg.get("content")
51
+ q_content: List[Dict[str, Any]] = []
52
+
53
+ if isinstance(content, str):
54
+ q_content.append({"type": "text", "text": content})
55
+ elif isinstance(content, list):
56
+ for part in content:
57
+ ptype = part.get("type")
58
+ if ptype == "text":
59
+ text_val = part.get("text") or part.get("content") or ""
60
+ q_content.append({"type": "text", "text": text_val})
61
+ elif ptype in ("image", "image_url"):
62
+ # OpenAI style: {type:"image_url", image_url:{url:"..."}}
63
+ url = part.get("image")
64
+ if url is None and isinstance(part.get("image_url"), dict):
65
+ url = part["image_url"].get("url")
66
+ if isinstance(url, str) and url.startswith("data:image"):
67
+ img = _decode_base64_image(url)
68
+ images.append(img)
69
+ q_content.append({"type": "image", "image": img})
70
+ else:
71
+ # URL or non-base64 string
72
+ images.append(url)
73
+ q_content.append({"type": "image", "image": url})
74
+ else:
75
+ # Unknown content; coerce to text
76
+ q_content.append({"type": "text", "text": str(content)})
77
+
78
+ qwen_msgs.append({"role": role, "content": q_content})
79
+
80
+ return qwen_msgs, images
81
+
82
  @app.post("/v1/chat/completions")
83
  async def chat_completions(req: ChatCompletionRequest):
84
+ qwen_messages, image_inputs = _to_qwen_messages_and_images(req.messages)
85
+ prompt_text = processor.apply_chat_template(
86
+ qwen_messages, tokenize=False, add_generation_prompt=True
87
+ )
88
+ inputs = processor(
89
+ text=[prompt_text],
90
+ images=image_inputs if image_inputs else None,
91
+ padding=True,
92
+ return_tensors="pt",
93
+ )
94
+ inputs = inputs.to(model.device)
95
+
96
+ generated_ids = model.generate(**inputs, max_new_tokens=req.max_tokens)
97
+ trimmed = [
98
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
99
+ ]
100
+ output_texts = processor.batch_decode(
101
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
102
+ )
103
+ text = output_texts[0] if output_texts else ""
104
 
105
  return {
106
  "id": "chatcmpl-uground72b",