LokeZhou commited on
Commit
72ae5ae
·
1 Parent(s): 6d89bbe

use openai client

Browse files
Files changed (2) hide show
  1. app.py +81 -214
  2. requirements.txt +1 -4
app.py CHANGED
@@ -1,231 +1,98 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoProcessor,TextStreamer,TextIteratorStreamer
4
- from PIL import Image
5
  import base64
6
- import io
7
- import re
8
- from typing import Generator, List, Tuple, Optional
9
- import spaces
10
- import threading
11
-
12
- MAX_HISTORY=5
13
- model_path = 'baidu/ERNIE-4.5-VL-28B-A3B-Thinking'
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_path,
16
- device_map="auto",
17
- torch_dtype=torch.bfloat16,
18
- trust_remote_code=True
19
- )
20
-
21
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
22
- processor.eval()
23
- model.add_image_preprocess(processor)
24
-
25
-
26
- def encode_image(image: Image.Image) -> str:
27
- if image is None:
28
- return ""
29
- buffer = io.BytesIO()
30
- image.save(buffer, format="PNG")
31
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
32
-
33
- def extract_text_from_html(html: str) -> str:
34
- text = re.sub(r'<img.*?>', '', html)
35
- text = re.sub(r'<.*?>', '', text)
36
- if text.startswith("user: "):
37
- return text[6:].strip()
38
- elif text.startswith("assistant: "):
39
- return text[8:].strip()
40
- return text.strip()
41
-
42
- @spaces.GPU(duration=120)
43
- def process_chat(
44
- message: str,
45
- image: Optional[Image.Image],
46
- chat_history: List[Tuple[str, str, Optional[str]]],
47
- max_new_tokens: int,
48
- temperature: float
49
- ) -> Generator[List[Tuple[str, str]], None, None]:
50
- """处理聊天输入,流式生成回应"""
51
-
52
- current_image_b64 = encode_image(image) if image else None
53
- image_html = ""
54
- if current_image_b64:
55
- image_html = f'<br><img src="data:image/png;base64,{current_image_b64}" style="max-width:300px; border-radius:4px;">'
56
-
57
-
58
- user_text = message
59
- user_message_html = f"user: {user_text}{image_html}"
60
-
61
- temp_history = chat_history + [(user_message_html, "", current_image_b64)]
62
-
63
 
64
- model_messages = []
65
-
66
-
67
- for hist in temp_history[:-1]:
68
- user_html, assistant_text, hist_image_b64 = hist
69
- user_text_clean = extract_text_from_html(user_html)
70
-
71
- user_content=[]
72
- if hist_image_b64:
73
- user_content.insert(0, {"type": "image_url","image_url": {"url": hist_image_b64}})
74
- else:
75
- user_content.append({"type": "text", "text": ""})
76
-
77
- model_messages.append({"role": "user", "content": user_content})
78
- assistant_content=[{"type": "text", "text": assistant_text}]
79
- model_messages.append({"role": "bot", "content": assistant_content})
80
-
81
-
82
- current_user_content = [{"type": "text", "text": user_text}]
83
-
84
- if current_image_b64:
85
- current_user_content.insert(0, {"type": "image_url", "image_url": {"url":current_image_b64}})
86
-
87
- model_messages.append({"role": "user", "content": current_user_content})
88
-
89
- text = processor.tokenizer.apply_chat_template(
90
- model_messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
91
- )
92
 
 
 
 
 
 
 
93
 
94
- image_inputs, video_inputs = processor.process_vision_info(model_messages)
 
 
 
 
95
 
96
- inputs = processor(
97
- text=[text],
98
- images=image_inputs,
99
- videos=video_inputs,
100
- padding=True,
101
- return_tensors="pt",
102
- )
103
 
104
- device = next(model.parameters()).device
105
- inputs = inputs.to(device)
106
-
107
-
108
- streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
109
- generation_kwargs = {
110
- **inputs,
111
- "streamer": streamer,
112
- "max_new_tokens": max_new_tokens,
113
- "temperature": temperature,
114
- "use_cache": False
115
- }
116
 
 
 
