Spaces:
Sleeping
Sleeping
added dropdown for user to select models
Browse files
app.py
CHANGED
|
@@ -22,8 +22,8 @@ dotenv_path = find_dotenv()
|
|
| 22 |
|
| 23 |
load_dotenv(dotenv_path)
|
| 24 |
|
| 25 |
-
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-
|
| 26 |
-
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-
|
| 27 |
|
| 28 |
input_processor = Gemma3Processor.from_pretrained(model_12_id)
|
| 29 |
|
|
@@ -138,6 +138,7 @@ def run(
|
|
| 138 |
message: dict,
|
| 139 |
history: list[dict],
|
| 140 |
system_prompt: str,
|
|
|
|
| 141 |
max_new_tokens: int,
|
| 142 |
max_images: int,
|
| 143 |
temperature: float,
|
|
@@ -148,9 +149,11 @@ def run(
|
|
| 148 |
|
| 149 |
logger.debug(
|
| 150 |
f"\n message: {message} \n history: {history} \n system_prompt: {system_prompt} \n "
|
| 151 |
-
f"max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
|
| 152 |
)
|
| 153 |
|
|
|
|
|
|
|
| 154 |
messages = []
|
| 155 |
if system_prompt:
|
| 156 |
messages.append(
|
|
@@ -167,7 +170,7 @@ def run(
|
|
| 167 |
tokenize=True,
|
| 168 |
return_dict=True,
|
| 169 |
return_tensors="pt",
|
| 170 |
-
).to(device=
|
| 171 |
|
| 172 |
streamer = TextIteratorStreamer(
|
| 173 |
input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
|
@@ -182,7 +185,7 @@ def run(
|
|
| 182 |
repetition_penalty=repetition_penalty,
|
| 183 |
do_sample=True,
|
| 184 |
)
|
| 185 |
-
t = Thread(target=
|
| 186 |
t.start()
|
| 187 |
|
| 188 |
output = ""
|
|
@@ -201,6 +204,11 @@ demo = gr.ChatInterface(
|
|
| 201 |
multimodal=True,
|
| 202 |
additional_inputs=[
|
| 203 |
gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
gr.Slider(
|
| 205 |
label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
|
| 206 |
),
|
|
|
|
| 22 |
|
| 23 |
load_dotenv(dotenv_path)
|
| 24 |
|
| 25 |
+
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
|
| 26 |
+
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
|
| 27 |
|
| 28 |
input_processor = Gemma3Processor.from_pretrained(model_12_id)
|
| 29 |
|
|
|
|
| 138 |
message: dict,
|
| 139 |
history: list[dict],
|
| 140 |
system_prompt: str,
|
| 141 |
+
model_choice: str,
|
| 142 |
max_new_tokens: int,
|
| 143 |
max_images: int,
|
| 144 |
temperature: float,
|
|
|
|
| 149 |
|
| 150 |
logger.debug(
|
| 151 |
f"\n message: {message} \n history: {history} \n system_prompt: {system_prompt} \n "
|
| 152 |
+
f"model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
|
| 153 |
)
|
| 154 |
|
| 155 |
+
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
|
| 156 |
+
|
| 157 |
messages = []
|
| 158 |
if system_prompt:
|
| 159 |
messages.append(
|
|
|
|
| 170 |
tokenize=True,
|
| 171 |
return_dict=True,
|
| 172 |
return_tensors="pt",
|
| 173 |
+
).to(device=selected_model.device, dtype=torch.bfloat16)
|
| 174 |
|
| 175 |
streamer = TextIteratorStreamer(
|
| 176 |
input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
|
|
|
| 185 |
repetition_penalty=repetition_penalty,
|
| 186 |
do_sample=True,
|
| 187 |
)
|
| 188 |
+
t = Thread(target=selected_model.generate, kwargs=generate_kwargs)
|
| 189 |
t.start()
|
| 190 |
|
| 191 |
output = ""
|
|
|
|
| 204 |
multimodal=True,
|
| 205 |
additional_inputs=[
|
| 206 |
gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
|
| 207 |
+
gr.Dropdown(
|
| 208 |
+
label="Model",
|
| 209 |
+
choices=["Gemma 3 12B", "Gemma 3n E4B"],
|
| 210 |
+
value="Gemma 3 12B"
|
| 211 |
+
),
|
| 212 |
gr.Slider(
|
| 213 |
label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
|
| 214 |
),
|