Nirupam Biswas
commited on
Commit
·
72efa30
1
Parent(s):
f583b83
Enhance device compatibility by adding MPS support and ensuring tensor operations respect device context
Browse files- modeling_deepseekv2.py +16 -1
modeling_deepseekv2.py
CHANGED
|
@@ -109,10 +109,11 @@ class DeepseekV2RMSNorm(nn.Module):
|
|
| 109 |
|
| 110 |
def forward(self, hidden_states):
|
| 111 |
input_dtype = hidden_states.dtype
|
|
|
|
| 112 |
hidden_states = hidden_states.to(torch.float32)
|
| 113 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 114 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 115 |
-
return self.weight * hidden_states.to(input_dtype)
|
| 116 |
|
| 117 |
|
| 118 |
ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
|
|
@@ -1468,6 +1469,8 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
|
|
| 1468 |
super().__init__(config)
|
| 1469 |
self.padding_idx = config.pad_token_id
|
| 1470 |
self.vocab_size = config.vocab_size
|
|
|
|
|
|
|
| 1471 |
|
| 1472 |
self.embed_tokens = nn.Embedding(
|
| 1473 |
config.vocab_size, config.hidden_size, self.padding_idx
|
|
@@ -1652,6 +1655,9 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
|
|
| 1652 |
self.model = DeepseekV2Model(config)
|
| 1653 |
self.vocab_size = config.vocab_size
|
| 1654 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
| 1655 |
|
| 1656 |
# Initialize weights and apply final processing
|
| 1657 |
self.post_init()
|
|
@@ -1782,6 +1788,15 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
|
|
| 1782 |
inputs_embeds=None,
|
| 1783 |
**kwargs,
|
| 1784 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1785 |
past_length = 0
|
| 1786 |
if past_key_values is not None:
|
| 1787 |
if isinstance(past_key_values, Cache):
|
|
|
|
| 109 |
|
| 110 |
def forward(self, hidden_states):
|
| 111 |
input_dtype = hidden_states.dtype
|
| 112 |
+
device = hidden_states.device
|
| 113 |
hidden_states = hidden_states.to(torch.float32)
|
| 114 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 115 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 116 |
+
return (self.weight.to(device) * hidden_states).to(input_dtype)
|
| 117 |
|
| 118 |
|
| 119 |
ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
|
|
|
|
| 1469 |
super().__init__(config)
|
| 1470 |
self.padding_idx = config.pad_token_id
|
| 1471 |
self.vocab_size = config.vocab_size
|
| 1472 |
+
# Set device to MPS if available, otherwise fallback to CPU
|
| 1473 |
+
self.device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
|
| 1474 |
|
| 1475 |
self.embed_tokens = nn.Embedding(
|
| 1476 |
config.vocab_size, config.hidden_size, self.padding_idx
|
|
|
|
| 1655 |
self.model = DeepseekV2Model(config)
|
| 1656 |
self.vocab_size = config.vocab_size
|
| 1657 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1658 |
+
# Move model to MPS if available
|
| 1659 |
+
if torch.backends.mps.is_available():
|
| 1660 |
+
self.to("mps")
|
| 1661 |
|
| 1662 |
# Initialize weights and apply final processing
|
| 1663 |
self.post_init()
|
|
|
|
| 1788 |
inputs_embeds=None,
|
| 1789 |
**kwargs,
|
| 1790 |
):
|
| 1791 |
+
# Move inputs to MPS device if available
|
| 1792 |
+
if torch.backends.mps.is_available():
|
| 1793 |
+
if input_ids is not None:
|
| 1794 |
+
input_ids = input_ids.to("mps")
|
| 1795 |
+
if attention_mask is not None:
|
| 1796 |
+
attention_mask = attention_mask.to("mps")
|
| 1797 |
+
if inputs_embeds is not None:
|
| 1798 |
+
inputs_embeds = inputs_embeds.to("mps")
|
| 1799 |
+
|
| 1800 |
past_length = 0
|
| 1801 |
if past_key_values is not None:
|
| 1802 |
if isinstance(past_key_values, Cache):
|