117
 
118
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
119
- thread.start()
120
-
 
 
 
 
121
 
122
- generated_text = ""
123
- for new_token in streamer:
124
- generated_text += new_token
125
-
126
- temp_history[-1] = (user_message_html, f"assistant: {generated_text}", current_image_b64)
127
-
128
- display_history = [(h[0], h[1]) for h in temp_history[-MAX_HISTORY:]]
129
- yield display_history
130
 
131
- thread.join()
132
-
133
-
134
- def chat_interface(
135
- message: str,
136
- image: Optional[Image.Image],
137
- chat_history: List[Tuple[str, str, Optional[str]]],
138
- max_new_tokens: int,
139
- temperature: float
140
- ) -> Generator[tuple, None, None]:
141
-
142
-
143
- for updated_display_history in process_chat(message, image, chat_history, max_new_tokens, temperature):
 
 
 
 
 
144
 
145
- updated_full_history = []
146
- for i, display_item in enumerate(updated_display_history):
147
-
148
- full_item = next((h for h in chat_history if h[0] == display_item[0] and h[1] == display_item[1]), None)
149
- if full_item:
150
- updated_full_history.append(full_item)
151
- else:
152
-
153
- if i == len(updated_display_history) - 1:
154
-
155
- img_b64 = encode_image(image) if image else None
156
- updated_full_history.append((display_item[0], display_item[1], img_b64))
157
- else:
158
- updated_full_history.append((display_item[0], display_item[1], None))
159
-
160
- yield "", None, updated_full_history, updated_display_history
161
-
162
-
163
- with gr.Blocks(title="ERNIE-4.5-VL-28B-A3B-Thinking", theme=gr.themes.Soft()) as demo:
164
-
165
-
166
- full_chat_history = gr.State([])
167
-
168
- with gr.Row():
169
- with gr.Column(scale=3):
170
-
171
- chat_display = gr.Chatbot(
172
- label="chat_bot",
173
- height=500,
174
- bubble_full_width=False
175
- )
176
-
177
- with gr.Column(scale=1):
178
-
179
- gr.Markdown("generation kwargs")
180
- max_new_tokens = gr.Slider(
181
- minimum=128, maximum=32768, value=8192, step=255,
182
- label="max_new_token"
183
- )
184
- temperature = gr.Slider(
185
- minimum=0.1, maximum=2.0, value=0.7, step=0.1,
186
- label="temperature"
187
- )
188
- clear_btn = gr.Button("clear", variant="destructive")
189
-
190
- with gr.Row():
191
-
192
- text_input = gr.Textbox(
193
- label="input text",
194
- placeholder="input text messages...",
195
- lines=2
196
  )
197
- image_input = gr.Image(
198
- label="input image",
199
- placeholder="upload image...",
200
- type="pil",
201
- height=100
202
- )
203
- submit_btn = gr.Button("submit", variant="primary")
204
-
205
-
206
- submit_btn.click(
207
- fn=chat_interface,
208
- inputs=[text_input, image_input, full_chat_history, max_new_tokens, temperature],
209
- outputs=[text_input, image_input, full_chat_history, chat_display]
 
 
 
210
  )
211
-
 
 
 
 
 
 
 
212
 
213
- text_input.submit(
214
- fn=chat_interface,
215
- inputs=[text_input, image_input, full_chat_history, max_new_tokens, temperature],
216
- outputs=[text_input, image_input, full_chat_history, chat_display]
217
- )
218
-
219
 
220
- def clear_chat():
221
- return [], []
222
-
223
- clear_btn.click(
224
- fn=clear_chat,
225
- inputs=[],
226
- outputs=[full_chat_history, chat_display]
227
- )
228
-
229
 
230
  if __name__ == "__main__":
231
- demo.launch()
 
 
 
 
 
1
  import base64
