Spaces:
Running
on
Zero
Running
on
Zero
Update util.py
Browse files
util.py
CHANGED
|
@@ -73,8 +73,8 @@ class NemoAudioPlayer:
|
|
| 73 |
).eval()
|
| 74 |
|
| 75 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 76 |
-
print(f"Moving NeMo codec to device: {self.device}")
|
| 77 |
-
self.nemo_codec_model.to(self.device)
|
| 78 |
|
| 79 |
self.text_tokenizer_name = text_tokenizer_name
|
| 80 |
if self.text_tokenizer_name:
|
|
@@ -145,7 +145,12 @@ class NemoAudioPlayer:
|
|
| 145 |
return text
|
| 146 |
|
| 147 |
def get_waveform(self, out_ids):
|
|
|
|
| 148 |
"""Convert model output to audio waveform"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
out_ids = out_ids.flatten()
|
| 150 |
|
| 151 |
# Validate output
|
|
@@ -223,7 +228,7 @@ class KaniModel:
|
|
| 223 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 224 |
self.conf.model_name,
|
| 225 |
dtype=torch.bfloat16,
|
| 226 |
-
device_map=self.conf.device_map,
|
| 227 |
trust_remote_code=True
|
| 228 |
)
|
| 229 |
|
|
@@ -264,6 +269,7 @@ class KaniModel:
|
|
| 264 |
rp: float,
|
| 265 |
max_tok: int) -> torch.tensor:
|
| 266 |
"""Generate tokens using the model"""
|
|
|
|
| 267 |
input_ids = input_ids.to(self.device)
|
| 268 |
attention_mask = attention_mask.to(self.device)
|
| 269 |
|
|
|
|
| 73 |
).eval()
|
| 74 |
|
| 75 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 76 |
+
# print(f"Moving NeMo codec to device: {self.device}")
|
| 77 |
+
# self.nemo_codec_model.to(self.device)
|
| 78 |
|
| 79 |
self.text_tokenizer_name = text_tokenizer_name
|
| 80 |
if self.text_tokenizer_name:
|
|
|
|
| 145 |
return text
|
| 146 |
|
| 147 |
def get_waveform(self, out_ids):
|
| 148 |
+
|
| 149 |
"""Convert model output to audio waveform"""
|
| 150 |
+
|
| 151 |
+
print(f"Moving NeMo codec to device: {self.device}")
|
| 152 |
+
self.nemo_codec_model.to(self.device)
|
| 153 |
+
|
| 154 |
out_ids = out_ids.flatten()
|
| 155 |
|
| 156 |
# Validate output
|
|
|
|
| 228 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 229 |
self.conf.model_name,
|
| 230 |
dtype=torch.bfloat16,
|
| 231 |
+
# device_map=self.conf.device_map,
|
| 232 |
trust_remote_code=True
|
| 233 |
)
|
| 234 |
|
|
|
|
| 269 |
rp: float,
|
| 270 |
max_tok: int) -> torch.tensor:
|
| 271 |
"""Generate tokens using the model"""
|
| 272 |
+
self.model.to(self.device)
|
| 273 |
input_ids = input_ids.to(self.device)
|
| 274 |
attention_mask = attention_mask.to(self.device)
|
| 275 |
|