import gradio as gr
import torch
from transformers import AutoProcessor, Glm4vForConditionalGeneration, TextIteratorStreamer
from pathlib import Path
import threading
import re
import argparse
import copy
import spaces
MODEL_PATH = "/model/glm-4v-9b-0529"
class GLM4VModel:
def __init__(self):
self.processor = None
self.model = None
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
def load(self):
self.processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
self.model = Glm4vForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map=self.device,
attn_implementation="sdpa",
)
def _strip_html(self, t):
return re.sub(r"<[^>]+>", "", t).strip()
def _wrap_text(self, t):
return [{"type": "text", "text": t}]
def _files_to_content(self, media):
out = []
for f in media or []:
ext = Path(f.name).suffix.lower()
if ext in [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".m4v"]:
out.append({"type": "video", "url": f.name})
elif ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
out.append({"type": "image", "url": f.name})
return out
# -----------------------------------------------------------
# ๐ผ๏ธ Output formatting
# -----------------------------------------------------------
def _format_output(self, txt):
"""Called onceๅฎๆด็ๆ็ปๆๆถ"""
think_pat, ans_pat = r"(.*?)", r"(.*?)"
think = re.findall(think_pat, txt, re.DOTALL)
ans = re.findall(ans_pat, txt, re.DOTALL)
html = ""
if think:
html += (
"๐ญ Thinking Process
"
""
+ think[0].strip()
+ "
"
)
body = ans[0] if ans else re.sub(think_pat, "", txt, flags=re.DOTALL)
html += f"
{body.strip()}
"
return html
def _stream_fragment(self, buf: str) -> str:
think_html = ""
if "" in buf:
if "" in buf:
think_content = re.search(r"(.*?)", buf, re.DOTALL)
if think_content:
think_html = (
"๐ญ Thinking Process
"
""
+ think_content.group(1).strip()
+ "
"
)
else:
partial = buf.split("", 1)[1]
think_html = (
"๐ญ Thinking Process
"
"" + partial
)
answer_html = ""
if "
" in buf:
if "" in buf:
ans_content = re.search(r"
(.*?)", buf, re.DOTALL)
if ans_content:
answer_html = (
"
" + ans_content.group(1).strip() + "
"
)
else:
partial = buf.split("
", 1)[1]
answer_html = "" + partial
if not think_html and not answer_html:
return self._strip_html(buf)
return think_html + answer_html
def _build_messages(self, hist, sys_prompt):
msgs = []
if sys_prompt.strip():
msgs.append({
"role": "system",
"content": [{"type": "text", "text": sys_prompt.strip()}]
})
for h in hist:
if h["role"] == "user":
payload = h.get("file_info") or self._wrap_text(
self._strip_html(h["content"])
)
msgs.append({"role": "user", "content": payload})
else:
raw = h["content"]
raw = re.sub(r"
.*?", "", raw, flags=re.DOTALL)
raw = re.sub(r"
", "", raw, flags=re.DOTALL)
clean = self._strip_html(raw).strip()
msgs.append({"role": "assistant", "content": self._wrap_text(clean)})
return msgs
@spaces.GPU(duration=240)
def stream_generate(self, hist, sys_prompt):
msgs = self._build_messages(hist, sys_prompt)
print(msgs)
inputs = self.processor.apply_chat_template(
msgs,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
padding=True,
).to(self.device)
streamer = TextIteratorStreamer(
self.processor.tokenizer, skip_prompt=True, skip_special_tokens=False
)
gen_args = dict(
inputs,
max_new_tokens=8192,
repetition_penalty=1.1,
do_sample=True,
top_k=2,
temperature=None,
top_p=1e-5,
streamer=streamer,
)
threading.Thread(target=self.model.generate, kwargs=gen_args).start()
buf = ""
for tok in streamer:
buf += tok
yield self._stream_fragment(buf)
yield self._format_output(buf)
glm4v = GLM4VModel()
glm4v.load()
def check_files(files):
vids = imgs = 0
for f in files or []:
ext = Path(f.name).suffix.lower()
if ext in [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".m4v"]:
vids += 1
elif ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
imgs += 1
if vids > 1:
return False, "Only 1 video allowed"
if imgs > 10:
return False, "Max 10 images"
if vids and imgs:
return False, "Cannot mix video and images"
return True, ""
def chat(files, msg, hist, sys_prompt):
ok, err = check_files(files)
if not ok:
hist.append({"role": "assistant", "content": err})
yield copy.deepcopy(hist), None, ""
return
payload = glm4v._files_to_content(files) if files else None
if msg.strip():
if payload is None:
payload = glm4v._wrap_text(msg.strip())
else:
payload.append({"type": "text", "text": msg.strip()})
display = f"[{len(files)} file(s) uploaded]\n{msg}" if files else msg
user_rec = {"role": "user", "content": display}
if payload:
user_rec["file_info"] = payload
hist.append(user_rec)
place = {"role": "assistant", "content": ""}
hist.append(place)
yield copy.deepcopy(hist), None, ""
for chunk in glm4v.stream_generate(hist[:-1], sys_prompt):
place["content"] = chunk
yield copy.deepcopy(hist), None, ""
yield copy.deepcopy(hist), None, ""
def reset():
return [], None, ""
css = """.chatbot-container .message-wrap .message{font-size:14px!important}
details summary{cursor:pointer;font-weight:bold}
details[open] summary{margin-bottom:10px}"""
demo = gr.Blocks(title="GLM-4.1V Chat", theme=gr.themes.Soft(), css=css)
with demo:
gr.Markdown("""
GLM-4.1V-9B Gradio Space๐ค
""")
with gr.Row():
with gr.Column(scale=7):
chatbox = gr.Chatbot(
label="Conversation",
type="messages",
height=600,
elem_classes="chatbot-container",
)
textbox = gr.Textbox(label="๐ญ Message", lines=3)
with gr.Row():
send = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
with gr.Column(scale=3):
up = gr.File(
label="๐ Upload",
file_count="multiple",
file_types=["image", "video"],
type="filepath",
)
gr.Markdown("""
Please upload the Bay image before entering text.
""")
sys = gr.Textbox(label="โ๏ธ System Prompt", lines=6)
send.click(chat, inputs=[up, textbox, chatbox, sys], outputs=[chatbox, up, textbox])
textbox.submit(chat, inputs=[up, textbox, chatbox, sys], outputs=[chatbox, up, textbox])
clear.click(reset, outputs=[chatbox, up, textbox])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
demo.launch(
server_port=args.port,
server_name=args.host,
share=args.share,
)