ekwek commited on
Commit
55b3415
·
verified ·
1 Parent(s): 278ac29

Upload 11 files

Browse files
soprano/soprano/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tts import SopranoTTS
soprano/soprano/backends/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseModel:
2
+ def infer(self,
3
+ prompts,
4
+ top_p=0.95,
5
+ temperature=0.3,
6
+ repetition_penalty=1.2):
7
+ '''
8
+ Takes a list of prompts and returns the output hidden states
9
+ '''
10
+ pass
11
+
12
+ def stream_infer(self,
13
+ prompt,
14
+ top_p=0.95,
15
+ temperature=0.3,
16
+ repetition_penalty=1.2):
17
+ '''
18
+ Takes a prompt and returns an iterator of the output hidden states
19
+ '''
20
+ pass
soprano/soprano/backends/lmdeploy.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
3
+ from .base import BaseModel
4
+
5
+
6
+ class LMDeployModel(BaseModel):
7
+ def __init__(self,
8
+ device='cuda',
9
+ cache_size_mb=100,
10
+ **kwargs):
11
+ assert device == 'cuda', "lmdeploy only supports cuda devices, consider changing device or using a different backend instead."
12
+ cache_size_ratio = cache_size_mb * 1024**2 / torch.cuda.get_device_properties('cuda').total_memory
13
+ backend_config = TurbomindEngineConfig(cache_max_entry_count=cache_size_ratio)
14
+ self.pipeline = pipeline('ekwek/Soprano-80M',
15
+ log_level='ERROR',
16
+ backend_config=backend_config)
17
+
18
+ def infer(self,
19
+ prompts,
20
+ top_p=0.95,
21
+ temperature=0.3,
22
+ repetition_penalty=1.2):
23
+ gen_config=GenerationConfig(output_last_hidden_state='generation',
24
+ do_sample=True,
25
+ top_p=top_p,
26
+ temperature=temperature,
27
+ repetition_penalty=repetition_penalty,
28
+ max_new_tokens=512)
29
+ responses = self.pipeline(prompts, gen_config=gen_config)
30
+ res = []
31
+ for response in responses:
32
+ res.append({
33
+ 'finish_reason': response.finish_reason,
34
+ 'hidden_state': response.last_hidden_state
35
+ })
36
+ return res
37
+
38
+ def stream_infer(self,
39
+ prompt,
40
+ top_p=0.95,
41
+ temperature=0.3,
42
+ repetition_penalty=1.2):
43
+ gen_config=GenerationConfig(output_last_hidden_state='generation',
44
+ do_sample=True,
45
+ top_p=top_p,
46
+ temperature=temperature,
47
+ repetition_penalty=repetition_penalty,
48
+ max_new_tokens=512)
49
+ responses = self.pipeline.stream_infer([prompt], gen_config=gen_config)
50
+ for response in responses:
51
+ yield {
52
+ 'finish_reason': response.finish_reason,
53
+ 'hidden_state': response.last_hidden_state
54
+ }
soprano/soprano/backends/transformers.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from .base import BaseModel
4
+
5
+
6
+ class TransformersModel(BaseModel):
7
+ def __init__(self,
8
+ device='cuda',
9
+ **kwargs):
10
+ self.device = device
11
+
12
+ self.model = AutoModelForCausalLM.from_pretrained(
13
+ 'ekwek/Soprano-80M',
14
+ torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
15
+ device_map=device
16
+ )
17
+ self.tokenizer = AutoTokenizer.from_pretrained('ekwek/Soprano-80M')
18
+ self.model.eval()
19
+
20
+ def infer(self,
21
+ prompts,
22
+ top_p=0.95,
23
+ temperature=0.3,
24
+ repetition_penalty=1.2):
25
+ inputs = self.tokenizer(
26
+ prompts,
27
+ return_tensors='pt',
28
+ padding=True,
29
+ truncation=True,
30
+ max_length=512,
31
+ ).to(self.device)
32
+
33
+ with torch.no_grad():
34
+ outputs = self.model.generate(
35
+ input_ids=inputs['input_ids'],
36
+ attention_mask=inputs['attention_mask'],
37
+ max_new_tokens=512,
38
+ do_sample=True,
39
+ top_p=top_p,
40
+ temperature=temperature,
41
+ repetition_penalty=repetition_penalty,
42
+ pad_token_id=self.tokenizer.pad_token_id,
43
+ return_dict_in_generate=True,
44
+ output_hidden_states=True,
45
+ )
46
+ res = []
47
+ eos_token_id = self.model.config.eos_token_id
48
+ for i in range(len(prompts)):
49
+ seq = outputs.sequences[i]
50
+ hidden_states = []
51
+ num_output_tokens = len(outputs.hidden_states)
52
+ for j in range(num_output_tokens):
53
+ token = seq[j + seq.size(0) - num_output_tokens]
54
+ if token != eos_token_id: hidden_states.append(outputs.hidden_states[j][-1][i, -1, :])
55
+ last_hidden_state = torch.stack(hidden_states).squeeze()
56
+ finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length'
57
+ res.append({
58
+ 'finish_reason': finish_reason,
59
+ 'hidden_state': last_hidden_state
60
+ })
61
+ return res
62
+
63
+ def stream_infer(self,
64
+ prompt,
65
+ top_p=0.95,
66
+ temperature=0.3,
67
+ repetition_penalty=1.2):
68
+ raise NotImplementedError("transformers backend does not currently support streaming, please consider using lmdeploy backend instead.")
soprano/soprano/tts.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vocos.decoder import SopranoDecoder
2
+ from .utils.text import clean_text
3
+ import torch
4
+ import re
5
+ from unidecode import unidecode
6
+ from scipy.io import wavfile
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
+ import time
10
+
11
+
12
+ class SopranoTTS:
13
+ def __init__(self,
14
+ backend='auto',
15
+ device='cuda',
16
+ cache_size_mb=10,
17
+ decoder_batch_size=1):
18
+ RECOGNIZED_DEVICES = ['cuda']
19
+ RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
20
+ assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
21
+ if backend == 'auto':
22
+ if device == 'cpu':
23
+ backend = 'transformers'
24
+ else:
25
+ try:
26
+ import lmdeploy
27
+ backend = 'lmdeploy'
28
+ except ImportError:
29
+ backend='transformers'
30
+ print(f"Using backend {backend}.")
31
+ assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}"
32
+
33
+ if backend == 'lmdeploy':
34
+ from .backends.lmdeploy import LMDeployModel
35
+ self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
36
+ elif backend == 'transformers':
37
+ from .backends.transformers import TransformersModel
38
+ self.pipeline = TransformersModel(device=device)
39
+
40
+ self.decoder = SopranoDecoder().cuda()
41
+ decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
42
+ self.decoder.load_state_dict(torch.load(decoder_path))
43
+ self.decoder_batch_size=decoder_batch_size
44
+ self.RECEPTIVE_FIELD = 4 # Decoder receptive field
45
+ self.TOKEN_SIZE = 2048 # Number of samples per audio token
46
+
47
+ self.infer("Hello world!") # warmup
48
+
49
+ def _preprocess_text(self, texts, min_length=30):
50
+ '''
51
+ adds prompt format and sentence/part index
52
+ Enforces a minimum sentence length by merging short sentences.
53
+ '''
54
+ res = []
55
+ for text_idx, text in enumerate(texts):
56
+ text = text.strip()
57
+ cleaned_text = clean_text(text)
58
+ sentences = re.split(r"(?<=[.!?])\s+", cleaned_text)
59
+ processed = []
60
+ for sentence in sentences:
61
+ processed.append({
62
+ "text": sentence,
63
+ "text_idx": text_idx,
64
+ })
65
+
66
+ if min_length > 0 and len(processed) > 1:
67
+ merged = []
68
+ i = 0
69
+ while i < len(processed):
70
+ cur = processed[i]
71
+ if len(cur["text"]) < min_length:
72
+ if merged: merged[-1]["text"] = (merged[-1]["text"] + " " + cur["text"]).strip()
73
+ else:
74
+ if i + 1 < len(processed): processed[i + 1]["text"] = (cur["text"] + " " + processed[i + 1]["text"]).strip()
75
+ else: merged.append(cur)
76
+ else: merged.append(cur)
77
+ i += 1
78
+ processed = merged
79
+ sentence_idxes = {}
80
+ for item in processed:
81
+ if item['text_idx'] not in sentence_idxes: sentence_idxes[item['text_idx']] = 0
82
+ res.append((f'[STOP][TEXT]{item["text"]}[START]', item["text_idx"], sentence_idxes[item['text_idx']]))
83
+ sentence_idxes[item['text_idx']] += 1
84
+ return res
85
+
86
+ def infer(self,
87
+ text,
88
+ out_path=None,
89
+ top_p=0.95,
90
+ temperature=0.3,
91
+ repetition_penalty=1.2):
92
+ results = self.infer_batch([text],
93
+ top_p=top_p,
94
+ temperature=temperature,
95
+ repetition_penalty=repetition_penalty,
96
+ out_dir=None)[0]
97
+ if out_path:
98
+ wavfile.write(out_path, 32000, results.cpu().numpy())
99
+ return results
100
+
101
+ def infer_batch(self,
102
+ texts,
103
+ out_dir=None,
104
+ top_p=0.95,
105
+ temperature=0.3,
106
+ repetition_penalty=1.2):
107
+ sentence_data = self._preprocess_text(texts)
108
+ prompts = list(map(lambda x: x[0], sentence_data))
109
+ responses = self.pipeline.infer(prompts,
110
+ top_p=top_p,
111
+ temperature=temperature,
112
+ repetition_penalty=repetition_penalty)
113
+ hidden_states = []
114
+ for i, response in enumerate(responses):
115
+ if response['finish_reason'] != 'stop':
116
+ print(f"Warning: some sentences did not complete generation, likely due to hallucination.")
117
+ hidden_state = response['hidden_state']
118
+ hidden_states.append(hidden_state)
119
+ combined = list(zip(hidden_states, sentence_data))
120
+ combined.sort(key=lambda x: -x[0].size(0))
121
+ hidden_states, sentence_data = zip(*combined)
122
+
123
+ num_texts = len(texts)
124
+ audio_concat = [[] for _ in range(num_texts)]
125
+ for sentence in sentence_data:
126
+ audio_concat[sentence[1]].append(None)
127
+ for idx in range(0, len(hidden_states), self.decoder_batch_size):
128
+ batch_hidden_states = []
129
+ lengths = list(map(lambda x: x.size(0), hidden_states[idx:idx+self.decoder_batch_size]))
130
+ N = len(lengths)
131
+ for i in range(N):
132
+ batch_hidden_states.append(torch.cat([
133
+ torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'),
134
+ hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32),
135
+ ], dim=2))
136
+ batch_hidden_states = torch.cat(batch_hidden_states)
137
+ with torch.no_grad():
138
+ audio = self.decoder(batch_hidden_states)
139
+
140
+ for i in range(N):
141
+ text_id = sentence_data[idx+i][1]
142
+ sentence_id = sentence_data[idx+i][2]
143
+ audio_concat[text_id][sentence_id] = audio[i].squeeze()[-(lengths[i]*self.TOKEN_SIZE-self.TOKEN_SIZE):]
144
+ audio_concat = [torch.cat(x).cpu() for x in audio_concat]
145
+
146
+ if out_dir:
147
+ os.makedirs(out_dir, exist_ok=True)
148
+ for i in range(len(audio_concat)):
149
+ wavfile.write(f"{out_dir}/{i}.wav", 32000, audio_concat[i].cpu().numpy())
150
+ return audio_concat
151
+
152
+ def infer_stream(self,
153
+ text,
154
+ chunk_size=1,
155
+ top_p=0.95,
156
+ temperature=0.3,
157
+ repetition_penalty=1.2):
158
+ start_time = time.time()
159
+ sentence_data = self._preprocess_text([text])
160
+
161
+ first_chunk = True
162
+ for sentence, _, _ in sentence_data:
163
+ responses = self.pipeline.stream_infer(sentence,
164
+ top_p=top_p,
165
+ temperature=temperature,
166
+ repetition_penalty=repetition_penalty)
167
+ hidden_states_buffer = []
168
+ chunk_counter = chunk_size
169
+ for token in responses:
170
+ finished = token['finish_reason'] is not None
171
+ if not finished: hidden_states_buffer.append(token['hidden_state'][-1])
172
+ hidden_states_buffer = hidden_states_buffer[-(2*self.RECEPTIVE_FIELD+chunk_size):]
173
+ if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
174
+ if finished or chunk_counter == chunk_size:
175
+ batch_hidden_states = torch.stack(hidden_states_buffer)
176
+ inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32)
177
+ with torch.no_grad():
178
+ audio = self.decoder(inp)[0]
179
+ if finished:
180
+ audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_counter-1)*self.TOKEN_SIZE-self.TOKEN_SIZE):]
181
+ else:
182
+ audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_size)*self.TOKEN_SIZE-self.TOKEN_SIZE):-(self.RECEPTIVE_FIELD*self.TOKEN_SIZE-self.TOKEN_SIZE)]
183
+ chunk_counter = 0
184
+ if first_chunk:
185
+ print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms")
186
+ first_chunk = False
187
+ yield audio_chunk.cpu()
188
+ chunk_counter += 1
soprano/soprano/utils/text.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Normalize input text to a format that Soprano recognizes.
3
+ Adapted from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/utils/tokenizer.py
4
+ """
5
+ import os
6
+ import re
7
+
8
+ import inflect
9
+ from unidecode import unidecode
10
+
11
+
12
+ _inflect = inflect.engine()
13
+
14
+ ####################################################################################################
15
+ # Abbreviations
16
+
17
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
18
+ ('mrs', 'misuss'),
19
+ ('ms', 'miss'),
20
+ ('mr', 'mister'),
21
+ ('dr', 'doctor'),
22
+ ('st', 'saint'),
23
+ ('co', 'company'),
24
+ ('jr', 'junior'),
25
+ ('maj', 'major'),
26
+ ('gen', 'general'),
27
+ ('drs', 'doctors'),
28
+ ('rev', 'reverend'),
29
+ ('lt', 'lieutenant'),
30
+ ('hon', 'honorable'),
31
+ ('sgt', 'sergeant'),
32
+ ('capt', 'captain'),
33
+ ('esq', 'esquire'),
34
+ ('ltd', 'limited'),
35
+ ('col', 'colonel'),
36
+ ('ft', 'fort'),
37
+ ]]
38
+ _cased_abbreviations = [(re.compile('\\b%s\\b' % x[0]), x[1]) for x in [
39
+ ('TTS', 'text to speech'),
40
+ ('Hz', 'hertz'),
41
+ ('kHz', 'kilohertz'),
42
+ ('KBs', 'kilobytes'),
43
+ ('KB', 'kilobyte'),
44
+ ('MBs', 'megabytes'),
45
+ ('MB', 'megabyte'),
46
+ ('GBs', 'gigabytes'),
47
+ ('GB', 'gigabyte'),
48
+ ('TBs', 'terabytes'),
49
+ ('TB', 'terabyte'),
50
+ ('APIs', 'a p i\'s'),
51
+ ('API', 'a p i'),
52
+ ('CLIs', 'c l i\'s'),
53
+ ('CLI', 'c l i'),
54
+ ('CPUs', 'c p u\'s'),
55
+ ('CPU', 'c p u'),
56
+ ('GPUs', 'g p u\'s'),
57
+ ('GPU', 'g p u'),
58
+ ('Ave', 'avenue'),
59
+ ]]
60
+
61
+ def expand_abbreviations(text):
62
+ for regex, replacement in _abbreviations + _cased_abbreviations:
63
+ text = re.sub(regex, replacement, text)
64
+ return text
65
+
66
+ ####################################################################################################
67
+ # Numbers
68
+
69
+ _num_prefix_re = re.compile(r'#\d')
70
+ _num_suffix_re = re.compile(r'\d(K|M|B|T)', re.IGNORECASE)
71
+ _num_letter_split_re = re.compile(r'(\d[a-z]|[a-z]\d)', re.IGNORECASE)
72
+
73
+ _comma_number_re = re.compile(r'(\d[\d\,]+\d)')
74
+ _date_re = re.compile(r'(^|[^/])(\d\d?[/-]\d\d?[/-]\d\d(?:\d\d)?)($|[^/])')
75
+ _phone_number_re = re.compile(r'(\(?\d{3}\)?[-.\s]\d{3}[-.\s]?\d{4})')
76
+ _time_re = re.compile(r'(\d\d?:\d\d(?::\d\d)?)')
77
+ _pounds_re = re.compile(r'£([\d\,]*\d+)')
78
+ _dollars_re = re.compile(r'\$([\d\.\,]*\d+)')
79
+ _decimal_number_re = re.compile(r'(\d+(?:\.\d+)+)')
80
+ _multiply_re = re.compile(r'(\d\s?\*\s?\d)')
81
+ _divide_re = re.compile(r'(\d\s?/\s?\d)')
82
+ _add_re = re.compile(r'(\d\s?\+\s?\d)')
83
+ _subtract_re = re.compile(r'(\d?\s?-\s?\d)') # also does negative numbers
84
+ _fraction_re = re.compile(r'(\d+(?:/\d+)+)')
85
+ _ordinal_re = re.compile(r'\d+(st|nd|rd|th)')
86
+ _number_re = re.compile(r'\d+')
87
+
88
+ def _expand_num_prefix(m):
89
+ match = m.group(0)
90
+ return f"number {match[1]}"
91
+
92
+ def _expand_num_suffix(m):
93
+ match = m.group(0)
94
+ if match[1].upper() == 'K': return f"{match[0]} thousand"
95
+ elif match[1].upper() == 'M': return f"{match[0]} million"
96
+ elif match[1].upper() == 'B': return f"{match[0]} billion"
97
+ elif match[1].upper() == 'T': return f"{match[0]} trillion"
98
+ return match # unexpected format
99
+
100
+ def _split_alphanumeric(m):
101
+ match = m.group(1)
102
+ return f"{match[0]} {match[1]}"
103
+
104
+ def _remove_commas(m):
105
+ return m.group(1).replace(',', '')
106
+
107
+ def _expand_date(m):
108
+ match = m.group(2)
109
+ match = re.split('[./-]', match)
110
+ return m.group(1) + ' dash '.join(match) + m.group(3)
111
+
112
+ def _expand_phone_number(m):
113
+ match = m.group(1)
114
+ match = re.sub(r'\D', '', match)
115
+ assert len(match) == 10
116
+ match = f"{' '.join(list(match[:3]))}, {' '.join(list(match[3:6]))}, {' '.join(list(match[6:]))}"
117
+ return match
118
+
119
+ def _expand_time(m):
120
+ match = m.group(1)
121
+ match = match.split(':')
122
+ if len(match) == 2:
123
+ hours, minutes = match
124
+ if minutes == '00':
125
+ if int(hours) == 0:
126
+ return '0'
127
+ elif int(hours) > 12: return f"{hours} minutes"
128
+ return f"{hours} o'clock"
129
+ elif minutes.startswith('0'):
130
+ minutes = f'oh {minutes[1:]}'
131
+ return f"{hours} {minutes}"
132
+ else:
133
+ hours, minutes, seconds = match
134
+ if int(hours) != 0:
135
+ return f"{hours} {'oh oh' if minutes == '00' else f'oh {minutes}' if minutes.startswith('0') else {minutes}} {'' if seconds == '00' else f'oh {seconds}' if seconds.startswith('0') else seconds}"
136
+ elif minutes != '00':
137
+ return f"{minutes} {'oh oh' if seconds == '00' else f'oh {seconds}' if seconds.startswith('0') else seconds}"
138
+ else:
139
+ return seconds
140
+
141
+ def _expand_dollars(m):
142
+ match = m.group(1)
143
+ parts = match.split('.')
144
+ if len(parts) > 2:
145
+ return match + ' dollars' # Unexpected format
146
+ dollars = int(parts[0]) if parts[0] else 0
147
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
148
+ if dollars and cents:
149
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
150
+ cent_unit = 'cent' if cents == 1 else 'cents'
151
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
152
+ elif dollars:
153
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
154
+ return '%s %s' % (dollars, dollar_unit)
155
+ elif cents:
156
+ cent_unit = 'cent' if cents == 1 else 'cents'
157
+ return '%s %s' % (cents, cent_unit)
158
+ else:
159
+ return 'zero dollars'
160
+
161
+ def _expand_decimal_point(m):
162
+ match = m.group(1)
163
+ match = match.split('.')
164
+ return match[0] + ' point ' + ' point '.join(' '.join(list(match[i])) for i in range(1, len(match)))
165
+
166
+ def _expand_fraction(m):
167
+ match = m.group(1)
168
+ match = match.split('/')
169
+ return ' over '.join(match) if len(match)==2 else ' slash '.join(match)
170
+
171
+ def _expand_multiply(m):
172
+ return ' times '.join(m.group(1).split('*'))
173
+
174
+ def _expand_divide(m):
175
+ return ' over '.join(m.group(1).split('/'))
176
+
177
+ def _expand_add(m):
178
+ return ' plus '.join(m.group(1).split('+'))
179
+
180
+ def _expand_subtract(m):
181
+ return ' minus '.join(m.group(1).split('-'))
182
+
183
+ def _expand_ordinal(m):
184
+ return _inflect.number_to_words(m.group(0), andword='')
185
+
186
+ def _expand_number(m):
187
+ num = int(m.group(0))
188
+ if num > 1000 and num < 3000:
189
+ if num == 2000:
190
+ return 'two thousand'
191
+ elif num > 2000 and num < 2010:
192
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
193
+ elif num % 100 == 0:
194
+ return _inflect.number_to_words(num // 100) + ' hundred'
195
+ else:
196
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
197
+ else:
198
+ return _inflect.number_to_words(num, andword='')
199
+
200
+ def normalize_numbers(text):
201
+ text = re.sub(_num_prefix_re, _expand_num_prefix, text)
202
+ text = re.sub(_num_suffix_re, _expand_num_suffix, text)
203
+ for _ in range(2): # need to do this twice to find all matches
204
+ text = re.sub(_num_letter_split_re, _split_alphanumeric, text)
205
+ text = re.sub(_comma_number_re, _remove_commas, text)
206
+ text = re.sub(_date_re, _expand_date, text)
207
+ text = re.sub(_phone_number_re, _expand_phone_number, text)
208
+ text = re.sub(_time_re, _expand_time, text)
209
+ text = re.sub(_pounds_re, r'\1 pounds', text)
210
+ text = re.sub(_dollars_re, _expand_dollars, text)
211
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
212
+ text = re.sub(_multiply_re, _expand_multiply, text)
213
+ text = re.sub(_divide_re, _expand_divide, text)
214
+ text = re.sub(_add_re, _expand_add, text)
215
+ text = re.sub(_subtract_re, _expand_subtract, text)
216
+
217
+ text = re.sub(_fraction_re, _expand_fraction, text)
218
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
219
+ text = re.sub(_number_re, _expand_number, text)
220
+ return text
221
+
222
+ ####################################################################################################
223
+ # Special characters & other patterns
224
+
225
+ _special_characters = [(re.compile(x[0]), x[1]) for x in [
226
+ ('@', ' at '),
227
+ ('&', ' and '),
228
+ ('%', ' percent '),
229
+ (':', '.'),
230
+ (';', ','),
231
+ (r'\+', ' plus '),
232
+ (r'\\', ' backslash '),
233
+ ('~', ' about '),
234
+ ('(^| )<3', ' heart '),
235
+ ('<=', ' less than or equal to '),
236
+ ('>=', ' greater than or equal to '),
237
+ ('<', ' less than '),
238
+ ('>', ' greater than '),
239
+ ('=', ' equals '),
240
+ ('/', ' slash '),
241
+ ('_', ' '),
242
+ ]]
243
+ _link_header_re = re.compile(r'(https?://)')
244
+ _dash_re = re.compile(r'(. - .)')
245
+ _dot_re = re.compile(r'([A-Z]\.[A-Z])', re.IGNORECASE)
246
+ _parentheses_re = re.compile(r'[\(\[\{].*[\)\]\}](.|$)')
247
+
248
+ def expand_special_characters(text):
249
+ for regex, replacement in _special_characters:
250
+ text = re.sub(regex, replacement, text)
251
+ return text
252
+
253
+ def _expand_link_header(m):
254
+ return 'h t t p s colon slash slash '
255
+
256
+ def _expand_dash(m):
257
+ match = m.group(0)
258
+ return f"{match[0]}, {match[4]}"
259
+
260
+ def _expand_dot(m):
261
+ match = m.group(0)
262
+ return f"{match[0]} dot {match[2]}"
263
+
264
+ def _expand_parantheses(m):
265
+ match = m.group(0)
266
+ match = re.sub(r'[\(\[\{]', ', ', match)
267
+ match = re.sub(r'[\)\]\}][^$.!?,]', ', ', match)
268
+ match = re.sub(r'[\)\]\}]', '', match)
269
+ return match
270
+
271
+ def normalize_special(text):
272
+ text = re.sub(_link_header_re, _expand_link_header, text)
273
+ text = re.sub(_dash_re, _expand_dash, text)
274
+ text = re.sub(_dot_re, _expand_dot, text)
275
+ text = re.sub(_parentheses_re, _expand_parantheses, text)
276
+ return text
277
+
278
+ ####################################################################################################
279
+ # Misc
280
+
281
+ def lowercase(text):
282
+ return text.lower()
283
+
284
+ def convert_to_ascii(text):
285
+ return unidecode(text)
286
+
287
+ def normalize_newlines(text):
288
+ text = text.split('\n')
289
+ for i in range(len(text)):
290
+ if not text[i]: continue
291
+ text[i] = text[i].strip()
292
+ if text[i][-1] not in '.!?':
293
+ text[i] = f"{text[i]}."
294
+ return ' '.join(text)
295
+
296
+ def remove_unknown_characters(text):
297
+ text = re.sub(r"[^A-Za-z !\$%&'\*\+,-./0123456789<>\?_]", "", text)
298
+ text = re.sub(r"[<>/_+]", "", text)
299
+ return text
300
+
301
+ def collapse_whitespace(text):
302
+ text = re.sub(r'\s+', ' ', text)
303
+ text = re.sub(r' [.\?!,]', lambda m: m.group(0)[1], text)
304
+ return text
305
+
306
+ def dedup_punctuation(text):
307
+ text = re.sub(r"\.\.\.+", "[ELLIPSIS]", text)
308
+ text = re.sub(r",+", ",", text)
309
+ text = re.sub(r"[\.,]*\.[\.,]*", ".", text)
310
+ text = re.sub(r"[\.,!]*![\.,!]*", "!", text)
311
+ text = re.sub(r"[\.,!\?]*\?[\.,!\?]*", "?", text)
312
+ text = re.sub("[ELLIPSIS]", "...", text)
313
+ return text
314
+
315
+ def clean_text(text):
316
+ text = convert_to_ascii(text)
317
+ text = normalize_newlines(text)
318
+ text = normalize_numbers(text)
319
+ text = normalize_special(text)
320
+ text = expand_abbreviations(text)
321
+ text = expand_special_characters(text)
322
+ text = lowercase(text)
323
+ text = remove_unknown_characters(text)
324
+ text = collapse_whitespace(text)
325
+ text = dedup_punctuation(text)
326
+ return text
327
+
328
+
329
+ if __name__ == '__main__':
330
+ print(normalize_numbers('1,2,3,456,176'))
331
+ print(normalize_numbers('123,456,789'))
332
+ print(normalize_numbers('123,456,789th'))
333
+ print(normalize_numbers('123-456-7890'))
334
+ print(normalize_numbers('111-111-1111'))
335
+ print(normalize_numbers('(111) 111-1111'))
336
+ print(normalize_numbers('A(111) 111-1111'))
337
+ print(normalize_numbers('A (111) 111-1111'))
338
+ print(normalize_numbers('$2.47'))
339
+ print(normalize_numbers('$247'))
340
+ print(normalize_numbers('$0.27'))
341
+ print(normalize_numbers('$1.00'))
342
+ print(normalize_numbers('£20'))
343
+ for i in range(1990, 2030):
344
+ print(normalize_numbers(str(i)))
345
+ print(normalize_numbers('2656'))
346
+ print(normalize_numbers('1024'))
347
+ print(normalize_numbers('2.47023'))
348
+ print(normalize_numbers('20.47023'))
349
+ print(normalize_numbers('1.17.1.1'))
350
+ print(normalize_numbers('111.111.1111'))
351
+ print(normalize_numbers('1/1/2025'))
352
+ print(normalize_numbers('1-1-2025'))
353
+ print(normalize_numbers('1-1-25'))
354
+ print(normalize_numbers('A 1/1/11 A'))
355
+ print(normalize_numbers('A 1/1 A'))
356
+ print(normalize_numbers('1/1'))
357
+ print(normalize_numbers('1/10'))
358
+ print(normalize_numbers('1/1/10'))
359
+ print(normalize_numbers('11/1/1/10'))
360
+
361
+ print(normalize_numbers('0:00'))
362
+ print(normalize_numbers('12:00'))
363
+ print(normalize_numbers('13:00'))
364
+ print(normalize_numbers('8:00'))
365
+ print(normalize_numbers('8:05'))
366
+ print(normalize_numbers('8:15'))
367
+ print(normalize_numbers('0:00:00'))
368
+ print(normalize_numbers('00:01:10'))
369
+ print(normalize_numbers('00:10:01'))
370
+ print(normalize_numbers('01:01:01'))
371
+ print(normalize_numbers('00:01:00'))
372
+ print(normalize_numbers('01:00:00'))
373
+
374
+ print(normalize_numbers('-1 + 2 * 3 - 4 / 5'))
375
+ print(normalize_numbers('-1+2*3-5/4/25'))
376
+
377
+ print(normalize_numbers('100x1'))
378
+ print(normalize_numbers('100k'))
379
+ print(normalize_numbers('100m'))
380
+ print(normalize_numbers('100b'))
381
+ print(normalize_numbers('100t'))
382
+
383
+ print(normalize_numbers('#1'))
384
+
385
+ print(normalize_numbers('12:00'))
386
+ print(normalize_numbers('11:59'))
387
+ print(normalize_numbers('01:00'))
388
+ print(normalize_numbers('0100'))
soprano/soprano/vocos/decoder.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .models import VocosBackbone
5
+ from .heads import ISTFTHead
6
+
7
+
8
+ class SopranoDecoder(nn.Module):
9
+ def __init__(self,
10
+ num_input_channels=512,
11
+ decoder_num_layers=8,
12
+ decoder_dim=512,
13
+ decoder_intermediate_dim=None,
14
+ hop_length=512,
15
+ n_fft=2048,
16
+ upscale=4,
17
+ dw_kernel=3,
18
+ ):
19
+ super().__init__()
20
+ self.decoder_initial_channels = num_input_channels
21
+ self.num_layers = decoder_num_layers
22
+ self.dim = decoder_dim
23
+ self.intermediate_dim = decoder_intermediate_dim if decoder_intermediate_dim else decoder_dim*3
24
+ self.hop_length = hop_length
25
+ self.n_fft = n_fft
26
+ self.upscale = upscale
27
+ self.dw_kernel = dw_kernel
28
+
29
+ self.decoder = VocosBackbone(input_channels=self.decoder_initial_channels,
30
+ dim=self.dim,
31
+ intermediate_dim=self.intermediate_dim,
32
+ num_layers=self.num_layers,
33
+ input_kernel_size=dw_kernel,
34
+ dw_kernel_size=dw_kernel,
35
+ )
36
+ self.head = ISTFTHead(dim=self.dim,
37
+ n_fft=self.n_fft,
38
+ hop_length=self.hop_length)
39
+
40
+ def forward(self, x):
41
+ T = x.size(2)
42
+ x = torch.nn.functional.interpolate(x, size=self.upscale*(T-1)+1, mode='linear', align_corners=True)
43
+ x = self.decoder(x)
44
+ reconstructed = self.head(x)
45
+ return reconstructed
soprano/soprano/vocos/heads.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .spectral_ops import ISTFT
4
+
5
+
6
+ class ISTFTHead(nn.Module):
7
+ """
8
+ ISTFT Head module for predicting STFT complex coefficients.
9
+
10
+ Args:
11
+ dim (int): Hidden dimension of the model.
12
+ n_fft (int): Size of Fourier transform.
13
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
14
+ the resolution of the input features.
15
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
16
+ """
17
+
18
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "center"):
19
+ super().__init__()
20
+ out_dim = n_fft + 2
21
+ self.out = torch.nn.Linear(dim, out_dim)
22
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
23
+
24
+ @torch.compiler.disable
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ Forward pass of the ISTFTHead module.
28
+
29
+ Args:
30
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
31
+ L is the sequence length, and H denotes the model dimension.
32
+
33
+ Returns:
34
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
35
+ """
36
+ x = self.out(x.transpose(1,2)).transpose(1, 2)
37
+ mag, p = x.chunk(2, dim=1)
38
+ mag = torch.exp(mag)
39
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
40
+ # wrapping happens here. These two lines produce real and imaginary value
41
+ x = torch.cos(p)
42
+ y = torch.sin(p)
43
+ # recalculating phase here does not produce anything new
44
+ # only costs time
45
+ # phase = torch.atan2(y, x)
46
+ # S = mag * torch.exp(phase * 1j)
47
+ # better directly produce the complex value
48
+ S = mag * (x + 1j * y)
49
+ audio = self.istft(S)
50
+ return audio
soprano/soprano/vocos/models.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .modules import ConvNeXtBlock
7
+
8
+ class VocosBackbone(nn.Module):
9
+ """
10
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
11
+
12
+ Args:
13
+ input_channels (int): Number of input features channels.
14
+ dim (int): Hidden dimension of the model.
15
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
16
+ num_layers (int): Number of ConvNeXtBlock layers.
17
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ input_channels: int,
23
+ dim: int,
24
+ intermediate_dim: int,
25
+ num_layers: int,
26
+ input_kernel_size: int = 9,
27
+ dw_kernel_size: int = 9,
28
+ layer_scale_init_value: Optional[float] = None,
29
+ pad: str = 'zeros',
30
+ ):
31
+ super().__init__()
32
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=input_kernel_size, padding=input_kernel_size//2, padding_mode=pad)
33
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
34
+ self.convnext = nn.ModuleList(
35
+ [
36
+ ConvNeXtBlock(
37
+ dim=dim,
38
+ intermediate_dim=intermediate_dim,
39
+ dw_kernel_size=dw_kernel_size,
40
+ layer_scale_init_value=layer_scale_init_value or 1 / num_layers**0.5,
41
+ )
42
+ for _ in range(num_layers)
43
+ ]
44
+ )
45
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
46
+ self.apply(self._init_weights)
47
+
48
+ def _init_weights(self, m):
49
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
50
+ nn.init.trunc_normal_(m.weight, std=0.02)
51
+ if m.bias is not None: nn.init.constant_(m.bias, 0)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ x = self.embed(x) # (B, C, L)
55
+ x = self.norm(x.transpose(1, 2))
56
+ x = x.transpose(1, 2)
57
+ for conv_block in self.convnext:
58
+ x = conv_block(x)
59
+ x = self.final_layer_norm(x.transpose(1, 2))
60
+ x = x.transpose(1, 2)
61
+ return x
soprano/soprano/vocos/modules.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class ConvNeXtBlock(nn.Module):
6
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
7
+
8
+ Args:
9
+ dim (int): Number of input channels.
10
+ intermediate_dim (int): Dimensionality of the intermediate layer.
11
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
12
+ Defaults to None.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ intermediate_dim: int,
19
+ layer_scale_init_value: float,
20
+ dw_kernel_size: int = 9,
21
+ ):
22
+ super().__init__()
23
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=dw_kernel_size, padding=dw_kernel_size//2, groups=dim) # depthwise conv
24
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
25
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
26
+ self.act = nn.GELU()
27
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
28
+ self.gamma = (
29
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
30
+ if layer_scale_init_value > 0
31
+ else None
32
+ )
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ residual = x
36
+ x = self.dwconv(x)
37
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
38
+ x = self.norm(x)
39
+ x = self.pwconv1(x)
40
+ x = self.act(x)
41
+ x = self.pwconv2(x)
42
+ if self.gamma is not None:
43
+ x = self.gamma * x
44
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
45
+
46
+ x = residual + x
47
+ return x
soprano/soprano/vocos/spectral_ops.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class ISTFT(nn.Module):
5
+ """
6
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
7
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
8
+ See issue: https://github.com/pytorch/pytorch/issues/62323
9
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
10
+ The NOLA constraint is met as we trim padded samples anyway.
11
+
12
+ Args:
13
+ n_fft (int): Size of Fourier transform.
14
+ hop_length (int): The distance between neighboring sliding window frames.
15
+ win_length (int): The size of window frame and STFT filter.
16
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
17
+ """
18
+
19
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
20
+ super().__init__()
21
+ if padding not in ["center", "same"]:
22
+ raise ValueError("Padding must be 'center' or 'same'.")
23
+ self.padding = padding
24
+ self.n_fft = n_fft
25
+ self.hop_length = hop_length
26
+ self.win_length = win_length
27
+ window = torch.hann_window(win_length).to('cuda')
28
+ self.register_buffer("window", window)
29
+
30
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
33
+
34
+ Args:
35
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
36
+ N is the number of frequency bins, and T is the number of time frames.
37
+
38
+ Returns:
39
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
40
+ """
41
+ if self.padding == "center":
42
+ spec[:,0] = 0 # fixes some strange bug where first/last freqs don't matter when bs<16 which causes exploding gradients
43
+ spec[:,-1] = 0
44
+ # Fallback to pytorch native implementation
45
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
46
+ elif self.padding == "same":
47
+ pad = (self.win_length - self.hop_length) // 2
48
+ else:
49
+ raise ValueError("Padding must be 'center' or 'same'.")
50
+
51
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
52
+ B, N, T = spec.shape
53
+
54
+ # Inverse FFT
55
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
56
+ ifft = ifft * self.window[None, :, None]
57
+
58
+ # Overlap and Add
59
+ output_size = (T - 1) * self.hop_length + self.win_length
60
+ y = torch.nn.functional.fold(
61
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
62
+ )[:, 0, 0, pad:-pad]
63
+
64
+ # Window envelope
65
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
66
+ window_envelope = torch.nn.functional.fold(
67
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
68
+ ).squeeze()[pad:-pad]
69
+
70
+ # Normalize
71
+ assert (window_envelope > 1e-11).all()
72
+ y = y / window_envelope
73
+
74
+ return y