Update
Browse files
app.py
CHANGED
|
@@ -63,13 +63,12 @@ def load_model_by_name(name: str):
|
|
| 63 |
model.eval()
|
| 64 |
|
| 65 |
print(f"Activated model: {name}")
|
| 66 |
-
return
|
| 67 |
|
| 68 |
|
| 69 |
def switch_model(selected_model_name: str):
|
| 70 |
"""Wrapper for Gradio to switch models via radio selection."""
|
| 71 |
-
load_model_by_name(selected_model_name)
|
| 72 |
-
return None
|
| 73 |
|
| 74 |
# βββ HELPER FUNCTIONS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
|
|
@@ -100,7 +99,7 @@ def process_csv_file(csv_file):
|
|
| 100 |
df.columns = [str(c).strip() for c in df.columns]
|
| 101 |
df = df.loc[:, [c for c in df.columns if c]]
|
| 102 |
df = df.dropna(axis=1, how="all")
|
| 103 |
-
print(f"File {csv_file.name} loaded. {df.columns=}")
|
| 104 |
|
| 105 |
if df.shape[1] == 0:
|
| 106 |
return None, "No valid time-series columns found."
|
|
@@ -131,7 +130,7 @@ def process_csv_file(csv_file):
|
|
| 131 |
|
| 132 |
if not ts_list:
|
| 133 |
return None, "All time series are empty after trimming NaNs."
|
| 134 |
-
print(f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}")
|
| 135 |
|
| 136 |
return df, f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}"
|
| 137 |
|
|
@@ -169,7 +168,7 @@ def preview_csv(csv_file, use_default):
|
|
| 169 |
label="Select a Column to Visualize"
|
| 170 |
)
|
| 171 |
|
| 172 |
-
print("Successfully generated preview!")
|
| 173 |
|
| 174 |
return plot, message, dropdown, False # Set use_default to False when file is uploaded
|
| 175 |
|
|
@@ -233,11 +232,13 @@ def initialize_interface():
|
|
| 233 |
# βββ INFERENCE + VALIDATION ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 234 |
|
| 235 |
@spaces.GPU # dynamically allocate & release a ZeroGPU device on each call
|
| 236 |
-
def infer_chatts_stream(prompt: str, csv_file, use_default):
|
| 237 |
"""
|
| 238 |
Streaming version of ChatTS inference
|
| 239 |
"""
|
| 240 |
-
|
|
|
|
|
|
|
| 241 |
|
| 242 |
if not prompt.strip():
|
| 243 |
yield "Please enter a prompt"
|
|
@@ -279,7 +280,7 @@ def infer_chatts_stream(prompt: str, csv_file, use_default):
|
|
| 279 |
|
| 280 |
full_prompt = f"<|im_start|>system\nYou are a helpful assistant. Your name is ChatTS. You can analyze time series data and provide insights. If user asks who you are, you should give your name and capabilities in the language of the prompt. If user has no format requirement, always output a step-by-step analysis about the time series attributes that mentioned in the question first, and then give a detailed result about the given question.<|im_end|><|im_start|>user\n{prefix}{clean_prompt}<|im_end|><|im_start|>assistant\n"
|
| 281 |
|
| 282 |
-
print(f"[
|
| 283 |
|
| 284 |
# Encode inputs
|
| 285 |
inputs = processor(
|
|
@@ -290,8 +291,8 @@ def infer_chatts_stream(prompt: str, csv_file, use_default):
|
|
| 290 |
)
|
| 291 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 292 |
|
| 293 |
-
if inputs['timeseries'] is not None:
|
| 294 |
-
|
| 295 |
|
| 296 |
# Generate with streaming
|
| 297 |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
|
@@ -399,7 +400,7 @@ with gr.Blocks(title="ChatTS Demo") as demo:
|
|
| 399 |
|
| 400 |
run_btn.click(
|
| 401 |
fn=infer_chatts_stream,
|
| 402 |
-
inputs=[prompt_input, upload, use_default_state],
|
| 403 |
outputs=[text_out]
|
| 404 |
)
|
| 405 |
|
|
|
|
| 63 |
model.eval()
|
| 64 |
|
| 65 |
print(f"Activated model: {name}")
|
| 66 |
+
return name
|
| 67 |
|
| 68 |
|
| 69 |
def switch_model(selected_model_name: str):
|
| 70 |
"""Wrapper for Gradio to switch models via radio selection."""
|
| 71 |
+
return load_model_by_name(selected_model_name)
|
|
|
|
| 72 |
|
| 73 |
# βββ HELPER FUNCTIONS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
|
|
|
|
| 99 |
df.columns = [str(c).strip() for c in df.columns]
|
| 100 |
df = df.loc[:, [c for c in df.columns if c]]
|
| 101 |
df = df.dropna(axis=1, how="all")
|
| 102 |
+
print(f"[LOG] File {csv_file.name} loaded. {df.columns=}")
|
| 103 |
|
| 104 |
if df.shape[1] == 0:
|
| 105 |
return None, "No valid time-series columns found."
|
|
|
|
| 130 |
|
| 131 |
if not ts_list:
|
| 132 |
return None, "All time series are empty after trimming NaNs."
|
| 133 |
+
# print(f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}")
|
| 134 |
|
| 135 |
return df, f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}"
|
| 136 |
|
|
|
|
| 168 |
label="Select a Column to Visualize"
|
| 169 |
)
|
| 170 |
|
| 171 |
+
# print("Successfully generated preview!")
|
| 172 |
|
| 173 |
return plot, message, dropdown, False # Set use_default to False when file is uploaded
|
| 174 |
|
|
|
|
| 232 |
# βββ INFERENCE + VALIDATION ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 233 |
|
| 234 |
@spaces.GPU # dynamically allocate & release a ZeroGPU device on each call
|
| 235 |
+
def infer_chatts_stream(prompt: str, csv_file, use_default, model_name):
|
| 236 |
"""
|
| 237 |
Streaming version of ChatTS inference
|
| 238 |
"""
|
| 239 |
+
# Activate the selected model
|
| 240 |
+
switch_model(model_name)
|
| 241 |
+
# print("Start inferring!!!")
|
| 242 |
|
| 243 |
if not prompt.strip():
|
| 244 |
yield "Please enter a prompt"
|
|
|
|
| 280 |
|
| 281 |
full_prompt = f"<|im_start|>system\nYou are a helpful assistant. Your name is ChatTS. You can analyze time series data and provide insights. If user asks who you are, you should give your name and capabilities in the language of the prompt. If user has no format requirement, always output a step-by-step analysis about the time series attributes that mentioned in the question first, and then give a detailed result about the given question.<|im_end|><|im_start|>user\n{prefix}{clean_prompt}<|im_end|><|im_start|>assistant\n"
|
| 282 |
|
| 283 |
+
print(f"[LOG] model={CURRENT_MODEL_NAME}, {clean_prompt=}, {len(ts_list)=}")
|
| 284 |
|
| 285 |
# Encode inputs
|
| 286 |
inputs = processor(
|
|
|
|
| 291 |
)
|
| 292 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 293 |
|
| 294 |
+
# if inputs['timeseries'] is not None:
|
| 295 |
+
# print(f"[debug] {inputs['timeseries'].shape=}")
|
| 296 |
|
| 297 |
# Generate with streaming
|
| 298 |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
|
|
|
| 400 |
|
| 401 |
run_btn.click(
|
| 402 |
fn=infer_chatts_stream,
|
| 403 |
+
inputs=[prompt_input, upload, use_default_state, model_radio],
|
| 404 |
outputs=[text_out]
|
| 405 |
)
|
| 406 |
|