thecollabagepatch commited on
Commit
edc7448
Β·
1 Parent(s): aa00058

continue added back in

Browse files
Files changed (1) hide show
  1. app.py +104 -30
app.py CHANGED
@@ -9,6 +9,8 @@ import torch
9
  from gradio_client import Client, handle_file
10
  import random
11
  import time
 
 
12
 
13
  # Check if CUDA is available
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -105,6 +107,7 @@ def continue_drum_sample(existing_audio_path):
105
 
106
  @spaces.GPU
107
  def generate_music(wav_filename, prompt_duration, musicgen_model, output_duration):
 
108
  if wav_filename is None:
109
  return None
110
 
@@ -138,6 +141,74 @@ def generate_music(wav_filename, prompt_duration, musicgen_model, output_duratio
138
 
139
  return filename_with_extension
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # ========== MELODYFLOW FUNCTIONS (Via Facebook Space) ==========
142
 
143
  def transform_with_melodyflow_api(audio_path, variation, custom_prompt="", solver="euler", flowstep=0.12):
@@ -145,24 +216,10 @@ def transform_with_melodyflow_api(audio_path, variation, custom_prompt="", solve
145
  if audio_path is None:
146
  return None, "❌ No audio file provided"
147
 
148
- # Initialize variables first to avoid scope issues
149
- base_steps = 125
150
- effective_steps = 25
151
-
152
  try:
153
  # Initialize client for Facebook MelodyFlow space
154
  client = Client("facebook/MelodyFlow")
155
 
156
- # Set steps based on solver and the fact we're doing editing
157
- # Facebook's space automatically reduces steps for editing:
158
- # EULER: divides by 5, MIDPOINT: divides by 2
159
- if solver == "midpoint":
160
- base_steps = 128
161
- effective_steps = 64 # 128 // 2
162
- else: # euler (default)
163
- base_steps = 125
164
- effective_steps = 25 # 125 // 5
165
-
166
  # Determine the prompt to use
167
  if custom_prompt.strip():
168
  prompt_text = custom_prompt.strip()
@@ -171,6 +228,16 @@ def transform_with_melodyflow_api(audio_path, variation, custom_prompt="", solve
171
  prompt_text = VARIATION_PROMPTS.get(variation, f"transform this audio to {variation} style")
172
  status_msg = f"βœ… Transformed with {variation} style (flowstep: {flowstep}, {effective_steps} steps)"
173
 
 
 
 
 
 
 
 
 
 
 
174
  # Call the MelodyFlow API with the base steps (it will auto-reduce)
175
  result = client.predict(
176
  model="facebook/melodyflow-t24-30secs",
@@ -188,21 +255,12 @@ def transform_with_melodyflow_api(audio_path, variation, custom_prompt="", solve
188
  # Result is a tuple of 3 audio files (variations)
189
  # We'll use the first variation
190
  if result and len(result) > 0 and result[0]:
191
- # Save the result locally with loudness normalization
192
  output_filename = f"melodyflow_{variation}_{random.randint(1000, 9999)}.wav"
193
 
194
- # Load the result and apply consistent loudness normalization
195
- transformed_audio, sr = torchaudio.load(result[0])
196
-
197
- # Re-save with same loudness strategy as your MusicGen (no headroom)
198
- audio_write(
199
- output_filename.replace('.wav', ''),
200
- transformed_audio,
201
- sr,
202
- strategy="loudness",
203
- loudness_compressor=True
204
- # Note: no loudness_headroom_db parameter like Facebook uses
205
- )
206
 
207
  return output_filename, status_msg
208
  else:
@@ -319,8 +377,18 @@ with gr.Blocks() as iface:
319
  ],
320
  value="thepatch/vanya_ai_dnb_0.1 (small)"
321
  )
322
-
323
- generate_music_button = gr.Button("🎼 Continue with MusicGen", variant="primary", size="lg")
 
 
 
 
 
 
 
 
 
 
324
 
325
  # ========== EVENT HANDLERS ==========
326
 
@@ -335,12 +403,18 @@ with gr.Blocks() as iface:
335
  outputs=[main_audio, transform_status]
336
  )
337
 
338
- # Step 3: Continue
339
  generate_music_button.click(
340
  generate_music,
341
  inputs=[main_audio, prompt_duration, musicgen_model, output_duration],
342
  outputs=[main_audio]
343
  )
 
 
 
 
 
 
344
 
345
  if __name__ == "__main__":
346
  iface.launch()
 
9
  from gradio_client import Client, handle_file
10
  import random
11
  import time
12
+ import io
13
+ from pydub import AudioSegment
14
 
15
  # Check if CUDA is available
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
107
 
108
  @spaces.GPU
109
  def generate_music(wav_filename, prompt_duration, musicgen_model, output_duration):
110
+ """Generate music using the BEGINNING of the audio as prompt"""
111
  if wav_filename is None:
112
  return None
113
 
 
141
 
142
  return filename_with_extension
143
 
144
+ @spaces.GPU
145
+ def continue_music(input_audio_path, prompt_duration, musicgen_model, output_duration):
146
+ """Continue music using the END of the audio as prompt - extends the audio"""
147
+ if input_audio_path is None:
148
+ return None
149
+
150
+ song, sr = torchaudio.load(input_audio_path)
151
+ song = song.to(device)
152
+
153
+ model_name = musicgen_model.split(" ")[0]
154
+ model_continue = MusicGen.get_pretrained(model_name)
155
+ model_continue.set_generation_params(
156
+ use_sampling=True,
157
+ top_k=250,
158
+ top_p=0.0,
159
+ temperature=1.0,
160
+ duration=output_duration,
161
+ cfg_coef=3
162
+ )
163
+
164
+ # Load original audio as AudioSegment for easier manipulation
165
+ original_audio = AudioSegment.from_wav(input_audio_path)
166
+ current_audio = original_audio
167
+ file_paths_for_cleanup = []
168
+
169
+ # Get the last `prompt_duration` seconds as the prompt
170
+ num_samples = int(prompt_duration * sr)
171
+ if song.shape[1] < num_samples:
172
+ raise ValueError("The prompt_duration is longer than the current audio length.")
173
+
174
+ # Extract the end portion for prompting
175
+ start_sample = song.shape[1] - num_samples
176
+ prompt_waveform = song[..., start_sample:]
177
+ prompt_waveform = preprocess_audio(prompt_waveform)
178
+
179
+ # Generate continuation
180
+ output = model_continue.generate_continuation(prompt_waveform, prompt_sample_rate=sr, progress=True)
181
+ output = output.cpu()
182
+
183
+ if len(output.size()) > 2:
184
+ output = output.squeeze()
185
+
186
+ # Save the generated audio
187
+ filename_without_extension = f'continue_extension_{random.randint(1000, 9999)}'
188
+ filename_with_extension = f'{filename_without_extension}.wav'
189
+ audio_write(filename_without_extension, output, model_continue.sample_rate, strategy="loudness", loudness_compressor=True)
190
+
191
+ # Handle the double .wav extension issue
192
+ correct_filename = f'{filename_without_extension}.wav.wav'
193
+ if os.path.exists(correct_filename):
194
+ generated_audio_segment = AudioSegment.from_wav(correct_filename)
195
+ file_paths_for_cleanup.append(correct_filename)
196
+ else:
197
+ generated_audio_segment = AudioSegment.from_wav(filename_with_extension)
198
+ file_paths_for_cleanup.append(filename_with_extension)
199
+
200
+ # Combine original + new audio
201
+ combined_audio = current_audio + generated_audio_segment
202
+ combined_audio_filename = f"extended_audio_{random.randint(1000, 9999)}.wav"
203
+ combined_audio.export(combined_audio_filename, format="wav")
204
+
205
+ # Cleanup temporary files
206
+ for file_path in file_paths_for_cleanup:
207
+ if os.path.exists(file_path):
208
+ os.remove(file_path)
209
+
210
+ return combined_audio_filename
211
+
212
  # ========== MELODYFLOW FUNCTIONS (Via Facebook Space) ==========
213
 
214
  def transform_with_melodyflow_api(audio_path, variation, custom_prompt="", solver="euler", flowstep=0.12):
 
216
  if audio_path is None:
217
  return None, "❌ No audio file provided"
218
 
 
 
 
 
219
  try:
220
  # Initialize client for Facebook MelodyFlow space
221
  client = Client("facebook/MelodyFlow")
222
 
 
 
 
 
 
 
 
 
 
 
223
  # Determine the prompt to use
224
  if custom_prompt.strip():
225
  prompt_text = custom_prompt.strip()
 
228
  prompt_text = VARIATION_PROMPTS.get(variation, f"transform this audio to {variation} style")
229
  status_msg = f"βœ… Transformed with {variation} style (flowstep: {flowstep}, {effective_steps} steps)"
230
 
231
+ # Set steps based on solver and the fact we're doing editing
232
+ # Facebook's space automatically reduces steps for editing:
233
+ # EULER: divides by 5, MIDPOINT: divides by 2
234
+ if solver == "midpoint":
235
+ base_steps = 128
236
+ effective_steps = base_steps // 2 # 64 effective steps
237
+ else: # euler
238
+ base_steps = 125
239
+ effective_steps = base_steps // 5 # 25 effective steps
240
+
241
  # Call the MelodyFlow API with the base steps (it will auto-reduce)
242
  result = client.predict(
243
  model="facebook/melodyflow-t24-30secs",
 
255
  # Result is a tuple of 3 audio files (variations)
256
  # We'll use the first variation
257
  if result and len(result) > 0 and result[0]:
258
+ # Save the result locally
259
  output_filename = f"melodyflow_{variation}_{random.randint(1000, 9999)}.wav"
260
 
261
+ # Copy the result file to our local filename
262
+ import shutil
263
+ shutil.copy2(result[0], output_filename)
 
 
 
 
 
 
 
 
 
264
 
265
  return output_filename, status_msg
266
  else:
 
377
  ],
378
  value="thepatch/vanya_ai_dnb_0.1 (small)"
379
  )
380
+
381
+ # Two different continuation options with clear explanations
382
+ with gr.Row():
383
+ with gr.Column():
384
+ gr.Markdown("### πŸ”„ Continue from Beginning")
385
+ gr.Markdown("*Uses the **first** X seconds as prompt. Good for reimagining/reworking from a starting point.*")
386
+ generate_music_button = gr.Button("πŸ”„ Continue from Beginning", variant="primary", size="lg")
387
+
388
+ with gr.Column():
389
+ gr.Markdown("### ➑️ Extend from End")
390
+ gr.Markdown("*Uses the **last** X seconds as prompt. Extends your audio by adding new content to the end.*")
391
+ continue_music_button = gr.Button("➑️ Extend from End", variant="secondary", size="lg")
392
 
393
  # ========== EVENT HANDLERS ==========
394
 
 
403
  outputs=[main_audio, transform_status]
404
  )
405
 
406
+ # Step 3: Continue (two different approaches)
407
  generate_music_button.click(
408
  generate_music,
409
  inputs=[main_audio, prompt_duration, musicgen_model, output_duration],
410
  outputs=[main_audio]
411
  )
412
+
413
+ continue_music_button.click(
414
+ continue_music,
415
+ inputs=[main_audio, prompt_duration, musicgen_model, output_duration],
416
+ outputs=[main_audio]
417
+ )
418
 
419
  if __name__ == "__main__":
420
  iface.launch()