Update ultravox_processing.py
Browse files- ultravox_processing.py +9 -9
ultravox_processing.py
CHANGED
|
@@ -113,7 +113,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
| 113 |
tokenizer.eos_token is not None
|
| 114 |
), "The tokenizer has no EOS token. Cannot recover."
|
| 115 |
self.vocab = tokenizer.get_vocab()
|
| 116 |
-
self.
|
| 117 |
if tokenizer.pad_token_id is None:
|
| 118 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 119 |
|
|
@@ -188,7 +188,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
| 188 |
)
|
| 189 |
is_continuation_list.append(is_continuation)
|
| 190 |
|
| 191 |
-
|
| 192 |
"audio_values": torch.stack(chunked_audio_values, dim=0),
|
| 193 |
"audio_lens": torch.tensor(
|
| 194 |
chunked_audio_lens, dtype=torch.int64, device=audio_values.device
|
|
@@ -199,12 +199,12 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
| 199 |
"audio_batch_size": torch.tensor(
|
| 200 |
[len(chunked_audio_values)], device=audio_values.device
|
| 201 |
),
|
| 202 |
-
"audio_num_chunks": (
|
| 203 |
-
torch.tensor(num_chunks, dtype=torch.int64, device=audio_values.device)
|
| 204 |
-
if include_audio_num_chunks
|
| 205 |
-
else None
|
| 206 |
-
),
|
| 207 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
def __call__(
|
| 210 |
self,
|
|
@@ -327,7 +327,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
| 327 |
split_input_ids = tokenized_parts["input_ids"]
|
| 328 |
input_ids: List[int] = []
|
| 329 |
|
| 330 |
-
|
| 331 |
|
| 332 |
for i, token_len in enumerate(data.get("audio_token_len", [])):
|
| 333 |
if not audio_is_continuation[i]:
|
|
@@ -341,7 +341,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
| 341 |
|
| 342 |
audio_token_start_idx.append(len(input_ids))
|
| 343 |
|
| 344 |
-
input_ids.extend([
|
| 345 |
|
| 346 |
# Include any tokens after the last audio.
|
| 347 |
placeholder_index += 1
|
|
|
|
| 113 |
tokenizer.eos_token is not None
|
| 114 |
), "The tokenizer has no EOS token. Cannot recover."
|
| 115 |
self.vocab = tokenizer.get_vocab()
|
| 116 |
+
self.audio_token_replacement = tokenizer.eos_token
|
| 117 |
if tokenizer.pad_token_id is None:
|
| 118 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 119 |
|
|
|
|
| 188 |
)
|
| 189 |
is_continuation_list.append(is_continuation)
|
| 190 |
|
| 191 |
+
data = {
|
| 192 |
"audio_values": torch.stack(chunked_audio_values, dim=0),
|
| 193 |
"audio_lens": torch.tensor(
|
| 194 |
chunked_audio_lens, dtype=torch.int64, device=audio_values.device
|
|
|
|
| 199 |
"audio_batch_size": torch.tensor(
|
| 200 |
[len(chunked_audio_values)], device=audio_values.device
|
| 201 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
}
|
| 203 |
+
if include_audio_num_chunks:
|
| 204 |
+
data["audio_num_chunks"] = torch.tensor(
|
| 205 |
+
num_chunks, dtype=torch.int64, device=audio_values.device
|
| 206 |
+
)
|
| 207 |
+
return data
|
| 208 |
|
| 209 |
def __call__(
|
| 210 |
self,
|
|
|
|
| 327 |
split_input_ids = tokenized_parts["input_ids"]
|
| 328 |
input_ids: List[int] = []
|
| 329 |
|
| 330 |
+
audio_token_replacement_token_id = self.vocab[self.audio_token_replacement]
|
| 331 |
|
| 332 |
for i, token_len in enumerate(data.get("audio_token_len", [])):
|
| 333 |
if not audio_is_continuation[i]:
|
|
|
|
| 341 |
|
| 342 |
audio_token_start_idx.append(len(input_ids))
|
| 343 |
|
| 344 |
+
input_ids.extend([audio_token_replacement_token_id] * token_len)
|
| 345 |
|
| 346 |
# Include any tokens after the last audio.
|
| 347 |
placeholder_index += 1
|