Spaces:
Running
Running
| import gradio as gr | |
| from gradio_client import Client | |
| import os | |
| import csv | |
| import numpy as np | |
| import scipy.io.wavfile as wavfile | |
| import tempfile | |
| client = Client(os.environ['src']) | |
| css = """ | |
| .gradio-container input::placeholder, | |
| .gradio-container textarea::placeholder { | |
| color: #333333 !important; | |
| } | |
| code { | |
| background-color: #ffde9f; | |
| padding: 2px 4px; | |
| border-radius: 3px; | |
| } | |
| #settings-accordion summary { | |
| justify-content: center; | |
| } | |
| .examples-holder > .label { | |
| color: #b45309 !important; | |
| font-weight: 600; | |
| } | |
| .audio-warning { | |
| color: #ff6b35 !important; | |
| font-weight: 600; | |
| margin: 10px 0; | |
| } | |
| .audio-error { | |
| color: #dc2626 !important; | |
| font-weight: 600; | |
| margin: 10px 0; | |
| } | |
| """ | |
| def validate_audio_duration(audio_data): | |
| """ | |
| Validate audio duration and return appropriate message | |
| Returns: (is_valid, warning_message) | |
| """ | |
| if audio_data is None: | |
| return True, "" | |
| sample_rate, audio_array = audio_data | |
| duration_seconds = len(audio_array) / sample_rate | |
| if duration_seconds > 10: | |
| error_msg = f""" | |
| <div class="audio-error"> | |
| ❌ Error: Audio is {duration_seconds:.1f} seconds long. Maximum allowed is 10 seconds.<br> | |
| ❌ エラー: 音声が{duration_seconds:.1f}秒です。最大10秒まで許可されています。 | |
| </div> | |
| """ | |
| return False, error_msg | |
| elif duration_seconds > 8.9: | |
| warning_msg = f""" | |
| <div class="audio-warning"> | |
| ⚠️ Warning: Your audio is {duration_seconds:.1f} seconds, it will eat up precious context and may result in poor generation.<br> | |
| ⚠️ 警告: 音声が{duration_seconds:.1f}秒を超えています。貴重なコンテキストを消費し、生成品質が低下する可能性があります。 | |
| </div> | |
| """ | |
| return True, warning_msg | |
| else: | |
| return True, "" | |
| def load_examples(csv_path): | |
| examples = [] | |
| if not os.path.exists(csv_path): | |
| print(f"Warning: Examples file not found at {csv_path}") | |
| return examples | |
| try: | |
| with open(csv_path, 'r', encoding='utf-8') as f: | |
| reader = csv.reader(f, delimiter='|') | |
| for row in reader: | |
| if len(row) >= 2: | |
| text = row[0].strip() | |
| audio_path = row[1].strip() | |
| # Handle temperature (third column) | |
| temperature = 0.7 # Default temperature | |
| if len(row) >= 3: | |
| try: | |
| temp_str = row[2].strip() | |
| if temp_str and temp_str.lower() != 'none': | |
| temperature = float(temp_str) | |
| # Clamp temperature to valid range | |
| temperature = max(0.0, min(1.3, temperature)) | |
| except (ValueError, TypeError): | |
| print(f"Warning: Invalid temperature value '{row[2]}', using default 0.7") | |
| temperature = 0.7 | |
| # Handle chained longform (fourth column) | |
| use_chained = False # Default to False | |
| if len(row) >= 4: | |
| chained_str = row[3].strip().lower() | |
| if chained_str in ['true', '1', 'yes', 'on']: | |
| use_chained = True | |
| elif chained_str in ['false', '0', 'no', 'off', 'none', '']: | |
| use_chained = False | |
| else: | |
| print(f"Warning: Invalid chained longform value '{row[3]}', using default False") | |
| use_chained = False | |
| if audio_path.lower() == "none": | |
| audio_path = None | |
| elif audio_path and not os.path.isabs(audio_path): | |
| base_dir = os.path.dirname(csv_path) | |
| audio_path = os.path.join(base_dir, audio_path) | |
| if not os.path.exists(audio_path): | |
| print(f"Warning: Audio file not found: {audio_path}") | |
| audio_path = None | |
| examples.append([text, audio_path, temperature, use_chained]) | |
| except Exception as e: | |
| print(f"Error loading examples: {e}") | |
| return examples | |
| def run_generation_pipeline_client( | |
| raw_text, | |
| audio_prompt, | |
| num_candidates, | |
| cfg_scale, | |
| top_k, | |
| temperature, | |
| use_chained_longform, | |
| seed, # Add seed parameter | |
| audio_warning_display | |
| ): | |
| try: | |
| # Validate audio duration first | |
| is_valid, warning_msg = validate_audio_duration(audio_prompt) | |
| if not is_valid: | |
| # Return error without processing | |
| return None, "Status: Audio too long. Please use audio under 10 seconds." | |
| # Handle audio prompt - save to temporary file if provided | |
| audio_prompt_for_api = None | |
| if audio_prompt is not None: | |
| import tempfile | |
| import scipy.io.wavfile as wavfile | |
| sample_rate, audio_data = audio_prompt | |
| # Save audio to temporary file | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: | |
| # Ensure audio_data is numpy array | |
| if isinstance(audio_data, list): | |
| audio_data = np.array(audio_data) | |
| # Convert to int16 for WAV file | |
| if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: | |
| audio_data = (audio_data * 32767).astype(np.int16) | |
| # Write WAV file | |
| wavfile.write(tmp_file.name, sample_rate, audio_data) | |
| # Prepare for API - use the file path with proper metadata | |
| audio_prompt_for_api = {"path": tmp_file.name, "meta": {"_type": "gradio.FileData"}} | |
| # Call the backend API with file path instead of raw audio data | |
| result = client.predict( | |
| raw_text, | |
| audio_prompt_for_api, # Now sending file path with metadata | |
| num_candidates, | |
| cfg_scale, | |
| top_k, | |
| temperature, | |
| use_chained_longform, | |
| seed, # Add seed to API call | |
| api_name="/run_generation_pipeline" | |
| ) | |
| # Clean up temporary file if created | |
| if audio_prompt_for_api is not None: | |
| import os | |
| try: | |
| os.unlink(audio_prompt_for_api["path"]) | |
| except: | |
| pass | |
| # Handle the unpacked result | |
| if len(result) == 3: # Successful case | |
| sample_rate, audio_data, status_message = result | |
| if audio_data is not None: | |
| if isinstance(audio_data, list): | |
| audio_data = np.array(audio_data) | |
| return (sample_rate, audio_data), status_message | |
| else: | |
| return None, status_message | |
| elif len(result) == 2: # Failed case | |
| return result[0], result[1] # (None, status_message) | |
| else: | |
| return None, "Status: Unexpected response format from server" | |
| except Exception as e: | |
| return None, f"Status: Connection error: {str(e)}" | |
| # Client wrapper for duration-aware generation - FIXED for audio handling | |
| def run_duration_generation_pipeline_client( | |
| raw_text, | |
| audio_prompt, | |
| num_candidates, | |
| cfg_scale, | |
| top_k, | |
| temperature, | |
| use_chained_longform, | |
| add_steps, | |
| use_duration_aware, | |
| chars_per_second, | |
| seed, # Add seed parameter | |
| audio_warning_display_dur | |
| ): | |
| try: | |
| # Validate audio duration first | |
| is_valid, warning_msg = validate_audio_duration(audio_prompt) | |
| if not is_valid: | |
| # Return error without processing | |
| return None, "Status: Audio too long. Please use audio under 10 seconds." | |
| # Handle audio prompt - save to temporary file if provided | |
| audio_prompt_for_api = None | |
| if audio_prompt is not None: | |
| import tempfile | |
| import scipy.io.wavfile as wavfile | |
| sample_rate, audio_data = audio_prompt | |
| # Save audio to temporary file | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: | |
| # Ensure audio_data is numpy array | |
| if isinstance(audio_data, list): | |
| audio_data = np.array(audio_data) | |
| # Convert to int16 for WAV file | |
| if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: | |
| audio_data = (audio_data * 32767).astype(np.int16) | |
| # Write WAV file | |
| wavfile.write(tmp_file.name, sample_rate, audio_data) | |
| # Prepare for API - use the file path with proper metadata | |
| audio_prompt_for_api = {"path": tmp_file.name, "meta": {"_type": "gradio.FileData"}} | |
| # Call the backend API with file path instead of raw audio data | |
| result = client.predict( | |
| raw_text, | |
| audio_prompt_for_api, # Now sending file path with metadata | |
| num_candidates, | |
| cfg_scale, | |
| top_k, | |
| temperature, | |
| use_chained_longform, | |
| add_steps, | |
| use_duration_aware, | |
| chars_per_second, | |
| seed, # Add seed to API call | |
| api_name="/run_duration_generation_pipeline" | |
| ) | |
| # Clean up temporary file if created | |
| if audio_prompt_for_api is not None: | |
| import os | |
| try: | |
| os.unlink(audio_prompt_for_api["path"]) | |
| except: | |
| pass | |
| # Handle the unpacked result | |
| if len(result) == 3: # Successful case | |
| sample_rate, audio_data, status_message = result | |
| if audio_data is not None: | |
| if isinstance(audio_data, list): | |
| audio_data = np.array(audio_data) | |
| return (sample_rate, audio_data), status_message | |
| else: | |
| return None, status_message | |
| elif len(result) == 2: # Failed case | |
| return result[0], result[1] # (None, status_message) | |
| else: | |
| return None, "Status: Unexpected response format from server" | |
| except Exception as e: | |
| return None, f"Status: Connection error: {str(e)}" | |
| # Audio validation callback | |
| def on_audio_upload(audio_data): | |
| """Validate audio when uploaded and return warning message""" | |
| is_valid, warning_msg = validate_audio_duration(audio_data) | |
| if not is_valid: | |
| # Clear the audio input if it's too long | |
| return None, warning_msg | |
| return audio_data, warning_msg | |
| # Load examples | |
| examples_csv_path = "./samples.csv" # Adjust path as needed for client side | |
| example_list = load_examples(examples_csv_path) | |
| # Create Gradio interface | |
| with gr.Blocks(theme="Respair/Shiki@9.1.0", css=css) as demo: | |
| gr.Markdown('<h1 style="text-align: center; width: 100%; display: block;">🌸 Takane</h1>') | |
| gr.Markdown(''' | |
| <div style="text-align: center; background-color: #fff3cd; border: 1px solid #ffc107; | |
| border-radius: 8px; padding: 12px; margin: 10px auto; max-width: 800px;"> | |
| <p style="color: #856404; margin: 0; font-weight: 500;"> | |
| ⚠️ This demo doesn't have load balancing or parallel query handling. | |
| You must wait for everyone else to finish first during busy times. Sorry! | |
| </p> | |
| </div> | |
| ''') | |
| with gr.Tabs(): | |
| with gr.TabItem("Speech Generation"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| label="Text to Synthesize", | |
| lines=5, | |
| value="<spk_1146> はいはい、それでは、チャンネル登録よろしくお願いしまーす。じゃあみんな、また明日ねー、ばいばーい。" | |
| ) | |
| # Settings and Generate button | |
| with gr.Row(equal_height=False): | |
| with gr.Accordion("----------------------------------⭐ 🛠️ ⭐", open=False, label="_"): | |
| turbo_checkbox = gr.Checkbox( | |
| label="⚡ Turbo Mode (Fast generation, single candidate)", | |
| value=False | |
| ) | |
| num_candidates_slider = gr.Slider( | |
| label="Number of Candidates", | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1 | |
| ) | |
| cfg_scale_slider = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=3.0, | |
| value=1.4, | |
| step=0.1 | |
| ) | |
| top_k_slider = gr.Slider( | |
| label="Top K", | |
| minimum=10, | |
| maximum=100, | |
| value=55, | |
| step=5 | |
| ) | |
| temperature_slider = gr.Slider( | |
| label="Temperature (below 0.6 can break)", | |
| minimum=0.0, | |
| maximum=1.3, | |
| value=0.7, | |
| step=0.1 | |
| ) | |
| seed_slider = gr.Slider( | |
| label="Seed (use -1 for random)", | |
| minimum=-1, | |
| maximum=2700000000, | |
| value=2687110803, | |
| step=1 | |
| ) | |
| chained_longform_checkbox = gr.Checkbox( | |
| label="Use Chained Longform (Sequential conditioning for consistency)", | |
| value=False | |
| ) | |
| audio_prompt_input = gr.Audio( | |
| label="Audio Prompt (Optional - オプション) [Max 10 seconds / 最大10秒]", | |
| sources=["upload", "microphone"], | |
| type="numpy" | |
| ) | |
| # Warning display for audio duration | |
| audio_warning_display = gr.HTML(value="", visible=True) | |
| # Audio validation on change | |
| audio_prompt_input.change( | |
| fn=on_audio_upload, | |
| inputs=[audio_prompt_input], | |
| outputs=[audio_prompt_input, audio_warning_display] | |
| ) | |
| # Turbo mode event handler | |
| def toggle_turbo(turbo_enabled): | |
| if turbo_enabled: | |
| return 1, 1.0 # num_candidates=1, temperature=1.0 | |
| else: | |
| return 5, 0.7 # default values | |
| turbo_checkbox.change( | |
| fn=toggle_turbo, | |
| inputs=[turbo_checkbox], | |
| outputs=[num_candidates_slider, temperature_slider] | |
| ) | |
| with gr.Column(scale=1): | |
| generate_button = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| audio_output = gr.Audio(label="Generated Speech", interactive=False, show_download_button=True) | |
| # Event handler - pass the warning display as a dummy input | |
| generate_button.click( | |
| fn=run_generation_pipeline_client, | |
| inputs=[ | |
| text_input, | |
| audio_prompt_input, | |
| num_candidates_slider, | |
| cfg_scale_slider, | |
| top_k_slider, | |
| temperature_slider, | |
| chained_longform_checkbox, | |
| seed_slider, # Add seed slider to inputs | |
| audio_warning_display # Pass as dummy input | |
| ], | |
| outputs=[audio_output, status_output], | |
| concurrency_limit=4 # Limit concurrent requests | |
| ) | |
| # Duration-controllable Mode Tab | |
| with gr.TabItem("Duration-controllable mode - (Experimental / Unreliable)"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input_dur = gr.Textbox( | |
| label="Text to Synthesize", | |
| lines=5, | |
| value="<spk_1146> はいはい、それでは、チャンネル登録よろしくお願いしまーす。じゃあみんな、また明日ねー、ばいばーい。" | |
| ) | |
| # Settings and Generate button | |
| with gr.Row(equal_height=False): | |
| with gr.Accordion("----------------------------------⭐ 🛠️ ⭐", open=False, label="_"): | |
| turbo_checkbox_dur = gr.Checkbox( | |
| label="⚡ Turbo Mode (Fast generation, single candidate)", | |
| value=False | |
| ) | |
| gr.Markdown("WARNING! Longform Generation is not optimized for this mode") | |
| num_candidates_slider_dur = gr.Slider( | |
| label="Number of Candidates", | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1 | |
| ) | |
| cfg_scale_slider_dur = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=3.0, | |
| value=1.4, | |
| step=0.1 | |
| ) | |
| top_k_slider_dur = gr.Slider( | |
| label="Top K", | |
| minimum=10, | |
| maximum=100, | |
| value=55, | |
| step=5 | |
| ) | |
| temperature_slider_dur = gr.Slider( | |
| label="Temperature", | |
| minimum=0.0, | |
| maximum=1.3, | |
| value=0.7, | |
| step=0.1 | |
| ) | |
| seed_slider_dur = gr.Slider( | |
| label="Seed (use -1 for random)", | |
| minimum=-1, | |
| maximum=2700000000, | |
| value=2687110803, | |
| step=1 | |
| ) | |
| # Duration-specific parameters | |
| gr.Markdown("### Duration Control Parameters") | |
| add_steps_slider = gr.Slider( | |
| label="Add Steps", | |
| minimum=0, | |
| maximum=10, | |
| value=0, | |
| step=1, | |
| info="Additional generation steps" | |
| ) | |
| use_duration_aware_checkbox = gr.Checkbox( | |
| label="Use Duration Aware", | |
| value=True, | |
| info="don't touch - For legacy compatibility" | |
| ) | |
| chars_per_second_slider = gr.Slider( | |
| label="Characters Per Second", | |
| minimum=10, | |
| maximum=17, | |
| value=14, | |
| step=1, | |
| info="Controls speech speed (10=slowest, 17=fastest)" | |
| ) | |
| chained_longform_checkbox_dur = gr.Checkbox( | |
| label="Use Chained Longform (Sequential conditioning for consistency)", | |
| value=False | |
| ) | |
| audio_prompt_input_dur = gr.Audio( | |
| label="Audio Prompt (Optional - オプション) [Max 10 seconds / 最大10秒]", | |
| sources=["upload", "microphone"], | |
| type="numpy" | |
| ) | |
| # Warning display for audio duration | |
| audio_warning_display_dur = gr.HTML(value="", visible=True) | |
| # Audio validation on change | |
| audio_prompt_input_dur.change( | |
| fn=on_audio_upload, | |
| inputs=[audio_prompt_input_dur], | |
| outputs=[audio_prompt_input_dur, audio_warning_display_dur] | |
| ) | |
| # Turbo mode event handler for duration tab | |
| def toggle_turbo_dur(turbo_enabled): | |
| if turbo_enabled: | |
| return 1, 1.0 # num_candidates=1, temperature=1.0 | |
| else: | |
| return 5, 0.7 # default values | |
| turbo_checkbox_dur.change( | |
| fn=toggle_turbo_dur, | |
| inputs=[turbo_checkbox_dur], | |
| outputs=[num_candidates_slider_dur, temperature_slider_dur] | |
| ) | |
| with gr.Column(scale=1): | |
| generate_button_dur = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| status_output_dur = gr.Textbox(label="Status", interactive=False) | |
| audio_output_dur = gr.Audio(label="Generated Speech", interactive=False, show_download_button=True) | |
| # Event handler for duration tab - pass the warning display as a dummy input | |
| generate_button_dur.click( | |
| fn=run_duration_generation_pipeline_client, | |
| inputs=[ | |
| text_input_dur, | |
| audio_prompt_input_dur, | |
| num_candidates_slider_dur, | |
| cfg_scale_slider_dur, | |
| top_k_slider_dur, | |
| temperature_slider_dur, | |
| chained_longform_checkbox_dur, | |
| add_steps_slider, | |
| use_duration_aware_checkbox, | |
| chars_per_second_slider, | |
| seed_slider_dur, # Add seed slider to inputs | |
| audio_warning_display_dur # Pass as dummy input | |
| ], | |
| outputs=[audio_output_dur, status_output_dur], | |
| concurrency_limit=4 # Limit concurrent requests | |
| ) | |
| with gr.TabItem("Examples"): | |
| if example_list: | |
| gr.Markdown("### Sample Text and Audio Prompts") | |
| gr.Markdown("Click on any example below to load it into the Speech Generation tab") | |
| gr.Markdown("*Note: Temperature and chained longform settings are loaded from the examples file*") | |
| gr.Examples( | |
| examples=example_list, | |
| inputs=[text_input, audio_prompt_input, temperature_slider, chained_longform_checkbox], | |
| label="Click to load an example" | |
| ) | |
| else: | |
| gr.Markdown("### No examples available") | |
| gr.Markdown("Examples will appear here when they are configured.") | |
| with gr.TabItem("Read Me"): | |
| gr.HTML(""" | |
| <div style="background-color: rgba(255, 255, 255, 0.025); padding: 30px; border-radius: 12px; backdrop-filter: blur(10px); max-width: 100%; box-shadow: 0 4px 6px rgba(0,0,0,0.1);"> | |
| <h2 style="color: #000000; margin-bottom: 20px; font-size: 28px;">About Takane</h2> | |
| <p style="color: #1a1a1a; font-weight: 500; line-height: 1.8; margin-bottom: 20px; font-size: 16px;"> | |
| Takane is a frontier Japanese-only speech synthesis network that was trained on tens of thousands of high quality data to autoregressively generate highly compressed audio codes. | |
| This network is powered by Kanadec, the world's only 44.1 kHz - 25 frame rate speech tokenizer which utilizes semantic and acoustic distillation to generate audio tokens as fast as possible. | |
| </p> | |
| <p style="color: #1a1a1a; font-weight: 500; line-height: 1.8; margin-bottom: 20px; font-size: 16px;"> | |
| There are two checkpoints in this demo, one of them utilizes a custom version of Rope to manipulate duration which is seldom seen in autoregressive settings. Please treat it as a proof of concept as its outputs are not very reliable. I'll include it to show that it can work to some levels and can be expanded upon. | |
| Both checkpoints have been fine-tuned on a subset of the dataset with only speaker tags. This will allow us to generate high quality samples without relying on audio prompts or dealing with random speaker attributes, but at the cost of tanking the zero-shot faithfulness of the model. | |
| </p> | |
| <p style="color: #1a1a1a; font-weight: 500; line-height: 1.8; margin-bottom: 20px; font-size: 16px;"> | |
| Takane also comes with an Anti-Hallucination Algorithm (AHA) that generates a few candidates in parallel and automatically returns the best one at the cost of introducing a small overhead. | |
| If you need the fastest response time possible, feel free to enable the Turbo mode. It will disable AHA and tweak the parameters internally to produce samples as fast as 2-3 seconds (though due to an influx of users coming in, you probably will be qeued and have to wait!) | |
| </p> | |
| <p style="color: #1a1a1a; font-weight: 500; line-height: 1.8; margin-bottom: 20px; font-size: 16px;"> | |
| There's no plan to release this model for now. | |
| </p> | |
| <p style="color: #1a1a1a; font-weight: 500; line-height: 1.8; margin-bottom: 20px; font-size: 16px;"> | |
| If you're not using an audio prompt or a speaker tag, or even if you do, you find the later sentences to be too different, then in that case you may want to enable the <code>Chained mode</code>, which will sequentially condition each output to ensure speaker consistency. | |
| </p> | |
| <h3 style="color: #000000; margin-top: 30px; margin-bottom: 15px; font-size: 20px;">Summary of Technical Properties:</h3> | |
| <ul style="color: #1a1a1a; font-weight: 500; line-height: 1.8; font-size: 15px;"> | |
| <li style="margin: 8px 0;">Encoder-Decoder fully autoregressive Transformer</li> | |
| <li style="margin: 8px 0;">Powered by Kanadec (44.1 kHz - 25 codes per second)</li> | |
| <li style="margin: 8px 0;">500M parameters</li> | |
| <li style="margin: 8px 0;">Tens of thousands of anime-esque data, everyday regular Japanese is not supported</li> | |
| <li style="margin: 8px 0;">Experimental support for duration-controllable synthesis</li> | |
| </ul> | |
| <div style="margin-top: 40px; padding-top: 20px; border-top: 1px solid rgba(0,0,0,0.1);"> | |
| <p style="color: #666; font-size: 14px; text-align: center;"> | |
| 🌸 Takane - Advanced Japanese Text-to-Speech System | |
| </p> | |
| </div> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.queue(api_open=False, max_size=15).launch(show_api=False, share=True) |