xiezhe22 commited on
Commit
8d3a706
Β·
1 Parent(s): 3c70d69
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -63,13 +63,12 @@ def load_model_by_name(name: str):
63
  model.eval()
64
 
65
  print(f"Activated model: {name}")
66
- return f"Active model: {name}"
67
 
68
 
69
  def switch_model(selected_model_name: str):
70
  """Wrapper for Gradio to switch models via radio selection."""
71
- load_model_by_name(selected_model_name)
72
- return None
73
 
74
  # ─── HELPER FUNCTIONS ──────────────────────────────────────────────────────────
75
 
@@ -100,7 +99,7 @@ def process_csv_file(csv_file):
100
  df.columns = [str(c).strip() for c in df.columns]
101
  df = df.loc[:, [c for c in df.columns if c]]
102
  df = df.dropna(axis=1, how="all")
103
- print(f"File {csv_file.name} loaded. {df.columns=}")
104
 
105
  if df.shape[1] == 0:
106
  return None, "No valid time-series columns found."
@@ -131,7 +130,7 @@ def process_csv_file(csv_file):
131
 
132
  if not ts_list:
133
  return None, "All time series are empty after trimming NaNs."
134
- print(f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}")
135
 
136
  return df, f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}"
137
 
@@ -169,7 +168,7 @@ def preview_csv(csv_file, use_default):
169
  label="Select a Column to Visualize"
170
  )
171
 
172
- print("Successfully generated preview!")
173
 
174
  return plot, message, dropdown, False # Set use_default to False when file is uploaded
175
 
@@ -233,11 +232,13 @@ def initialize_interface():
233
  # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
234
 
235
  @spaces.GPU # dynamically allocate & release a ZeroGPU device on each call
236
- def infer_chatts_stream(prompt: str, csv_file, use_default):
237
  """
238
  Streaming version of ChatTS inference
239
  """
240
- print("Start inferring!!!")
 
 
241
 
242
  if not prompt.strip():
243
  yield "Please enter a prompt"
@@ -279,7 +280,7 @@ def infer_chatts_stream(prompt: str, csv_file, use_default):
279
 
280
  full_prompt = f"<|im_start|>system\nYou are a helpful assistant. Your name is ChatTS. You can analyze time series data and provide insights. If user asks who you are, you should give your name and capabilities in the language of the prompt. If user has no format requirement, always output a step-by-step analysis about the time series attributes that mentioned in the question first, and then give a detailed result about the given question.<|im_end|><|im_start|>user\n{prefix}{clean_prompt}<|im_end|><|im_start|>assistant\n"
281
 
282
- print(f"[debug] {full_prompt}. {len(ts_list)=}, {[len(item) for item in ts_list]=}")
283
 
284
  # Encode inputs
285
  inputs = processor(
@@ -290,8 +291,8 @@ def infer_chatts_stream(prompt: str, csv_file, use_default):
290
  )
291
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
292
 
293
- if inputs['timeseries'] is not None:
294
- print(f"[debug] {inputs['timeseries'].shape=}")
295
 
296
  # Generate with streaming
297
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
@@ -399,7 +400,7 @@ with gr.Blocks(title="ChatTS Demo") as demo:
399
 
400
  run_btn.click(
401
  fn=infer_chatts_stream,
402
- inputs=[prompt_input, upload, use_default_state],
403
  outputs=[text_out]
404
  )
405
 
 
63
  model.eval()
64
 
65
  print(f"Activated model: {name}")
66
+ return name
67
 
68
 
69
  def switch_model(selected_model_name: str):
70
  """Wrapper for Gradio to switch models via radio selection."""
71
+ return load_model_by_name(selected_model_name)
 
72
 
73
  # ─── HELPER FUNCTIONS ──────────────────────────────────────────────────────────
74
 
 
99
  df.columns = [str(c).strip() for c in df.columns]
100
  df = df.loc[:, [c for c in df.columns if c]]
101
  df = df.dropna(axis=1, how="all")
102
+ print(f"[LOG] File {csv_file.name} loaded. {df.columns=}")
103
 
104
  if df.shape[1] == 0:
105
  return None, "No valid time-series columns found."
 
130
 
131
  if not ts_list:
132
  return None, "All time series are empty after trimming NaNs."
133
+ # print(f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}")
134
 
135
  return df, f"Successfully loaded {len(ts_names)} time series: {', '.join(ts_names)}"
136
 
 
168
  label="Select a Column to Visualize"
169
  )
170
 
171
+ # print("Successfully generated preview!")
172
 
173
  return plot, message, dropdown, False # Set use_default to False when file is uploaded
174
 
 
232
  # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
233
 
234
  @spaces.GPU # dynamically allocate & release a ZeroGPU device on each call
235
+ def infer_chatts_stream(prompt: str, csv_file, use_default, model_name):
236
  """
237
  Streaming version of ChatTS inference
238
  """
239
+ # Activate the selected model
240
+ switch_model(model_name)
241
+ # print("Start inferring!!!")
242
 
243
  if not prompt.strip():
244
  yield "Please enter a prompt"
 
280
 
281
  full_prompt = f"<|im_start|>system\nYou are a helpful assistant. Your name is ChatTS. You can analyze time series data and provide insights. If user asks who you are, you should give your name and capabilities in the language of the prompt. If user has no format requirement, always output a step-by-step analysis about the time series attributes that mentioned in the question first, and then give a detailed result about the given question.<|im_end|><|im_start|>user\n{prefix}{clean_prompt}<|im_end|><|im_start|>assistant\n"
282
 
283
+ print(f"[LOG] model={CURRENT_MODEL_NAME}, {clean_prompt=}, {len(ts_list)=}")
284
 
285
  # Encode inputs
286
  inputs = processor(
 
291
  )
292
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
293
 
294
+ # if inputs['timeseries'] is not None:
295
+ # print(f"[debug] {inputs['timeseries'].shape=}")
296
 
297
  # Generate with streaming
298
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
400
 
401
  run_btn.click(
402
  fn=infer_chatts_stream,
403
+ inputs=[prompt_input, upload, use_default_state, model_radio],
404
  outputs=[text_out]
405
  )
406