mbellan commited on
Commit
d6c45f2
·
1 Parent(s): c3efd49

Fix: Gradio Progress and launch parameters for HF deployment

Browse files
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -69,7 +69,7 @@ class VoiceModelTrainer:
69
  learning_rate: float,
70
  algorithm: str,
71
  batch_size: int,
72
- progress=gr.Progress()
73
  ) -> Tuple[str, str, str]:
74
  """Train the model with RL."""
75
  if self.training_active:
@@ -77,7 +77,8 @@ class VoiceModelTrainer:
77
 
78
  try:
79
  self.training_active = True
80
- progress(0, desc="Initializing training...")
 
81
 
82
  # Create output directory
83
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -85,18 +86,21 @@ class VoiceModelTrainer:
85
  run_dir.mkdir(parents=True, exist_ok=True)
86
 
87
  # Load model
88
- progress(0.1, desc="Loading model...")
 
89
  model = VoiceModelWrapper(model_name=model_name, device=self.device)
90
  model.load_model()
91
 
92
  # Setup data (use sample data for demo)
93
- progress(0.2, desc="Preparing data...")
 
94
  data_manager = DataManager()
95
  # For HF Space, we'll use a small demo dataset
96
  # In production, this would load from user-provided data
97
 
98
  # Create algorithm
99
- progress(0.3, desc=f"Initializing {algorithm.upper()} algorithm...")
 
100
  rl_model = model.get_rl_model() if hasattr(model, 'get_rl_model') else model.model
101
 
102
  if algorithm.lower() == 'ppo':
@@ -122,7 +126,8 @@ class VoiceModelTrainer:
122
  metrics_tracker = MetricsTracker(log_dir=str(run_dir / 'logs'))
123
  visualizer = Visualizer(output_dir=str(run_dir / 'visualizations'))
124
 
125
- progress(0.4, desc="Starting training...")
 
126
 
127
  # For demo purposes, simulate training
128
  # In production, you'd run actual training here
@@ -144,8 +149,9 @@ class VoiceModelTrainer:
144
 
145
  # Simulate training progress
146
  for i in range(num_episodes):
147
- progress((0.4 + (i / num_episodes) * 0.5),
148
- desc=f"Training episode {i+1}/{num_episodes}")
 
149
 
150
  # Save checkpoint
151
  checkpoint_dir = run_dir / 'checkpoints'
@@ -158,7 +164,8 @@ class VoiceModelTrainer:
158
  'episode': num_episodes
159
  }, checkpoint_path)
160
 
161
- progress(1.0, desc="Training complete!")
 
162
 
163
  self.models['trained'] = model
164
 
@@ -182,14 +189,15 @@ class VoiceModelTrainer:
182
  self,
183
  checkpoint_path: str,
184
  sample_audio: str,
185
- progress=gr.Progress()
186
  ) -> Tuple[str, str, str]:
187
  """Generate audio comparison."""
188
  try:
189
  if not checkpoint_path or not Path(checkpoint_path).exists():
190
  return None, None, "❌ No checkpoint available"
191
 
192
- progress(0, desc="Loading models...")
 
193
 
194
  # For demo, return the input audio
195
  # In production, process through models
@@ -445,8 +453,4 @@ def create_app():
445
 
446
  if __name__ == "__main__":
447
  app = create_app()
448
- app.launch(
449
- server_name="0.0.0.0",
450
- server_port=7860,
451
- share=False
452
- )
 
69
  learning_rate: float,
70
  algorithm: str,
71
  batch_size: int,
72
+ progress=None
73
  ) -> Tuple[str, str, str]:
74
  """Train the model with RL."""
75
  if self.training_active:
 
77
 
78
  try:
79
  self.training_active = True
80
+ if progress:
81
+ progress(0, desc="Initializing training...")
82
 
83
  # Create output directory
84
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
86
  run_dir.mkdir(parents=True, exist_ok=True)
87
 
88
  # Load model
89
+ if progress:
90
+ progress(0.1, desc="Loading model...")
91
  model = VoiceModelWrapper(model_name=model_name, device=self.device)
92
  model.load_model()
93
 
94
  # Setup data (use sample data for demo)
95
+ if progress:
96
+ progress(0.2, desc="Preparing data...")
97
  data_manager = DataManager()
98
  # For HF Space, we'll use a small demo dataset
99
  # In production, this would load from user-provided data
100
 
101
  # Create algorithm
102
+ if progress:
103
+ progress(0.3, desc=f"Initializing {algorithm.upper()} algorithm...")
104
  rl_model = model.get_rl_model() if hasattr(model, 'get_rl_model') else model.model
105
 
106
  if algorithm.lower() == 'ppo':
 
126
  metrics_tracker = MetricsTracker(log_dir=str(run_dir / 'logs'))
127
  visualizer = Visualizer(output_dir=str(run_dir / 'visualizations'))
128
 
129
+ if progress:
130
+ progress(0.4, desc="Starting training...")
131
 
132
  # For demo purposes, simulate training
133
  # In production, you'd run actual training here
 
149
 
150
  # Simulate training progress
151
  for i in range(num_episodes):
152
+ if progress:
153
+ progress((0.4 + (i / num_episodes) * 0.5),
154
+ desc=f"Training episode {i+1}/{num_episodes}")
155
 
156
  # Save checkpoint
157
  checkpoint_dir = run_dir / 'checkpoints'
 
164
  'episode': num_episodes
165
  }, checkpoint_path)
166
 
167
+ if progress:
168
+ progress(1.0, desc="Training complete!")
169
 
170
  self.models['trained'] = model
171
 
 
189
  self,
190
  checkpoint_path: str,
191
  sample_audio: str,
192
+ progress=None
193
  ) -> Tuple[str, str, str]:
194
  """Generate audio comparison."""
195
  try:
196
  if not checkpoint_path or not Path(checkpoint_path).exists():
197
  return None, None, "❌ No checkpoint available"
198
 
199
+ if progress:
200
+ progress(0, desc="Loading models...")
201
 
202
  # For demo, return the input audio
203
  # In production, process through models
 
453
 
454
  if __name__ == "__main__":
455
  app = create_app()
456
+ app.launch()