Simonlob commited on
Commit
29847a4
·
verified ·
1 Parent(s): 073a996

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +9 -3
util.py CHANGED
@@ -73,8 +73,8 @@ class NemoAudioPlayer:
73
  ).eval()
74
 
75
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
76
- print(f"Moving NeMo codec to device: {self.device}")
77
- self.nemo_codec_model.to(self.device)
78
 
79
  self.text_tokenizer_name = text_tokenizer_name
80
  if self.text_tokenizer_name:
@@ -145,7 +145,12 @@ class NemoAudioPlayer:
145
  return text
146
 
147
  def get_waveform(self, out_ids):
 
148
  """Convert model output to audio waveform"""
 
 
 
 
149
  out_ids = out_ids.flatten()
150
 
151
  # Validate output
@@ -223,7 +228,7 @@ class KaniModel:
223
  self.model = AutoModelForCausalLM.from_pretrained(
224
  self.conf.model_name,
225
  dtype=torch.bfloat16,
226
- device_map=self.conf.device_map,
227
  trust_remote_code=True
228
  )
229
 
@@ -264,6 +269,7 @@ class KaniModel:
264
  rp: float,
265
  max_tok: int) -> torch.tensor:
266
  """Generate tokens using the model"""
 
267
  input_ids = input_ids.to(self.device)
268
  attention_mask = attention_mask.to(self.device)
269
 
 
73
  ).eval()
74
 
75
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
76
+ # print(f"Moving NeMo codec to device: {self.device}")
77
+ # self.nemo_codec_model.to(self.device)
78
 
79
  self.text_tokenizer_name = text_tokenizer_name
80
  if self.text_tokenizer_name:
 
145
  return text
146
 
147
  def get_waveform(self, out_ids):
148
+
149
  """Convert model output to audio waveform"""
150
+
151
+ print(f"Moving NeMo codec to device: {self.device}")
152
+ self.nemo_codec_model.to(self.device)
153
+
154
  out_ids = out_ids.flatten()
155
 
156
  # Validate output
 
228
  self.model = AutoModelForCausalLM.from_pretrained(
229
  self.conf.model_name,
230
  dtype=torch.bfloat16,
231
+ # device_map=self.conf.device_map,
232
  trust_remote_code=True
233
  )
234
 
 
269
  rp: float,
270
  max_tok: int) -> torch.tensor:
271
  """Generate tokens using the model"""
272
+ self.model.to(self.device)
273
  input_ids = input_ids.to(self.device)
274
  attention_mask = attention_mask.to(self.device)
275