Joel Lundgren commited on
Commit
dc90ed9
·
1 Parent(s): f32efcc
Files changed (2) hide show
  1. app.py +22 -19
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
  from ultralytics import YOLO
 
 
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
  # Load a pre-trained YOLO model
8
  model = YOLO('yolov8n.pt')
@@ -59,13 +60,18 @@ def get_llm(model_name):
59
  return llm_cache[model_name]
60
 
61
  model_map = {
62
- "qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct",
63
- "gemma3:1b": "google/gemma-3-1b-it"
64
  }
65
  hf_model_name = model_map[model_name]
66
 
67
- tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
68
- model = AutoModelForCausalLM.from_pretrained(hf_model_name)
 
 
 
 
 
69
 
70
  llm_cache[model_name] = (model, tokenizer)
71
  return model, tokenizer
@@ -88,24 +94,21 @@ def generate_text(model_name, system_prompt, user_prompt):
88
  {"role": "system", "content": system_prompt},
89
  {"role": "user", "content": user_prompt},
90
  ]
91
-
92
- text = tokenizer.apply_chat_template(
93
  messages,
94
- tokenize=False,
95
- add_generation_prompt=True
96
  )
97
-
98
- model_inputs = tokenizer([text], return_tensors="pt")
99
 
100
- generated_ids = model.generate(
101
- model_inputs.input_ids,
102
- max_new_tokens=512
103
- )
104
- generated_ids = [
105
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
106
- ]
107
 
108
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
109
 
110
  return response
111
 
@@ -120,7 +123,7 @@ with gr.Blocks() as demo:
120
  detected_objects_output = gr.Textbox(label="Detected Objects")
121
 
122
  with gr.Tab("LLM Chat"):
123
- model_selector = gr.Dropdown(choices=["qwen2:0.5b", "gemma2:2b"], label="Select LLM Model")
124
  system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
125
  user_prompt_input = gr.Textbox(label="User Prompt")
126
  llm_output = gr.Textbox(label="LLM Response")
 
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
  from ultralytics import YOLO
4
+ from transformers import AutoTokenizer
5
+ from optimum.onnxruntime import ORTModelForCausalLM
6
  import torch
 
7
 
8
  # Load a pre-trained YOLO model
9
  model = YOLO('yolov8n.pt')
 
60
  return llm_cache[model_name]
61
 
62
  model_map = {
63
+ "qwen3:0.6b": "onnx-community/Qwen3-0.6B-ONNX",
64
+ "gemma3:1b": "onnx-community/gemma-3-1b-it-ONNX-GQA"
65
  }
66
  hf_model_name = model_map[model_name]
67
 
68
+ # Tokenizer is loaded from the original model's repo to ensure correct chat templates
69
+ original_model_map = {
70
+ "qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct",
71
+ "gemma3:1b": "google/gemma-3-1b-it"
72
+ }
73
+ tokenizer = AutoTokenizer.from_pretrained(original_model_map[model_name])
74
+ model = ORTModelForCausalLM.from_pretrained(hf_model_name)
75
 
76
  llm_cache[model_name] = (model, tokenizer)
77
  return model, tokenizer
 
94
  {"role": "system", "content": system_prompt},
95
  {"role": "user", "content": user_prompt},
96
  ]
97
+
98
+ inputs = tokenizer.apply_chat_template(
99
  messages,
100
+ add_generation_prompt=True,
101
+ return_tensors="pt",
102
  )
 
 
103
 
104
+ generated_ids = model.generate(inputs, max_new_tokens=512)
 
 
 
 
 
 
105
 
106
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
107
+
108
+ # The response might include the prompt, so we remove it.
109
+ # This is a common pattern when decoding from a generation.
110
+ prompt_plus_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
111
+ response = prompt_plus_response[len(tokenizer.decode(inputs[0], skip_special_tokens=True)):]
112
 
113
  return response
114
 
 
123
  detected_objects_output = gr.Textbox(label="Detected Objects")
124
 
125
  with gr.Tab("LLM Chat"):
126
+ model_selector = gr.Dropdown(choices=["qwen3:0.6b", "gemma3:1b"], label="Select LLM Model")
127
  system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
128
  user_prompt_input = gr.Textbox(label="User Prompt")
129
  llm_output = gr.Textbox(label="LLM Response")
requirements.txt CHANGED
@@ -3,3 +3,5 @@ ultralytics
3
  torch
4
  transformers
5
  pillow
 
 
 
3
  torch
4
  transformers
5
  pillow
6
+ bitsandbytes
7
+ optimum[onnxruntime]