Cosmobillian commited on
Commit
2349c43
·
verified ·
1 Parent(s): 1f02129

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +287 -1
README.md CHANGED
@@ -20,4 +20,290 @@ pipeline_tag: text-to-speech
20
 
21
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
22
 
23
- [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
22
 
23
+ [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
24
+
25
+
26
+
27
+
28
+ inference.py
29
+ (please install the necessary libraries)
30
+
31
+ # respective torch from https://pytorch.org/
32
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
33
+ pip install snac pathlib torch transformers huggingface_hub librosa numpy scipy torchaudio Flask jsonify
34
+
35
+ import os
36
+ from snac import SNAC
37
+ from pathlib import Path
38
+ import torch
39
+ from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer,BitsAndBytesConfig
40
+ from huggingface_hub import snapshot_download
41
+ import librosa
42
+ import numpy as np
43
+ from scipy.io.wavfile import write
44
+ import torchaudio
45
+ from flask import Flask, jsonify, request
46
+
47
+ modelLocalPath="D:\\...\\Karayakar\\Orpheus-TTS-Turkish-PT-5000"
48
+
49
+
50
+ def load_orpheus_tokenizer(model_id: str = modelLocalPath) -> AutoTokenizer:
51
+ tokenizer = AutoTokenizer.from_pretrained(model_id,local_files_only=True, device_map="cuda")
52
+ return tokenizer
53
+
54
+ def load_snac():
55
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
56
+ return snac_model
57
+
58
+ def load_orpheus_auto_model(model_id: str = modelLocalPath):
59
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16,local_files_only=True, device_map="cuda")
60
+ model.cuda()
61
+ return model
62
+
63
+
64
+
65
+ def tokenize_audio(audio_file_path, snac_model):
66
+ audio_array, sample_rate = librosa.load(audio_file_path, sr=24000)
67
+ waveform = torch.from_numpy(audio_array).unsqueeze(0)
68
+ waveform = waveform.to(dtype=torch.float32)
69
+
70
+ waveform = waveform.unsqueeze(0)
71
+
72
+ with torch.inference_mode():
73
+ codes = snac_model.encode(waveform)
74
+
75
+ all_codes = []
76
+ for i in range(codes[0].shape[1]):
77
+ all_codes.append(codes[0][0][i].item() + 128266)
78
+ all_codes.append(codes[1][0][2 * i].item() + 128266 + 4096)
79
+ all_codes.append(codes[2][0][4 * i].item() + 128266 + (2 * 4096))
80
+ all_codes.append(codes[2][0][(4 * i) + 1].item() + 128266 + (3 * 4096))
81
+ all_codes.append(codes[1][0][(2 * i) + 1].item() + 128266 + (4 * 4096))
82
+ all_codes.append(codes[2][0][(4 * i) + 2].item() + 128266 + (5 * 4096))
83
+ all_codes.append(codes[2][0][(4 * i) + 3].item() + 128266 + (6 * 4096))
84
+
85
+ return all_codes
86
+
87
+
88
+ def prepare_inputs(
89
+ fpath_audio_ref,
90
+ audio_ref_transcript: str,
91
+ text_prompts: list[str],
92
+ snac_model,
93
+ tokenizer,
94
+ ):
95
+
96
+
97
+ start_tokens = torch.tensor([[128259]], dtype=torch.int64)
98
+ end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
99
+ final_tokens = torch.tensor([[128258, 128262]], dtype=torch.int64)
100
+
101
+
102
+ all_modified_input_ids = []
103
+ for prompt in text_prompts:
104
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
105
+ #second_input_ids = torch.cat([zeroprompt_input_ids, start_tokens, input_ids, end_tokens], dim=1)
106
+ second_input_ids = torch.cat([start_tokens, input_ids, end_tokens], dim=1)
107
+ all_modified_input_ids.append(second_input_ids)
108
+
109
+ all_padded_tensors = []
110
+ all_attention_masks = []
111
+ max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])
112
+
113
+ for modified_input_ids in all_modified_input_ids:
114
+ padding = max_length - modified_input_ids.shape[1]
115
+ padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
116
+ attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64),
117
+ torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
118
+ all_padded_tensors.append(padded_tensor)
119
+ all_attention_masks.append(attention_mask)
120
+
121
+ all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
122
+ all_attention_masks = torch.cat(all_attention_masks, dim=0)
123
+
124
+ input_ids = all_padded_tensors.to("cuda")
125
+ attention_mask = all_attention_masks.to("cuda")
126
+ return input_ids, attention_mask
127
+
128
+
129
+
130
+ def inference(model, input_ids, attention_mask):
131
+ with torch.no_grad():
132
+ generated_ids = model.generate(
133
+ input_ids=input_ids,
134
+ attention_mask=attention_mask,
135
+ max_new_tokens=2048,
136
+ do_sample=True,
137
+ temperature=0.2,
138
+ top_k=10,
139
+ top_p=0.9,
140
+ repetition_penalty=1.9,
141
+ num_return_sequences=1,
142
+ eos_token_id=128258,
143
+
144
+ )
145
+
146
+ generated_ids = torch.cat([generated_ids, torch.tensor([[128262]]).to("cuda")], dim=1) # EOAI
147
+
148
+ return generated_ids
149
+
150
+
151
+ def convert_tokens_to_speech(generated_ids, snac_model):
152
+ token_to_find = 128257
153
+ token_to_remove = 128258
154
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
155
+
156
+ if len(token_indices[1]) > 0:
157
+ last_occurrence_idx = token_indices[1][-1].item()
158
+ cropped_tensor = generated_ids[:, last_occurrence_idx + 1:]
159
+ else:
160
+ cropped_tensor = generated_ids
161
+
162
+ _mask = cropped_tensor != token_to_remove
163
+ processed_rows = []
164
+ for row in cropped_tensor:
165
+ masked_row = row[row != token_to_remove]
166
+ processed_rows.append(masked_row)
167
+
168
+ code_lists = []
169
+ for row in processed_rows:
170
+ row_length = row.size(0)
171
+ new_length = (row_length // 7) * 7
172
+ trimmed_row = row[:new_length]
173
+ trimmed_row = [t - 128266 for t in trimmed_row]
174
+ code_lists.append(trimmed_row)
175
+
176
+ my_samples = []
177
+ for code_list in code_lists:
178
+ samples = redistribute_codes(code_list, snac_model)
179
+ my_samples.append(samples)
180
+
181
+ return my_samples
182
+
183
+
184
+ def redistribute_codes(code_list, snac_model):
185
+ layer_1 = []
186
+ layer_2 = []
187
+ layer_3 = []
188
+
189
+ for i in range((len(code_list) + 1) // 7):
190
+ layer_1.append(code_list[7 * i])
191
+ layer_2.append(code_list[7 * i + 1] - 4096)
192
+ layer_3.append(code_list[7 * i + 2] - (2 * 4096))
193
+ layer_3.append(code_list[7 * i + 3] - (3 * 4096))
194
+ layer_2.append(code_list[7 * i + 4] - (4 * 4096))
195
+ layer_3.append(code_list[7 * i + 5] - (5 * 4096))
196
+ layer_3.append(code_list[7 * i + 6] - (6 * 4096))
197
+
198
+ codes = [
199
+ torch.tensor(layer_1).unsqueeze(0),
200
+ torch.tensor(layer_2).unsqueeze(0),
201
+ torch.tensor(layer_3).unsqueeze(0)
202
+ ]
203
+ audio_hat = snac_model.decode(codes)
204
+ return audio_hat
205
+
206
+
207
+ def to_wav_from(samples: list) -> list[np.ndarray]:
208
+ """Converts a list of PyTorch tensors (or NumPy arrays) to NumPy arrays."""
209
+ processed_samples = []
210
+
211
+ for s in samples:
212
+ if isinstance(s, torch.Tensor):
213
+ s = s.detach().squeeze().to('cpu').numpy()
214
+ else:
215
+ s = np.squeeze(s)
216
+
217
+ processed_samples.append(s)
218
+
219
+ return processed_samples
220
+
221
+
222
+ def zero_shot_tts(fpath_audio_ref, audio_ref_transcript, texts: list[str], model, snac_model, tokenizer):
223
+ print(f"fpath_audio_ref {fpath_audio_ref}")
224
+ print(f"audio_ref_transcript {audio_ref_transcript}")
225
+ print(f"texts {texts}")
226
+ inp_ids, attn_mask = prepare_inputs(fpath_audio_ref, audio_ref_transcript, texts, snac_model, tokenizer)
227
+ print(f"input_id_len:{len(inp_ids)}")
228
+ gen_ids = inference(model, inp_ids, attn_mask)
229
+ samples = convert_tokens_to_speech(gen_ids, snac_model)
230
+ wav_forms = to_wav_from(samples)
231
+ return wav_forms
232
+
233
+
234
+ def save_wav(samples: list[np.array], sample_rate: int, filenames: list[str]):
235
+ """ Saves a list of tensors as .wav files.
236
+
237
+ Args:
238
+ samples (list[torch.Tensor]): List of audio tensors.
239
+ sample_rate (int): Sample rate in Hz.
240
+ filenames (list[str]): List of filenames to save.
241
+ """
242
+ wav_data = to_wav_from(samples)
243
+
244
+ for data, filename in zip(wav_data, filenames):
245
+ write(filename, sample_rate, data.astype(np.float32))
246
+ print(f"saved to {filename}")
247
+
248
+
249
+ def get_ref_audio_and_transcript(root_folder: str):
250
+ root_path = Path(root_folder)
251
+ print(f"root_path {root_path}")
252
+ out = []
253
+ for speaker_folder in root_path.iterdir():
254
+ if speaker_folder.is_dir(): # Ensure it's a directory
255
+ wav_files = list(speaker_folder.glob("*.wav"))
256
+ txt_files = list(speaker_folder.glob("*.txt"))
257
+
258
+ if wav_files and txt_files:
259
+ ref_audio = wav_files[0] # Assume only one .wav file per folder
260
+ transcript = txt_files[0].read_text(encoding="utf-8").strip()
261
+ out.append((ref_audio, transcript))
262
+
263
+ return out
264
+
265
+ app = Flask(__name__)
266
+
267
+
268
+ @app.route('/generate', methods=['POST'])
269
+ def generate():
270
+ content = request.json
271
+ process_data(content)
272
+ rresponse = {
273
+ 'received': content,
274
+ 'status': 'success'
275
+ }
276
+ response= jsonify(rresponse)
277
+ response.headers['Content-Type'] = 'application/json; charset=utf-8'
278
+ return response
279
+
280
+
281
+
282
+ def process_data(jsonText):
283
+ texts = [f"{jsonText['text']}"]
284
+ #print(f"texts:{texts}")
285
+ #print(f"prompt_pairs:{prompt_pairs}")
286
+ for fpath_audio, audio_transcript in prompt_pairs:
287
+ print(f"zero shot: {fpath_audio} {audio_transcript}")
288
+ wav_forms = zero_shot_tts(fpath_audio, audio_transcript, texts, model, snac_model, tokenizer)
289
+
290
+ import os
291
+ from pathlib import Path
292
+ from datetime import datetime
293
+ out_dir = Path(fpath_audio).parent / "inference"
294
+ #print(f"out_dir:{out_dir}")
295
+ out_dir.mkdir(parents=True, exist_ok=True) #
296
+ timestamp_str = str(int(datetime.now().timestamp()))
297
+ file_names = [f"{out_dir.as_posix()}/{Path(fpath_audio).stem}_{i}_{timestamp_str}.wav" for i, t in enumerate(texts)]
298
+ #print(f"file_names:{file_names}")
299
+ save_wav(wav_forms, 24000, file_names)
300
+
301
+
302
+
303
+ if __name__ == "__main__":
304
+ tokenizer = load_orpheus_tokenizer()
305
+ model = load_orpheus_auto_model()
306
+ snac_model = load_snac()
307
+ prompt_pairs = get_ref_audio_and_transcript("D:\\AI_APPS\\Orpheus-TTS\\data")
308
+ print(f"snac_model loaded")
309
+ app.run(debug=True,port=5400)