2
+ import mimetypes
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ import gradio as gr
8
+ from openai import OpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "ERNIE-4.5-VL-28B-A3B-Thinking")
11
+ api_key = os.getenv("OPENAI_API_KEY","")
12
+ _client = OpenAI(
13
+ base_url="https://9d4as2f4m0e8f0a6.aistudio-app.com/v1/chat/completions",
14
+ api_key=api_key,
15
+ )
16
 
17
+ def _data_url(path: str) -> str:
18
+ mime, _ = mimetypes.guess_type(path)
19
+ mime = mime or "application/octet-stream"
20
+ data = base64.b64encode(Path(path).read_bytes()).decode("utf-8")
21
+ return f"data:{mime};base64,{data}"
22
 
23
+ def _image_content(path: str) -> Dict[str, Any]:
24
+ return {"type": "image_url", "image_url": {"url": _data_url(path)}}
 
 
 
 
 
25
 
26
+ def _text_content(text: str) -> Dict[str, Any]:
27
+ return {"type": "text", "text": text}
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def _message(role: str, content: Any) -> Dict[str, Any]:
30
+ return {"role": role, "content": content}
31
 
32
+ def _build_user_message(message: Dict[str, Any]) -> Dict[str, Any]:
33
+ files = message.get("files") or []
34
+ text = (message.get("text") or "").strip()
35
+ content: List[Dict[str, Any]] = [_image_content(p) for p in files]
36
+ if text:
37
+ content.append(_text_content(text))
38
+ return _message("user", content)
39
 
40
+ def _convert_history(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
41
+ msgs: List[Dict[str, Any]] = []
42
+ user_content: List[Dict[str, Any]] = []
 
 
 
 
 
43
 
44
+ for turn in history or []:
45
+ role, content = turn.get("role"), turn.get("content")
46
+ if role == "user":
47
+ if isinstance(content, str):
48
+ user_content.append(_text_content(content))
49
+ elif isinstance(content, tuple):
50
+ user_content.extend(_image_content(path) for path in content if path)
51
+ elif role == "bot" or role == "assistant":
52
+ msgs.append(_message("user", user_content.copy()))
53
+ user_content.clear()
54
+ content = [{"type": "text", "text": content}]
55
+ msgs.append(_message("bot", content))
56
+ return msgs
57
+
58
+
59
+ def stream_response(message: Dict[str, Any], history: List[Dict[str, Any]], model_name: str = DEFAULT_MODEL):
60
+ messages = _convert_history(history)
61
+ messages.append(_build_user_message(message))
62
 
63
+ try:
64
+ stream = _client.chat.completions.create(
65
+ model=model_name,
66
+ messages=messages,
67
+ stream=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
+ partial = ""
70
+ for chunk in stream:
71
+ delta = chunk.choices[0].delta.content
72
+ if delta:
73
+ partial += delta
74
+ yield partial
75
+ except Exception as e:
76
+ yield f"Failed to get response: {e}"
77
+
78
+ def build_demo() -> gr.Blocks:
79
+ chatbot = gr.Chatbot(type="messages", allow_tags=["think"])
80
+ textbox = gr.MultimodalTextbox(
81
+ show_label=False,
82
+ placeholder="Enter text, or upload one or more images...",
83
+ file_types=["image"],
84
+ file_count="multiple"
85
  )
86
+ return gr.ChatInterface(
87
+ fn=stream_response,
88
+ type="messages",
89
+ multimodal=True,
90
+ chatbot=chatbot,
91
+ textbox=textbox,
92
+ title="ERNIE-4.5-VL-28B-A3B-Thinking",
93
+ ).queue(default_concurrency_limit=8)
94
 
 
 
 
 
 
 
95
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
+ build_demo().launch(server_name="0.0.0.0", server_port=8100,share=False)
requirements.txt CHANGED
@@ -1,4 +1 @@
1
- transformers==4.57.1
2
- decord
3
- sentencepiece
4
- accelerate
 
1
+ openai