OmniAICreator commited on
Commit
e00b5e2
·
verified ·
1 Parent(s): ecd5a2a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +166 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import soundfile as sf
5
+ import librosa
6
+ import gradio as gr
7
+ import spaces # For ZeroGPU
8
+ from xcodec2.modeling_xcodec2 import XCodec2Model
9
+
10
+ # ====== Settings ======
11
+ BASE_REPO = os.getenv("BASE_REPO", "HKUSTAudio/xcodec2") # Baseline (pretrained)
12
+ FT_REPO = os.getenv("FT_REPO", "NandemoGHS/Anime-XCodec2") # Fine-tuned (yours)
13
+ TARGET_SR = 16000 # XCodec2 expects 16 kHz
14
+ MAX_SECONDS_DEFAULT = 30 # Default max duration (seconds)
15
+
16
+ def _ensure_models():
17
+ """Load both models to CPU once, and reuse across requests."""
18
+ global _model_base, _model_ft
19
+ if _model_base is None:
20
+ _model_base = XCodec2Model.from_pretrained(BASE_REPO).eval().to("cpu")
21
+ if _model_ft is None:
22
+ _model_ft = XCodec2Model.from_pretrained(FT_REPO).eval().to("cpu")
23
+
24
+ # ====== Globals (lazy CPU load; move to GPU only during inference) ======
25
+ _model_base = None
26
+ _model_ft = None
27
+
28
+ _ensure_models()
29
+
30
+
31
+ def _load_audio(filepath: str, max_seconds: int):
32
+ """
33
+ Load audio (wav/flac/ogg/mp3), convert to mono, resample to 16 kHz,
34
+ trim to the given max length (from the beginning), and return torch.Tensor (1, T).
35
+ """
36
+ # Try soundfile first, then fall back to librosa
37
+ try:
38
+ wav, sr = sf.read(filepath, dtype="float32", always_2d=False)
39
+ except Exception:
40
+ wav, sr = librosa.load(filepath, sr=None, mono=False)
41
+ wav = np.asarray(wav, dtype=np.float32)
42
+
43
+ # Mono
44
+ if wav.ndim == 2:
45
+ # soundfile often returns (frames, channels)
46
+ if wav.shape[1] in (1, 2): # (frames, ch)
47
+ wav = wav.mean(axis=1)
48
+ else: # Possibly (ch, frames)
49
+ wav = wav.mean(axis=0)
50
+ elif wav.ndim > 2:
51
+ wav = np.mean(wav, axis=tuple(range(1, wav.ndim)))
52
+
53
+ # Resample to 16 kHz
54
+ if sr != TARGET_SR:
55
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR)
56
+ sr = TARGET_SR
57
+
58
+ # Length cap
59
+ if max_seconds is None or max_seconds <= 0:
60
+ max_seconds = MAX_SECONDS_DEFAULT
61
+ max_len = int(sr * max_seconds)
62
+ if wav.shape[0] > max_len:
63
+ wav = wav[:max_len]
64
+
65
+ # Light safety normalization
66
+ peak = np.max(np.abs(wav))
67
+ if peak > 1.0:
68
+ wav = wav / (peak + 1e-8)
69
+
70
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # (1, T)
71
+ return wav_tensor, sr
72
+
73
+ def _codes_to_tensor(codes, device):
74
+ """
75
+ Normalize the output of xcodec2.encode_code to a tensor with shape (1, 1, N).
76
+ Handles version differences where the return type/shape may vary.
77
+ """
78
+ if isinstance(codes, torch.Tensor):
79
+ return codes.to(device)
80
+ try:
81
+ t = torch.as_tensor(codes[0][0], device=device)
82
+ return t.unsqueeze(0).unsqueeze(0) if t.ndim == 1 else t
83
+ except Exception:
84
+ return torch.as_tensor(codes, device=device)
85
+
86
+ def _reconstruct(model: XCodec2Model, waveform: torch.Tensor, device: str) -> np.ndarray:
87
+ """Encode→decode with XCodec2 to get a reconstructed waveform (np.float32, clipped to [-1, 1])."""
88
+ with torch.inference_mode():
89
+ wave = waveform.to(device)
90
+ codes = model.encode_code(input_waveform=wave)
91
+ codes_t = _codes_to_tensor(codes, device=device)
92
+ recon = model.decode_code(codes_t) # (1, 1, T')
93
+ recon_np = recon.squeeze().detach().cpu().numpy().astype(np.float32)
94
+ recon_np = np.clip(recon_np, -1.0, 1.0)
95
+ return recon_np
96
+
97
+ @spaces.GPU(duration=60) # ZeroGPU: reserve GPU only during this function call
98
+ def run(audio_path, max_seconds):
99
+ if audio_path is None:
100
+ raise gr.Error("Please upload an audio file.")
101
+
102
+ _ensure_models()
103
+ waveform, sr = _load_audio(audio_path, max_seconds)
104
+
105
+ device = "cuda" if torch.cuda.is_available() else "cpu"
106
+
107
+ # Baseline (pretrained)
108
+ base = _model_base.to(device)
109
+ recon_base = _reconstruct(base, waveform, device)
110
+
111
+ # Fine-tuned
112
+ ft = _model_ft.to(device)
113
+ recon_ft = _reconstruct(ft, waveform, device)
114
+
115
+ # Gradio Audio expects (sample_rate, np.ndarray)
116
+ return (sr, recon_base), (sr, recon_ft)
117
+
118
+ # ====== UI ======
119
+ DESCRIPTION = """
120
+ # Anime‑XCodec2 / XCodec2 Reconstruction Demo
121
+ Compare **Baseline (HKUSTAudio/xcodec2)** and **Fine‑tuned (NandemoGHS/Anime‑XCodec2)** reconstructions side by side.
122
+
123
+ - Supported inputs: wav / flac / ogg / mp3
124
+ - Input is automatically converted to **16 kHz** (as required by XCodec2).
125
+ - ZeroGPU ready. If no GPU is available, it falls back to CPU (slower).
126
+ """
127
+
128
+ with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo:
129
+ gr.Markdown(DESCRIPTION)
130
+
131
+ with gr.Row():
132
+ with gr.Column(scale=1):
133
+ inp = gr.Audio(
134
+ sources=["upload"],
135
+ type="filepath",
136
+ label="Upload an audio file",
137
+ waveform_options={"show_controls": True}
138
+ )
139
+ max_sec = gr.Slider(
140
+ 3, 60, value=MAX_SECONDS_DEFAULT, step=1,
141
+ label="Max length (seconds)",
142
+ info="If the input is longer, only the first N seconds will be processed."
143
+ )
144
+ run_btn = gr.Button("Run", variant="primary")
145
+ gr.Markdown(
146
+ f"**Baseline model**: `{BASE_REPO}` \n"
147
+ f"**Fine‑tuned model**: `{FT_REPO}` \n"
148
+ f"**Inference device**: auto (GPU on ZeroGPU)"
149
+ )
150
+
151
+ with gr.Column(scale=1):
152
+ with gr.Row():
153
+ out_base = gr.Audio(
154
+ label="Baseline reconstruction (HKUSTAudio/xcodec2)",
155
+ show_download_button=True, format="wav"
156
+ )
157
+ out_ft = gr.Audio(
158
+ label="Fine‑tuned reconstruction (NandemoGHS/Anime‑XCodec2)",
159
+ show_download_button=True, format="wav"
160
+ )
161
+
162
+ run_btn.click(run, inputs=[inp, max_sec], outputs=[out_base, out_ft])
163
+
164
+ # In Spaces, explicit launch is optional
165
+ if __name__ == "__main__":
166
+ demo.queue(max_size=8).launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.44.0,<6
2
+ xcodec2==0.1.3
3
+ soundfile>=0.12.1
4
+ librosa>=0.10.2.post1
5
+ numpy>=1.23