ekwek commited on
Commit
9402b2c
·
verified ·
1 Parent(s): 55b3415

Delete soprano/soprano

Browse files
soprano/soprano/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .tts import SopranoTTS
 
 
soprano/soprano/backends/base.py DELETED
@@ -1,20 +0,0 @@
1
- class BaseModel:
2
- def infer(self,
3
- prompts,
4
- top_p=0.95,
5
- temperature=0.3,
6
- repetition_penalty=1.2):
7
- '''
8
- Takes a list of prompts and returns the output hidden states
9
- '''
10
- pass
11
-
12
- def stream_infer(self,
13
- prompt,
14
- top_p=0.95,
15
- temperature=0.3,
16
- repetition_penalty=1.2):
17
- '''
18
- Takes a prompt and returns an iterator of the output hidden states
19
- '''
20
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/backends/lmdeploy.py DELETED
@@ -1,54 +0,0 @@
1
- import torch
2
- from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
3
- from .base import BaseModel
4
-
5
-
6
- class LMDeployModel(BaseModel):
7
- def __init__(self,
8
- device='cuda',
9
- cache_size_mb=100,
10
- **kwargs):
11
- assert device == 'cuda', "lmdeploy only supports cuda devices, consider changing device or using a different backend instead."
12
- cache_size_ratio = cache_size_mb * 1024**2 / torch.cuda.get_device_properties('cuda').total_memory
13
- backend_config = TurbomindEngineConfig(cache_max_entry_count=cache_size_ratio)
14
- self.pipeline = pipeline('ekwek/Soprano-80M',
15
- log_level='ERROR',
16
- backend_config=backend_config)
17
-
18
- def infer(self,
19
- prompts,
20
- top_p=0.95,
21
- temperature=0.3,
22
- repetition_penalty=1.2):
23
- gen_config=GenerationConfig(output_last_hidden_state='generation',
24
- do_sample=True,
25
- top_p=top_p,
26
- temperature=temperature,
27
- repetition_penalty=repetition_penalty,
28
- max_new_tokens=512)
29
- responses = self.pipeline(prompts, gen_config=gen_config)
30
- res = []
31
- for response in responses:
32
- res.append({
33
- 'finish_reason': response.finish_reason,
34
- 'hidden_state': response.last_hidden_state
35
- })
36
- return res
37
-
38
- def stream_infer(self,
39
- prompt,
40
- top_p=0.95,
41
- temperature=0.3,
42
- repetition_penalty=1.2):
43
- gen_config=GenerationConfig(output_last_hidden_state='generation',
44
- do_sample=True,
45
- top_p=top_p,
46
- temperature=temperature,
47
- repetition_penalty=repetition_penalty,
48
- max_new_tokens=512)
49
- responses = self.pipeline.stream_infer([prompt], gen_config=gen_config)
50
- for response in responses:
51
- yield {
52
- 'finish_reason': response.finish_reason,
53
- 'hidden_state': response.last_hidden_state
54
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/backends/transformers.py DELETED
@@ -1,68 +0,0 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from .base import BaseModel
4
-
5
-
6
- class TransformersModel(BaseModel):
7
- def __init__(self,
8
- device='cuda',
9
- **kwargs):
10
- self.device = device
11
-
12
- self.model = AutoModelForCausalLM.from_pretrained(
13
- 'ekwek/Soprano-80M',
14
- torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
15
- device_map=device
16
- )
17
- self.tokenizer = AutoTokenizer.from_pretrained('ekwek/Soprano-80M')
18
- self.model.eval()
19
-
20
- def infer(self,
21
- prompts,
22
- top_p=0.95,
23
- temperature=0.3,
24
- repetition_penalty=1.2):
25
- inputs = self.tokenizer(
26
- prompts,
27
- return_tensors='pt',
28
- padding=True,
29
- truncation=True,
30
- max_length=512,
31
- ).to(self.device)
32
-
33
- with torch.no_grad():
34
- outputs = self.model.generate(
35
- input_ids=inputs['input_ids'],
36
- attention_mask=inputs['attention_mask'],
37
- max_new_tokens=512,
38
- do_sample=True,
39
- top_p=top_p,
40
- temperature=temperature,
41
- repetition_penalty=repetition_penalty,
42
- pad_token_id=self.tokenizer.pad_token_id,
43
- return_dict_in_generate=True,
44
- output_hidden_states=True,
45
- )
46
- res = []
47
- eos_token_id = self.model.config.eos_token_id
48
- for i in range(len(prompts)):
49
- seq = outputs.sequences[i]
50
- hidden_states = []
51
- num_output_tokens = len(outputs.hidden_states)
52
- for j in range(num_output_tokens):
53
- token = seq[j + seq.size(0) - num_output_tokens]
54
- if token != eos_token_id: hidden_states.append(outputs.hidden_states[j][-1][i, -1, :])
55
- last_hidden_state = torch.stack(hidden_states).squeeze()
56
- finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length'
57
- res.append({
58
- 'finish_reason': finish_reason,
59
- 'hidden_state': last_hidden_state
60
- })
61
- return res
62
-
63
- def stream_infer(self,
64
- prompt,
65
- top_p=0.95,
66
- temperature=0.3,
67
- repetition_penalty=1.2):
68
- raise NotImplementedError("transformers backend does not currently support streaming, please consider using lmdeploy backend instead.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/tts.py DELETED
@@ -1,188 +0,0 @@
1
- from .vocos.decoder import SopranoDecoder
2
- from .utils.text import clean_text
3
- import torch
4
- import re
5
- from unidecode import unidecode
6
- from scipy.io import wavfile
7
- from huggingface_hub import hf_hub_download
8
- import os
9
- import time
10
-
11
-
12
- class SopranoTTS:
13
- def __init__(self,
14
- backend='auto',
15
- device='cuda',
16
- cache_size_mb=10,
17
- decoder_batch_size=1):
18
- RECOGNIZED_DEVICES = ['cuda']
19
- RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
20
- assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
21
- if backend == 'auto':
22
- if device == 'cpu':
23
- backend = 'transformers'
24
- else:
25
- try:
26
- import lmdeploy
27
- backend = 'lmdeploy'
28
- except ImportError:
29
- backend='transformers'
30
- print(f"Using backend {backend}.")
31
- assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}"
32
-
33
- if backend == 'lmdeploy':
34
- from .backends.lmdeploy import LMDeployModel
35
- self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
36
- elif backend == 'transformers':
37
- from .backends.transformers import TransformersModel
38
- self.pipeline = TransformersModel(device=device)
39
-
40
- self.decoder = SopranoDecoder().cuda()
41
- decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
42
- self.decoder.load_state_dict(torch.load(decoder_path))
43
- self.decoder_batch_size=decoder_batch_size
44
- self.RECEPTIVE_FIELD = 4 # Decoder receptive field
45
- self.TOKEN_SIZE = 2048 # Number of samples per audio token
46
-
47
- self.infer("Hello world!") # warmup
48
-
49
- def _preprocess_text(self, texts, min_length=30):
50
- '''
51
- adds prompt format and sentence/part index
52
- Enforces a minimum sentence length by merging short sentences.
53
- '''
54
- res = []
55
- for text_idx, text in enumerate(texts):
56
- text = text.strip()
57
- cleaned_text = clean_text(text)
58
- sentences = re.split(r"(?<=[.!?])\s+", cleaned_text)
59
- processed = []
60
- for sentence in sentences:
61
- processed.append({
62
- "text": sentence,
63
- "text_idx": text_idx,
64
- })
65
-
66
- if min_length > 0 and len(processed) > 1:
67
- merged = []
68
- i = 0
69
- while i < len(processed):
70
- cur = processed[i]
71
- if len(cur["text"]) < min_length:
72
- if merged: merged[-1]["text"] = (merged[-1]["text"] + " " + cur["text"]).strip()
73
- else:
74
- if i + 1 < len(processed): processed[i + 1]["text"] = (cur["text"] + " " + processed[i + 1]["text"]).strip()
75
- else: merged.append(cur)
76
- else: merged.append(cur)
77
- i += 1
78
- processed = merged
79
- sentence_idxes = {}
80
- for item in processed:
81
- if item['text_idx'] not in sentence_idxes: sentence_idxes[item['text_idx']] = 0
82
- res.append((f'[STOP][TEXT]{item["text"]}[START]', item["text_idx"], sentence_idxes[item['text_idx']]))
83
- sentence_idxes[item['text_idx']] += 1
84
- return res
85
-
86
- def infer(self,
87
- text,
88
- out_path=None,
89
- top_p=0.95,
90
- temperature=0.3,
91
- repetition_penalty=1.2):
92
- results = self.infer_batch([text],
93
- top_p=top_p,
94
- temperature=temperature,
95
- repetition_penalty=repetition_penalty,
96
- out_dir=None)[0]
97
- if out_path:
98
- wavfile.write(out_path, 32000, results.cpu().numpy())
99
- return results
100
-
101
- def infer_batch(self,
102
- texts,
103
- out_dir=None,
104
- top_p=0.95,
105
- temperature=0.3,
106
- repetition_penalty=1.2):
107
- sentence_data = self._preprocess_text(texts)
108
- prompts = list(map(lambda x: x[0], sentence_data))
109
- responses = self.pipeline.infer(prompts,
110
- top_p=top_p,
111
- temperature=temperature,
112
- repetition_penalty=repetition_penalty)
113
- hidden_states = []
114
- for i, response in enumerate(responses):
115
- if response['finish_reason'] != 'stop':
116
- print(f"Warning: some sentences did not complete generation, likely due to hallucination.")
117
- hidden_state = response['hidden_state']
118
- hidden_states.append(hidden_state)
119
- combined = list(zip(hidden_states, sentence_data))
120
- combined.sort(key=lambda x: -x[0].size(0))
121
- hidden_states, sentence_data = zip(*combined)
122
-
123
- num_texts = len(texts)
124
- audio_concat = [[] for _ in range(num_texts)]
125
- for sentence in sentence_data:
126
- audio_concat[sentence[1]].append(None)
127
- for idx in range(0, len(hidden_states), self.decoder_batch_size):
128
- batch_hidden_states = []
129
- lengths = list(map(lambda x: x.size(0), hidden_states[idx:idx+self.decoder_batch_size]))
130
- N = len(lengths)
131
- for i in range(N):
132
- batch_hidden_states.append(torch.cat([
133
- torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'),
134
- hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32),
135
- ], dim=2))
136
- batch_hidden_states = torch.cat(batch_hidden_states)
137
- with torch.no_grad():
138
- audio = self.decoder(batch_hidden_states)
139
-
140
- for i in range(N):
141
- text_id = sentence_data[idx+i][1]
142
- sentence_id = sentence_data[idx+i][2]
143
- audio_concat[text_id][sentence_id] = audio[i].squeeze()[-(lengths[i]*self.TOKEN_SIZE-self.TOKEN_SIZE):]
144
- audio_concat = [torch.cat(x).cpu() for x in audio_concat]
145
-
146
- if out_dir:
147
- os.makedirs(out_dir, exist_ok=True)
148
- for i in range(len(audio_concat)):
149
- wavfile.write(f"{out_dir}/{i}.wav", 32000, audio_concat[i].cpu().numpy())
150
- return audio_concat
151
-
152
- def infer_stream(self,
153
- text,
154
- chunk_size=1,
155
- top_p=0.95,
156
- temperature=0.3,
157
- repetition_penalty=1.2):
158
- start_time = time.time()
159
- sentence_data = self._preprocess_text([text])
160
-
161
- first_chunk = True
162
- for sentence, _, _ in sentence_data:
163
- responses = self.pipeline.stream_infer(sentence,
164
- top_p=top_p,
165
- temperature=temperature,
166
- repetition_penalty=repetition_penalty)
167
- hidden_states_buffer = []
168
- chunk_counter = chunk_size
169
- for token in responses:
170
- finished = token['finish_reason'] is not None
171
- if not finished: hidden_states_buffer.append(token['hidden_state'][-1])
172
- hidden_states_buffer = hidden_states_buffer[-(2*self.RECEPTIVE_FIELD+chunk_size):]
173
- if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
174
- if finished or chunk_counter == chunk_size:
175
- batch_hidden_states = torch.stack(hidden_states_buffer)
176
- inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32)
177
- with torch.no_grad():
178
- audio = self.decoder(inp)[0]
179
- if finished:
180
- audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_counter-1)*self.TOKEN_SIZE-self.TOKEN_SIZE):]
181
- else:
182
- audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_size)*self.TOKEN_SIZE-self.TOKEN_SIZE):-(self.RECEPTIVE_FIELD*self.TOKEN_SIZE-self.TOKEN_SIZE)]
183
- chunk_counter = 0
184
- if first_chunk:
185
- print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms")
186
- first_chunk = False
187
- yield audio_chunk.cpu()
188
- chunk_counter += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/utils/text.py DELETED
@@ -1,388 +0,0 @@
1
- """
2
- Normalize input text to a format that Soprano recognizes.
3
- Adapted from https://github.com/neonbjb/tortoise-tts/blob/main/tortoise/utils/tokenizer.py
4
- """
5
- import os
6
- import re
7
-
8
- import inflect
9
- from unidecode import unidecode
10
-
11
-
12
- _inflect = inflect.engine()
13
-
14
- ####################################################################################################
15
- # Abbreviations
16
-
17
- _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
18
- ('mrs', 'misuss'),
19
- ('ms', 'miss'),
20
- ('mr', 'mister'),
21
- ('dr', 'doctor'),
22
- ('st', 'saint'),
23
- ('co', 'company'),
24
- ('jr', 'junior'),
25
- ('maj', 'major'),
26
- ('gen', 'general'),
27
- ('drs', 'doctors'),
28
- ('rev', 'reverend'),
29
- ('lt', 'lieutenant'),
30
- ('hon', 'honorable'),
31
- ('sgt', 'sergeant'),
32
- ('capt', 'captain'),
33
- ('esq', 'esquire'),
34
- ('ltd', 'limited'),
35
- ('col', 'colonel'),
36
- ('ft', 'fort'),
37
- ]]
38
- _cased_abbreviations = [(re.compile('\\b%s\\b' % x[0]), x[1]) for x in [
39
- ('TTS', 'text to speech'),
40
- ('Hz', 'hertz'),
41
- ('kHz', 'kilohertz'),
42
- ('KBs', 'kilobytes'),
43
- ('KB', 'kilobyte'),
44
- ('MBs', 'megabytes'),
45
- ('MB', 'megabyte'),
46
- ('GBs', 'gigabytes'),
47
- ('GB', 'gigabyte'),
48
- ('TBs', 'terabytes'),
49
- ('TB', 'terabyte'),
50
- ('APIs', 'a p i\'s'),
51
- ('API', 'a p i'),
52
- ('CLIs', 'c l i\'s'),
53
- ('CLI', 'c l i'),
54
- ('CPUs', 'c p u\'s'),
55
- ('CPU', 'c p u'),
56
- ('GPUs', 'g p u\'s'),
57
- ('GPU', 'g p u'),
58
- ('Ave', 'avenue'),
59
- ]]
60
-
61
- def expand_abbreviations(text):
62
- for regex, replacement in _abbreviations + _cased_abbreviations:
63
- text = re.sub(regex, replacement, text)
64
- return text
65
-
66
- ####################################################################################################
67
- # Numbers
68
-
69
- _num_prefix_re = re.compile(r'#\d')
70
- _num_suffix_re = re.compile(r'\d(K|M|B|T)', re.IGNORECASE)
71
- _num_letter_split_re = re.compile(r'(\d[a-z]|[a-z]\d)', re.IGNORECASE)
72
-
73
- _comma_number_re = re.compile(r'(\d[\d\,]+\d)')
74
- _date_re = re.compile(r'(^|[^/])(\d\d?[/-]\d\d?[/-]\d\d(?:\d\d)?)($|[^/])')
75
- _phone_number_re = re.compile(r'(\(?\d{3}\)?[-.\s]\d{3}[-.\s]?\d{4})')
76
- _time_re = re.compile(r'(\d\d?:\d\d(?::\d\d)?)')
77
- _pounds_re = re.compile(r'£([\d\,]*\d+)')
78
- _dollars_re = re.compile(r'\$([\d\.\,]*\d+)')
79
- _decimal_number_re = re.compile(r'(\d+(?:\.\d+)+)')
80
- _multiply_re = re.compile(r'(\d\s?\*\s?\d)')
81
- _divide_re = re.compile(r'(\d\s?/\s?\d)')
82
- _add_re = re.compile(r'(\d\s?\+\s?\d)')
83
- _subtract_re = re.compile(r'(\d?\s?-\s?\d)') # also does negative numbers
84
- _fraction_re = re.compile(r'(\d+(?:/\d+)+)')
85
- _ordinal_re = re.compile(r'\d+(st|nd|rd|th)')
86
- _number_re = re.compile(r'\d+')
87
-
88
- def _expand_num_prefix(m):
89
- match = m.group(0)
90
- return f"number {match[1]}"
91
-
92
- def _expand_num_suffix(m):
93
- match = m.group(0)
94
- if match[1].upper() == 'K': return f"{match[0]} thousand"
95
- elif match[1].upper() == 'M': return f"{match[0]} million"
96
- elif match[1].upper() == 'B': return f"{match[0]} billion"
97
- elif match[1].upper() == 'T': return f"{match[0]} trillion"
98
- return match # unexpected format
99
-
100
- def _split_alphanumeric(m):
101
- match = m.group(1)
102
- return f"{match[0]} {match[1]}"
103
-
104
- def _remove_commas(m):
105
- return m.group(1).replace(',', '')
106
-
107
- def _expand_date(m):
108
- match = m.group(2)
109
- match = re.split('[./-]', match)
110
- return m.group(1) + ' dash '.join(match) + m.group(3)
111
-
112
- def _expand_phone_number(m):
113
- match = m.group(1)
114
- match = re.sub(r'\D', '', match)
115
- assert len(match) == 10
116
- match = f"{' '.join(list(match[:3]))}, {' '.join(list(match[3:6]))}, {' '.join(list(match[6:]))}"
117
- return match
118
-
119
- def _expand_time(m):
120
- match = m.group(1)
121
- match = match.split(':')
122
- if len(match) == 2:
123
- hours, minutes = match
124
- if minutes == '00':
125
- if int(hours) == 0:
126
- return '0'
127
- elif int(hours) > 12: return f"{hours} minutes"
128
- return f"{hours} o'clock"
129
- elif minutes.startswith('0'):
130
- minutes = f'oh {minutes[1:]}'
131
- return f"{hours} {minutes}"
132
- else:
133
- hours, minutes, seconds = match
134
- if int(hours) != 0:
135
- return f"{hours} {'oh oh' if minutes == '00' else f'oh {minutes}' if minutes.startswith('0') else {minutes}} {'' if seconds == '00' else f'oh {seconds}' if seconds.startswith('0') else seconds}"
136
- elif minutes != '00':
137
- return f"{minutes} {'oh oh' if seconds == '00' else f'oh {seconds}' if seconds.startswith('0') else seconds}"
138
- else:
139
- return seconds
140
-
141
- def _expand_dollars(m):
142
- match = m.group(1)
143
- parts = match.split('.')
144
- if len(parts) > 2:
145
- return match + ' dollars' # Unexpected format
146
- dollars = int(parts[0]) if parts[0] else 0
147
- cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
148
- if dollars and cents:
149
- dollar_unit = 'dollar' if dollars == 1 else 'dollars'
150
- cent_unit = 'cent' if cents == 1 else 'cents'
151
- return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
152
- elif dollars:
153
- dollar_unit = 'dollar' if dollars == 1 else 'dollars'
154
- return '%s %s' % (dollars, dollar_unit)
155
- elif cents:
156
- cent_unit = 'cent' if cents == 1 else 'cents'
157
- return '%s %s' % (cents, cent_unit)
158
- else:
159
- return 'zero dollars'
160
-
161
- def _expand_decimal_point(m):
162
- match = m.group(1)
163
- match = match.split('.')
164
- return match[0] + ' point ' + ' point '.join(' '.join(list(match[i])) for i in range(1, len(match)))
165
-
166
- def _expand_fraction(m):
167
- match = m.group(1)
168
- match = match.split('/')
169
- return ' over '.join(match) if len(match)==2 else ' slash '.join(match)
170
-
171
- def _expand_multiply(m):
172
- return ' times '.join(m.group(1).split('*'))
173
-
174
- def _expand_divide(m):
175
- return ' over '.join(m.group(1).split('/'))
176
-
177
- def _expand_add(m):
178
- return ' plus '.join(m.group(1).split('+'))
179
-
180
- def _expand_subtract(m):
181
- return ' minus '.join(m.group(1).split('-'))
182
-
183
- def _expand_ordinal(m):
184
- return _inflect.number_to_words(m.group(0), andword='')
185
-
186
- def _expand_number(m):
187
- num = int(m.group(0))
188
- if num > 1000 and num < 3000:
189
- if num == 2000:
190
- return 'two thousand'
191
- elif num > 2000 and num < 2010:
192
- return 'two thousand ' + _inflect.number_to_words(num % 100)
193
- elif num % 100 == 0:
194
- return _inflect.number_to_words(num // 100) + ' hundred'
195
- else:
196
- return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
197
- else:
198
- return _inflect.number_to_words(num, andword='')
199
-
200
- def normalize_numbers(text):
201
- text = re.sub(_num_prefix_re, _expand_num_prefix, text)
202
- text = re.sub(_num_suffix_re, _expand_num_suffix, text)
203
- for _ in range(2): # need to do this twice to find all matches
204
- text = re.sub(_num_letter_split_re, _split_alphanumeric, text)
205
- text = re.sub(_comma_number_re, _remove_commas, text)
206
- text = re.sub(_date_re, _expand_date, text)
207
- text = re.sub(_phone_number_re, _expand_phone_number, text)
208
- text = re.sub(_time_re, _expand_time, text)
209
- text = re.sub(_pounds_re, r'\1 pounds', text)
210
- text = re.sub(_dollars_re, _expand_dollars, text)
211
- text = re.sub(_decimal_number_re, _expand_decimal_point, text)
212
- text = re.sub(_multiply_re, _expand_multiply, text)
213
- text = re.sub(_divide_re, _expand_divide, text)
214
- text = re.sub(_add_re, _expand_add, text)
215
- text = re.sub(_subtract_re, _expand_subtract, text)
216
-
217
- text = re.sub(_fraction_re, _expand_fraction, text)
218
- text = re.sub(_ordinal_re, _expand_ordinal, text)
219
- text = re.sub(_number_re, _expand_number, text)
220
- return text
221
-
222
- ####################################################################################################
223
- # Special characters & other patterns
224
-
225
- _special_characters = [(re.compile(x[0]), x[1]) for x in [
226
- ('@', ' at '),
227
- ('&', ' and '),
228
- ('%', ' percent '),
229
- (':', '.'),
230
- (';', ','),
231
- (r'\+', ' plus '),
232
- (r'\\', ' backslash '),
233
- ('~', ' about '),
234
- ('(^| )<3', ' heart '),
235
- ('<=', ' less than or equal to '),
236
- ('>=', ' greater than or equal to '),
237
- ('<', ' less than '),
238
- ('>', ' greater than '),
239
- ('=', ' equals '),
240
- ('/', ' slash '),
241
- ('_', ' '),
242
- ]]
243
- _link_header_re = re.compile(r'(https?://)')
244
- _dash_re = re.compile(r'(. - .)')
245
- _dot_re = re.compile(r'([A-Z]\.[A-Z])', re.IGNORECASE)
246
- _parentheses_re = re.compile(r'[\(\[\{].*[\)\]\}](.|$)')
247
-
248
- def expand_special_characters(text):
249
- for regex, replacement in _special_characters:
250
- text = re.sub(regex, replacement, text)
251
- return text
252
-
253
- def _expand_link_header(m):
254
- return 'h t t p s colon slash slash '
255
-
256
- def _expand_dash(m):
257
- match = m.group(0)
258
- return f"{match[0]}, {match[4]}"
259
-
260
- def _expand_dot(m):
261
- match = m.group(0)
262
- return f"{match[0]} dot {match[2]}"
263
-
264
- def _expand_parantheses(m):
265
- match = m.group(0)
266
- match = re.sub(r'[\(\[\{]', ', ', match)
267
- match = re.sub(r'[\)\]\}][^$.!?,]', ', ', match)
268
- match = re.sub(r'[\)\]\}]', '', match)
269
- return match
270
-
271
- def normalize_special(text):
272
- text = re.sub(_link_header_re, _expand_link_header, text)
273
- text = re.sub(_dash_re, _expand_dash, text)
274
- text = re.sub(_dot_re, _expand_dot, text)
275
- text = re.sub(_parentheses_re, _expand_parantheses, text)
276
- return text
277
-
278
- ####################################################################################################
279
- # Misc
280
-
281
- def lowercase(text):
282
- return text.lower()
283
-
284
- def convert_to_ascii(text):
285
- return unidecode(text)
286
-
287
- def normalize_newlines(text):
288
- text = text.split('\n')
289
- for i in range(len(text)):
290
- if not text[i]: continue
291
- text[i] = text[i].strip()
292
- if text[i][-1] not in '.!?':
293
- text[i] = f"{text[i]}."
294
- return ' '.join(text)
295
-
296
- def remove_unknown_characters(text):
297
- text = re.sub(r"[^A-Za-z !\$%&'\*\+,-./0123456789<>\?_]", "", text)
298
- text = re.sub(r"[<>/_+]", "", text)
299
- return text
300
-
301
- def collapse_whitespace(text):
302
- text = re.sub(r'\s+', ' ', text)
303
- text = re.sub(r' [.\?!,]', lambda m: m.group(0)[1], text)
304
- return text
305
-
306
- def dedup_punctuation(text):
307
- text = re.sub(r"\.\.\.+", "[ELLIPSIS]", text)
308
- text = re.sub(r",+", ",", text)
309
- text = re.sub(r"[\.,]*\.[\.,]*", ".", text)
310
- text = re.sub(r"[\.,!]*![\.,!]*", "!", text)
311
- text = re.sub(r"[\.,!\?]*\?[\.,!\?]*", "?", text)
312
- text = re.sub("[ELLIPSIS]", "...", text)
313
- return text
314
-
315
- def clean_text(text):
316
- text = convert_to_ascii(text)
317
- text = normalize_newlines(text)
318
- text = normalize_numbers(text)
319
- text = normalize_special(text)
320
- text = expand_abbreviations(text)
321
- text = expand_special_characters(text)
322
- text = lowercase(text)
323
- text = remove_unknown_characters(text)
324
- text = collapse_whitespace(text)
325
- text = dedup_punctuation(text)
326
- return text
327
-
328
-
329
- if __name__ == '__main__':
330
- print(normalize_numbers('1,2,3,456,176'))
331
- print(normalize_numbers('123,456,789'))
332
- print(normalize_numbers('123,456,789th'))
333
- print(normalize_numbers('123-456-7890'))
334
- print(normalize_numbers('111-111-1111'))
335
- print(normalize_numbers('(111) 111-1111'))
336
- print(normalize_numbers('A(111) 111-1111'))
337
- print(normalize_numbers('A (111) 111-1111'))
338
- print(normalize_numbers('$2.47'))
339
- print(normalize_numbers('$247'))
340
- print(normalize_numbers('$0.27'))
341
- print(normalize_numbers('$1.00'))
342
- print(normalize_numbers('£20'))
343
- for i in range(1990, 2030):
344
- print(normalize_numbers(str(i)))
345
- print(normalize_numbers('2656'))
346
- print(normalize_numbers('1024'))
347
- print(normalize_numbers('2.47023'))
348
- print(normalize_numbers('20.47023'))
349
- print(normalize_numbers('1.17.1.1'))
350
- print(normalize_numbers('111.111.1111'))
351
- print(normalize_numbers('1/1/2025'))
352
- print(normalize_numbers('1-1-2025'))
353
- print(normalize_numbers('1-1-25'))
354
- print(normalize_numbers('A 1/1/11 A'))
355
- print(normalize_numbers('A 1/1 A'))
356
- print(normalize_numbers('1/1'))
357
- print(normalize_numbers('1/10'))
358
- print(normalize_numbers('1/1/10'))
359
- print(normalize_numbers('11/1/1/10'))
360
-
361
- print(normalize_numbers('0:00'))
362
- print(normalize_numbers('12:00'))
363
- print(normalize_numbers('13:00'))
364
- print(normalize_numbers('8:00'))
365
- print(normalize_numbers('8:05'))
366
- print(normalize_numbers('8:15'))
367
- print(normalize_numbers('0:00:00'))
368
- print(normalize_numbers('00:01:10'))
369
- print(normalize_numbers('00:10:01'))
370
- print(normalize_numbers('01:01:01'))
371
- print(normalize_numbers('00:01:00'))
372
- print(normalize_numbers('01:00:00'))
373
-
374
- print(normalize_numbers('-1 + 2 * 3 - 4 / 5'))
375
- print(normalize_numbers('-1+2*3-5/4/25'))
376
-
377
- print(normalize_numbers('100x1'))
378
- print(normalize_numbers('100k'))
379
- print(normalize_numbers('100m'))
380
- print(normalize_numbers('100b'))
381
- print(normalize_numbers('100t'))
382
-
383
- print(normalize_numbers('#1'))
384
-
385
- print(normalize_numbers('12:00'))
386
- print(normalize_numbers('11:59'))
387
- print(normalize_numbers('01:00'))
388
- print(normalize_numbers('0100'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/vocos/decoder.py DELETED
@@ -1,45 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- from .models import VocosBackbone
5
- from .heads import ISTFTHead
6
-
7
-
8
- class SopranoDecoder(nn.Module):
9
- def __init__(self,
10
- num_input_channels=512,
11
- decoder_num_layers=8,
12
- decoder_dim=512,
13
- decoder_intermediate_dim=None,
14
- hop_length=512,
15
- n_fft=2048,
16
- upscale=4,
17
- dw_kernel=3,
18
- ):
19
- super().__init__()
20
- self.decoder_initial_channels = num_input_channels
21
- self.num_layers = decoder_num_layers
22
- self.dim = decoder_dim
23
- self.intermediate_dim = decoder_intermediate_dim if decoder_intermediate_dim else decoder_dim*3
24
- self.hop_length = hop_length
25
- self.n_fft = n_fft
26
- self.upscale = upscale
27
- self.dw_kernel = dw_kernel
28
-
29
- self.decoder = VocosBackbone(input_channels=self.decoder_initial_channels,
30
- dim=self.dim,
31
- intermediate_dim=self.intermediate_dim,
32
- num_layers=self.num_layers,
33
- input_kernel_size=dw_kernel,
34
- dw_kernel_size=dw_kernel,
35
- )
36
- self.head = ISTFTHead(dim=self.dim,
37
- n_fft=self.n_fft,
38
- hop_length=self.hop_length)
39
-
40
- def forward(self, x):
41
- T = x.size(2)
42
- x = torch.nn.functional.interpolate(x, size=self.upscale*(T-1)+1, mode='linear', align_corners=True)
43
- x = self.decoder(x)
44
- reconstructed = self.head(x)
45
- return reconstructed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/vocos/heads.py DELETED
@@ -1,50 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from .spectral_ops import ISTFT
4
-
5
-
6
- class ISTFTHead(nn.Module):
7
- """
8
- ISTFT Head module for predicting STFT complex coefficients.
9
-
10
- Args:
11
- dim (int): Hidden dimension of the model.
12
- n_fft (int): Size of Fourier transform.
13
- hop_length (int): The distance between neighboring sliding window frames, which should align with
14
- the resolution of the input features.
15
- padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
16
- """
17
-
18
- def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "center"):
19
- super().__init__()
20
- out_dim = n_fft + 2
21
- self.out = torch.nn.Linear(dim, out_dim)
22
- self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
23
-
24
- @torch.compiler.disable
25
- def forward(self, x: torch.Tensor) -> torch.Tensor:
26
- """
27
- Forward pass of the ISTFTHead module.
28
-
29
- Args:
30
- x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
31
- L is the sequence length, and H denotes the model dimension.
32
-
33
- Returns:
34
- Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
35
- """
36
- x = self.out(x.transpose(1,2)).transpose(1, 2)
37
- mag, p = x.chunk(2, dim=1)
38
- mag = torch.exp(mag)
39
- mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
40
- # wrapping happens here. These two lines produce real and imaginary value
41
- x = torch.cos(p)
42
- y = torch.sin(p)
43
- # recalculating phase here does not produce anything new
44
- # only costs time
45
- # phase = torch.atan2(y, x)
46
- # S = mag * torch.exp(phase * 1j)
47
- # better directly produce the complex value
48
- S = mag * (x + 1j * y)
49
- audio = self.istft(S)
50
- return audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/vocos/models.py DELETED
@@ -1,61 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- from torch import nn
5
-
6
- from .modules import ConvNeXtBlock
7
-
8
- class VocosBackbone(nn.Module):
9
- """
10
- Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
11
-
12
- Args:
13
- input_channels (int): Number of input features channels.
14
- dim (int): Hidden dimension of the model.
15
- intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
16
- num_layers (int): Number of ConvNeXtBlock layers.
17
- layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
18
- """
19
-
20
- def __init__(
21
- self,
22
- input_channels: int,
23
- dim: int,
24
- intermediate_dim: int,
25
- num_layers: int,
26
- input_kernel_size: int = 9,
27
- dw_kernel_size: int = 9,
28
- layer_scale_init_value: Optional[float] = None,
29
- pad: str = 'zeros',
30
- ):
31
- super().__init__()
32
- self.embed = nn.Conv1d(input_channels, dim, kernel_size=input_kernel_size, padding=input_kernel_size//2, padding_mode=pad)
33
- self.norm = nn.LayerNorm(dim, eps=1e-6)
34
- self.convnext = nn.ModuleList(
35
- [
36
- ConvNeXtBlock(
37
- dim=dim,
38
- intermediate_dim=intermediate_dim,
39
- dw_kernel_size=dw_kernel_size,
40
- layer_scale_init_value=layer_scale_init_value or 1 / num_layers**0.5,
41
- )
42
- for _ in range(num_layers)
43
- ]
44
- )
45
- self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
46
- self.apply(self._init_weights)
47
-
48
- def _init_weights(self, m):
49
- if isinstance(m, (nn.Conv1d, nn.Linear)):
50
- nn.init.trunc_normal_(m.weight, std=0.02)
51
- if m.bias is not None: nn.init.constant_(m.bias, 0)
52
-
53
- def forward(self, x: torch.Tensor) -> torch.Tensor:
54
- x = self.embed(x) # (B, C, L)
55
- x = self.norm(x.transpose(1, 2))
56
- x = x.transpose(1, 2)
57
- for conv_block in self.convnext:
58
- x = conv_block(x)
59
- x = self.final_layer_norm(x.transpose(1, 2))
60
- x = x.transpose(1, 2)
61
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/vocos/modules.py DELETED
@@ -1,47 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class ConvNeXtBlock(nn.Module):
6
- """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
7
-
8
- Args:
9
- dim (int): Number of input channels.
10
- intermediate_dim (int): Dimensionality of the intermediate layer.
11
- layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
12
- Defaults to None.
13
- """
14
-
15
- def __init__(
16
- self,
17
- dim: int,
18
- intermediate_dim: int,
19
- layer_scale_init_value: float,
20
- dw_kernel_size: int = 9,
21
- ):
22
- super().__init__()
23
- self.dwconv = nn.Conv1d(dim, dim, kernel_size=dw_kernel_size, padding=dw_kernel_size//2, groups=dim) # depthwise conv
24
- self.norm = nn.LayerNorm(dim, eps=1e-6)
25
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
26
- self.act = nn.GELU()
27
- self.pwconv2 = nn.Linear(intermediate_dim, dim)
28
- self.gamma = (
29
- nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
30
- if layer_scale_init_value > 0
31
- else None
32
- )
33
-
34
- def forward(self, x: torch.Tensor) -> torch.Tensor:
35
- residual = x
36
- x = self.dwconv(x)
37
- x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
38
- x = self.norm(x)
39
- x = self.pwconv1(x)
40
- x = self.act(x)
41
- x = self.pwconv2(x)
42
- if self.gamma is not None:
43
- x = self.gamma * x
44
- x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
45
-
46
- x = residual + x
47
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
soprano/soprano/vocos/spectral_ops.py DELETED
@@ -1,74 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- class ISTFT(nn.Module):
5
- """
6
- Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
7
- windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
8
- See issue: https://github.com/pytorch/pytorch/issues/62323
9
- Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
10
- The NOLA constraint is met as we trim padded samples anyway.
11
-
12
- Args:
13
- n_fft (int): Size of Fourier transform.
14
- hop_length (int): The distance between neighboring sliding window frames.
15
- win_length (int): The size of window frame and STFT filter.
16
- padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
17
- """
18
-
19
- def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
20
- super().__init__()
21
- if padding not in ["center", "same"]:
22
- raise ValueError("Padding must be 'center' or 'same'.")
23
- self.padding = padding
24
- self.n_fft = n_fft
25
- self.hop_length = hop_length
26
- self.win_length = win_length
27
- window = torch.hann_window(win_length).to('cuda')
28
- self.register_buffer("window", window)
29
-
30
- def forward(self, spec: torch.Tensor) -> torch.Tensor:
31
- """
32
- Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
33
-
34
- Args:
35
- spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
36
- N is the number of frequency bins, and T is the number of time frames.
37
-
38
- Returns:
39
- Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
40
- """
41
- if self.padding == "center":
42
- spec[:,0] = 0 # fixes some strange bug where first/last freqs don't matter when bs<16 which causes exploding gradients
43
- spec[:,-1] = 0
44
- # Fallback to pytorch native implementation
45
- return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
46
- elif self.padding == "same":
47
- pad = (self.win_length - self.hop_length) // 2
48
- else:
49
- raise ValueError("Padding must be 'center' or 'same'.")
50
-
51
- assert spec.dim() == 3, "Expected a 3D tensor as input"
52
- B, N, T = spec.shape
53
-
54
- # Inverse FFT
55
- ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
56
- ifft = ifft * self.window[None, :, None]
57
-
58
- # Overlap and Add
59
- output_size = (T - 1) * self.hop_length + self.win_length
60
- y = torch.nn.functional.fold(
61
- ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
62
- )[:, 0, 0, pad:-pad]
63
-
64
- # Window envelope
65
- window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
66
- window_envelope = torch.nn.functional.fold(
67
- window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
68
- ).squeeze()[pad:-pad]
69
-
70
- # Normalize
71
- assert (window_envelope > 1e-11).all()
72
- y = y / window_envelope
73
-
74
- return y