ChickenMcSwag commited on
Commit
b33a74e
·
1 Parent(s): 0a902de

code cleanup

Browse files
Files changed (1) hide show
  1. server.py +12 -40
server.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoProcessor
4
  import os
5
  import torch
6
 
@@ -11,40 +11,17 @@ CACHE_DIR = (
11
  or "/data/huggingface"
12
  )
13
 
14
- # Inspect config and load appropriate stack
15
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)
16
 
17
- is_qwen2_vl = getattr(config, "model_type", None) == "qwen2_vl" or (
18
- config.__class__.__name__.lower().startswith("qwen2vl")
 
 
 
 
 
19
  )
20
 
21
- if is_qwen2_vl:
22
- try:
23
- from transformers import Qwen2VLForConditionalGeneration # type: ignore
24
- except Exception as e:
25
- raise RuntimeError(
26
- "Transformers version does not support Qwen2-VL. Please upgrade transformers to >=4.43."
27
- ) from e
28
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)
29
- model = Qwen2VLForConditionalGeneration.from_pretrained(
30
- MODEL_ID,
31
- torch_dtype=torch.bfloat16,
32
- device_map="auto",
33
- trust_remote_code=True,
34
- cache_dir=CACHE_DIR,
35
- )
36
- _use_processor = True
37
- else:
38
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)
39
- model = AutoModelForCausalLM.from_pretrained(
40
- MODEL_ID,
41
- torch_dtype=torch.bfloat16,
42
- device_map="auto", # automatically shards across GPUs
43
- trust_remote_code=True,
44
- cache_dir=CACHE_DIR
45
- )
46
- _use_processor = False
47
-
48
  app = FastAPI()
49
 
50
  # OpenAI-style request schema
@@ -61,14 +38,9 @@ class ChatCompletionRequest(BaseModel):
61
  async def chat_completions(req: ChatCompletionRequest):
62
  # Concatenate messages into one prompt
63
  prompt = "\n".join([m.content for m in req.messages])
64
- if _use_processor:
65
- inputs = processor(text=prompt, return_tensors="pt").to(model.device)
66
- outputs = model.generate(**inputs, max_new_tokens=req.max_tokens)
67
- text = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
68
- else:
69
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
70
- outputs = model.generate(**inputs, max_new_tokens=req.max_tokens)
71
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
 
73
  return {
74
  "id": "chatcmpl-uground72b",
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import AutoProcessor
4
  import os
5
  import torch
6
 
 
11
  or "/data/huggingface"
12
  )
13
 
14
+ from transformers import Qwen2VLForConditionalGeneration # type: ignore
 
15
 
16
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)
17
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ MODEL_ID,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map="auto",
21
+ trust_remote_code=True,
22
+ cache_dir=CACHE_DIR,
23
  )
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  app = FastAPI()
26
 
27
  # OpenAI-style request schema
 
38
  async def chat_completions(req: ChatCompletionRequest):
39
  # Concatenate messages into one prompt
40
  prompt = "\n".join([m.content for m in req.messages])
41
+ inputs = processor(text=prompt, return_tensors="pt").to(model.device)
42
+ outputs = model.generate(**inputs, max_new_tokens=req.max_tokens)
43
+ text = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
44
 
45
  return {
46
  "id": "chatcmpl-uground72b",