File size: 4,222 Bytes
7d6af4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94c03bd
dd2e2b8
7d6af4d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import dashscope
from argparse import ArgumentParser

import gradio as gr

os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

dashscope.api_key = os.environ.get("DASHSCOPE_API_KEY")

def _launch_demo(args):

    def generate_caption_from_audio(audio_path, temperature, top_p, top_k):
        messages = [
            {
                "role": "user",
                "content": [{"audio": "file://" + audio_path}],
            }
        ]
        response = dashscope.MultiModalConversation.call(
                    model="qwen3-omni-30b-a3b-captioner",
                    top_p=top_p,
                    top_k=top_k,
                    temperature=temperature,
                    messages=messages)
        
        return response["output"]["choices"][0]["message"].content[0]["text"]

    def on_submit(audio_path, temperature, top_p, top_k):
        if not audio_path:
            yield None, gr.update(interactive=True)
            return

        caption = generate_caption_from_audio(audio_path, temperature, top_p, top_k)

        yield caption, gr.update(interactive=True)

    with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"])) as demo:
        gr.Markdown("# Qwen3-Omni-30B-A3B-Captioner Demo")
        
        with gr.Row():
            with gr.Column(scale=1):
                audio_input = gr.Audio(sources=['upload', 'microphone'], type="filepath", label="Upload or record an audio")
                
                with gr.Accordion("Generation Parameters", open=True):
                    temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.6, step=0.1)
                    top_p = gr.Slider(label="Top P", minimum=0.05, maximum=1.0, value=0.95, step=0.05)
                    top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=20, step=1)
                
                with gr.Row():
                    submit_btn = gr.Button("Submit", variant="primary")
                    clear_btn = gr.Button("Clear")

            with gr.Column(scale=2):
                output_caption = gr.Textbox(
                    label="Caption Result",
                    lines=15,
                    interactive=False
                )

        def clear_fields():
            return None, ""

        submit_btn.click(
            fn=on_submit,
            inputs=[audio_input, temperature, top_p, top_k],
            outputs=[output_caption, submit_btn]
        )

        clear_btn.click(fn=clear_fields, inputs=None, outputs=[audio_input, output_caption])
    
    demo.queue(100, max_size=100).launch(max_threads=100,
                                        ssr_mode=False,
                                        share=args.share,
                                        inbrowser=args.inbrowser,
                                        server_port=args.server_port,
                                        server_name=args.server_name,)


DEFAULT_CKPT_PATH = "Qwen/Qwen3-Omni-30B-A3B-Captioner"

def _get_args():
    parser = ArgumentParser()

    parser.add_argument('-c', '--checkpoint-path', type=str, default=DEFAULT_CKPT_PATH,
                        help='Checkpoint name or path, default to %(default)r')
    parser.add_argument('--flash-attn2', action='store_true', default=False,
                        help='Enable flash_attention_2 when loading the model.')
    parser.add_argument('--use-transformers', action='store_true', default=False,
                        help='Use transformers for inference instead of vLLM.')
    parser.add_argument('--share', action='store_true', default=False,
                        help='Create a publicly shareable link for the interface.')
    parser.add_argument('--inbrowser', action='store_true', default=False,
                        help='Automatically launch the interface in a new tab on the default browser.')
    parser.add_argument('--server-port', type=int, help='Demo server port.')
    parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Demo server name.')

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = _get_args()
    _launch_demo(args)