Nirupam Biswas commited on
Commit
72efa30
·
1 Parent(s): f583b83

Enhance device compatibility by adding MPS support and ensuring tensor operations respect device context

Browse files
Files changed (1) hide show
  1. modeling_deepseekv2.py +16 -1
modeling_deepseekv2.py CHANGED
@@ -109,10 +109,11 @@ class DeepseekV2RMSNorm(nn.Module):
109
 
110
  def forward(self, hidden_states):
111
  input_dtype = hidden_states.dtype
 
112
  hidden_states = hidden_states.to(torch.float32)
113
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
114
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
115
- return self.weight * hidden_states.to(input_dtype)
116
 
117
 
118
  ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
@@ -1468,6 +1469,8 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
1468
  super().__init__(config)
1469
  self.padding_idx = config.pad_token_id
1470
  self.vocab_size = config.vocab_size
 
 
1471
 
1472
  self.embed_tokens = nn.Embedding(
1473
  config.vocab_size, config.hidden_size, self.padding_idx
@@ -1652,6 +1655,9 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1652
  self.model = DeepseekV2Model(config)
1653
  self.vocab_size = config.vocab_size
1654
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
1655
 
1656
  # Initialize weights and apply final processing
1657
  self.post_init()
@@ -1782,6 +1788,15 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1782
  inputs_embeds=None,
1783
  **kwargs,
1784
  ):
 
 
 
 
 
 
 
 
 
1785
  past_length = 0
1786
  if past_key_values is not None:
1787
  if isinstance(past_key_values, Cache):
 
109
 
110
  def forward(self, hidden_states):
111
  input_dtype = hidden_states.dtype
112
+ device = hidden_states.device
113
  hidden_states = hidden_states.to(torch.float32)
114
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
115
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
116
+ return (self.weight.to(device) * hidden_states).to(input_dtype)
117
 
118
 
119
  ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
 
1469
  super().__init__(config)
1470
  self.padding_idx = config.pad_token_id
1471
  self.vocab_size = config.vocab_size
1472
+ # Set device to MPS if available, otherwise fallback to CPU
1473
+ self.device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
1474
 
1475
  self.embed_tokens = nn.Embedding(
1476
  config.vocab_size, config.hidden_size, self.padding_idx
 
1655
  self.model = DeepseekV2Model(config)
1656
  self.vocab_size = config.vocab_size
1657
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1658
+ # Move model to MPS if available
1659
+ if torch.backends.mps.is_available():
1660
+ self.to("mps")
1661
 
1662
  # Initialize weights and apply final processing
1663
  self.post_init()
 
1788
  inputs_embeds=None,
1789
  **kwargs,
1790
  ):
1791
+ # Move inputs to MPS device if available
1792
+ if torch.backends.mps.is_available():
1793
+ if input_ids is not None:
1794
+ input_ids = input_ids.to("mps")
1795
+ if attention_mask is not None:
1796
+ attention_mask = attention_mask.to("mps")
1797
+ if inputs_embeds is not None:
1798
+ inputs_embeds = inputs_embeds.to("mps")
1799
+
1800
  past_length = 0
1801
  if past_key_values is not None:
1802
  if isinstance(past_key_values, Cache):