Pj12 commited on
Commit
35e2e4b
·
verified ·
1 Parent(s): d0b22cd

Delete extract_feature_print.py

Browse files
Files changed (1) hide show
  1. extract_feature_print.py +0 -298
extract_feature_print.py DELETED
@@ -1,298 +0,0 @@
1
- import os, sys, traceback
2
- from transformers import HubertModel
3
- import librosa
4
- from torch import nn
5
- import torch
6
-
7
- import json
8
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
9
- os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
10
-
11
- device=sys.argv[1]
12
- n_part = int(sys.argv[2])
13
- i_part = int(sys.argv[3])
14
- if len(sys.argv) == 6:
15
- exp_dir = sys.argv[4]
16
- version = sys.argv[5]
17
- else:
18
- i_gpu = sys.argv[4]
19
- exp_dir = sys.argv[5]
20
- os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
21
- version = sys.argv[6]
22
- import torch
23
- import torch.nn.functional as F
24
- import soundfile as sf
25
- import numpy as np
26
- from fairseq import checkpoint_utils
27
-
28
- #device = "cpu"
29
- if torch.cuda.is_available():
30
- device = "cuda"
31
- elif torch.backends.mps.is_available():
32
- device = "mps"
33
-
34
- version_config_paths = [
35
- os.path.join("", "32k.json"),
36
- os.path.join("", "40k.json"),
37
- os.path.join("", "48k.json"),
38
- os.path.join("", "48k_v2.json"),
39
- os.path.join("", "40k.json"),
40
- os.path.join("", "32k_v2.json"),
41
- ]
42
-
43
- class Config:
44
- def __init__(self):
45
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
46
- self.is_half = self.device != "cpu"
47
- self.gpu_name = (
48
- torch.cuda.get_device_name(int(self.device.split(":")[-1]))
49
- if self.device.startswith("cuda")
50
- else None
51
- )
52
- self.json_config = self.load_config_json()
53
- self.gpu_mem = None
54
- self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
55
-
56
- def load_config_json(self) -> dict:
57
- configs = {}
58
- for config_file in version_config_paths:
59
- config_path = os.path.join("configs", config_file)
60
- with open(config_path, "r") as f:
61
- configs[config_file] = json.load(f)
62
- return configs
63
-
64
- def has_mps(self) -> bool:
65
- # Check if Metal Performance Shaders are available - for macOS 12.3+.
66
- return torch.backends.mps.is_available()
67
-
68
- def has_xpu(self) -> bool:
69
- # Check if XPU is available.
70
- return hasattr(torch, "xpu") and torch.xpu.is_available()
71
-
72
- def set_precision(self, precision):
73
- if precision not in ["fp32", "fp16"]:
74
- raise ValueError("Invalid precision type. Must be 'fp32' or 'fp16'.")
75
-
76
- fp16_run_value = precision == "fp16"
77
- preprocess_target_version = "3.7" if precision == "fp16" else "3.0"
78
- preprocess_path = os.path.join(
79
- os.path.dirname(__file__),
80
- os.pardir,
81
- ""
82
- "preprocess.py",
83
- )
84
-
85
- for config_path in version_config_paths:
86
- full_config_path = os.path.join("configs", config_path)
87
- try:
88
- with open(full_config_path, "r") as f:
89
- config = json.load(f)
90
- config["train"]["fp16_run"] = fp16_run_value
91
- with open(full_config_path, "w") as f:
92
- json.dump(config, f, indent=4)
93
- except FileNotFoundError:
94
- print(f"File not found: {full_config_path}")
95
-
96
- if os.path.exists(preprocess_path):
97
- with open(preprocess_path, "r") as f:
98
- preprocess_content = f.read()
99
- preprocess_content = preprocess_content.replace(
100
- "3.0" if precision == "fp16" else "3.7", preprocess_target_version
101
- )
102
- with open(preprocess_path, "w") as f:
103
- f.write(preprocess_content)
104
-
105
- return f"Overwritten preprocess and config.json to use {precision}."
106
-
107
- def get_precision(self):
108
- if not version_config_paths:
109
- raise FileNotFoundError("No configuration paths provided.")
110
-
111
- full_config_path = os.path.join("configs", version_config_paths[0])
112
- try:
113
- with open(full_config_path, "r") as f:
114
- config = json.load(f)
115
- fp16_run_value = config["train"].get("fp16_run", False)
116
- precision = "fp16" if fp16_run_value else "fp32"
117
- return precision
118
- except FileNotFoundError:
119
- print(f"File not found: {full_config_path}")
120
- return None
121
-
122
- def device_config(self) -> tuple:
123
- if self.device.startswith("cuda"):
124
- self.set_cuda_config()
125
- elif self.has_mps():
126
- self.device = "mps"
127
- self.is_half = False
128
- self.set_precision("fp32")
129
- else:
130
- self.device = "cpu"
131
- self.is_half = False
132
- self.set_precision("fp32")
133
-
134
- # Configuration for 6GB GPU memory
135
- x_pad, x_query, x_center, x_max = (
136
- (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
137
- )
138
- if self.gpu_mem is not None and self.gpu_mem <= 4:
139
- # Configuration for 5GB GPU memory
140
- x_pad, x_query, x_center, x_max = (1, 5, 30, 32)
141
-
142
- return x_pad, x_query, x_center, x_max
143
-
144
- def set_cuda_config(self):
145
- i_device = int(self.device.split(":")[-1])
146
- self.gpu_name = torch.cuda.get_device_name(i_device)
147
- low_end_gpus = ["16", "P40", "P10", "1060", "1070", "1080"]
148
- if (
149
- any(gpu in self.gpu_name for gpu in low_end_gpus)
150
- and "V100" not in self.gpu_name.upper()
151
- ):
152
- self.is_half = False
153
- self.set_precision("fp32")
154
-
155
- self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (
156
- 1024**3
157
- )
158
- config = Config()
159
-
160
- def load_audio(file, sample_rate):
161
- try:
162
- file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
163
- audio, sr = sf.read(file)
164
- if len(audio.shape) > 1:
165
- audio = librosa.to_mono(audio.T)
166
- if sr != sample_rate:
167
- audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
168
- except Exception as error:
169
- raise RuntimeError(f"An error occurred loading the audio: {error}")
170
-
171
- return audio.flatten()
172
-
173
- #HuggingFacePlaceHolder = None
174
- class HubertModelWithFinalProj(HubertModel):
175
- def __init__(self, config):
176
- super().__init__(config)
177
- self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
178
- print(config.hidden_size, config.classifier_proj_size)
179
-
180
- f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
181
-
182
-
183
- def printt(strr):
184
- print(strr)
185
- f.write("%s\n" % strr)
186
- f.flush()
187
-
188
-
189
- printt(sys.argv)
190
- model_path = sys.argv[7]
191
- Custom_Embed = False
192
- sample_embedding = sys.argv[8]
193
- if os.path.split(model_path)[-1] == "Custom" and sample_embedding == "hubert_base":
194
- model_path = "hubert_base.pt"
195
- Custom_Embed = True
196
- elif os.path.split(model_path)[-1] == "Custom" and sample_embedding == "contentvec_base":
197
- model_path = "contentvec_base.pt"
198
- Custom_Embed = True
199
- elif os.path.split(model_path)[-1] == "Custom" and sample_embedding == "hubert_base_japanese":
200
- model_path = "japanese_hubert_base.pt"
201
- Custom_Embed = True
202
-
203
- printt(exp_dir)
204
- wavPath = "%s/1_16k_wavs" % exp_dir
205
- outPath = (
206
- "%s/3_feature256" % exp_dir if version == "v1" else "%s/3_feature768" % exp_dir
207
- )
208
- os.makedirs(outPath, exist_ok=True)
209
-
210
-
211
- # wave must be 16k, hop_size=320
212
- def readwave(wav_path, normalize=False):
213
- wav, sr = sf.read(wav_path)
214
- assert sr == 16000
215
- if Custom_Embed == False:
216
- feats = torch.from_numpy(wav).float()
217
- else:
218
- feats = torch.from_numpy(load_audio(wav_path, sr)).to(dtype).to(device)
219
- if feats.dim() == 2: # double channels
220
- feats = feats.mean(-1)
221
- assert feats.dim() == 1, feats.dim()
222
- if normalize:
223
- with torch.no_grad():
224
- feats = F.layer_norm(feats, feats.shape)
225
- feats = feats.view(1, -1)
226
- return feats
227
-
228
-
229
- # HuBERT model
230
- printt("load model(s) from {}".format(model_path))
231
- # if hubert model is exist
232
- if os.access(model_path, os.F_OK) == False:
233
- printt(
234
- "Error: Extracting is shut down because %s does not exist, you may download it from https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main"
235
- % model_path
236
- )
237
- exit(0)
238
- models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
239
- [model_path],
240
- suffix="",
241
- )
242
- if Custom_Embed == False:
243
- model = models[0]
244
- if device not in ["mps", "cpu"]:
245
- model = model.half()
246
- else:
247
- dtype = torch.float16 if config.is_half and "cuda" in device else torch.float32
248
- model = HubertModelWithFinalProj.from_pretrained("Custom/").to(dtype).to(device)
249
- model = model.to(device)
250
- printt("move model to %s" % device)
251
- model.eval()
252
-
253
- todo = sorted(list(os.listdir(wavPath)))[i_part::n_part]
254
- n = max(1, len(todo) // 10)
255
- if len(todo) == 0:
256
- printt("no-feature-todo")
257
- else:
258
- printt("all-feature-%s" % len(todo))
259
- for idx, file in enumerate(todo):
260
- try:
261
- if file.endswith(".wav"):
262
- wav_path = "%s/%s" % (wavPath, file)
263
- out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))
264
-
265
- if os.path.exists(out_path):
266
- continue
267
-
268
- feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
269
- padding_mask = torch.BoolTensor(feats.shape).fill_(False)
270
- inputs = {
271
- "source": feats.half().to(device)
272
- if device not in ["mps", "cpu"]
273
- else feats.to(device),
274
- "padding_mask": padding_mask.to(device),
275
- "output_layer": 9 if version == "v1" else 12, # layer 9
276
- }
277
- with torch.no_grad():
278
- if Custom_Embed == False:
279
- logits = model.extract_features(**inputs)
280
- feats = (
281
- model.final_proj(logits[0]) if version == "v1" else logits[0]
282
- )
283
- elif Custom_Embed == True:
284
- feats = model(feats)["last_hidden_state"]
285
- feats = (
286
- model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats
287
- )
288
-
289
- feats = feats.squeeze(0).float().cpu().numpy()
290
- if np.isnan(feats).sum() == 0:
291
- np.save(out_path, feats, allow_pickle=False)
292
- else:
293
- printt("%s-contains nan" % file)
294
- if idx % n == 0:
295
- printt("now-%s,all-%s,%s,%s" % (idx, len(todo), file, feats.shape))
296
- except:
297
- printt(traceback.format_exc())
298
- printt("all-feature-done")