# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨��🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨��🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/gpt_oss/modular_gpt_oss.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_gpt_oss.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨��🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨��🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Union import torch from torch import nn from torch.nn import functional as F from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.integrations.hub_kernels import use_kernel_forward_from_hub from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_layers import ( GenericForSequenceClassification, GenericForTokenClassification, GradientCheckpointingLayer, ) from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import OutputRecorder, check_model_inputs from .configuration_gpt_oss import GptOssConfig @use_kernel_forward_from_hub("RMSNorm") class GptOssRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ GptOssRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return (self.weight * hidden_states).to(input_dtype) # main diff with Llama def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: """ When training it is more efficient to just loop over the experts and compute the output for each expert as otherwise the memory would explode. For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. Args: hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) selected_experts (torch.Tensor): (batch_size * token_num, top_k) routing_weights (torch.Tensor): (batch_size * token_num, num_experts) Returns: torch.Tensor """ batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) num_experts = routing_weights.shape[1] if hidden_states.device.type == "cpu" or self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) expert_mask = expert_mask.permute(2, 1, 0) # we sum on the top_k and on the sequence length to get which experts # are hit this time around expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit[:]: # expert_idx only have 1 element, so we can use scale for fast indexing expert_idx = expert_idx[0] with torch.no_grad(): _, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] gate, up = gate_up[..., ::2], gate_up[..., 1::2] gate = gate.clamp(min=None, max=self.limit) up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) gated_output = (up + 1) * glu out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] weighted_output = out * routing_weights[token_idx, expert_idx, None] next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) else: hidden_states = hidden_states.repeat(num_experts, 1) hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up[..., ::2], gate_up[..., 1::2] gate = gate.clamp(min=None, max=self.limit) up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] next_states = next_states.sum(dim=0) return next_states class GptOssTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) self.bias = nn.Parameter(torch.empty(self.num_experts)) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) return router_scores, router_indices @use_kernel_forward_from_hub("MegaBlocksMoeMLP") class GptOssMLP(nn.Module): def __init__(self, config): super().__init__() self.router = GptOssTopKRouter(config) self.experts = GptOssExperts(config) def forward(self, hidden_states): router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) return routed_out, router_scores class GptOssRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: GptOssConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = freqs cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(x.dtype), sin.to(x.dtype) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: first_half, second_half = torch.chunk(x, 2, dim=-1) first_ = first_half * cos - second_half * sin second_ = second_half * cos + first_half * sin return torch.cat((first_, second_), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = _apply_rotary_emb(q, cos, sin) k_embed = _apply_rotary_emb(k, cos, sin) return q_embed, k_embed def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) combined_logits = torch.cat([attn_weights, sinks], dim=-1) # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 # when training with bsz>1 we clamp max values. combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) scores = probs[..., :-1] # we drop the sink here attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class GptOssAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: GptOssConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, s_aux=self.sinks, # diff with Llama **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class GptOssDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: GptOssConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx) self.mlp = GptOssMLP(config) self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores hidden_states = residual + hidden_states return hidden_states @auto_docstring class GptOssPreTrainedModel(PreTrainedModel): config: GptOssConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["GptOssDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = False _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "router_logits": OutputRecorder(GptOssTopKRouter, index=0), "hidden_states": GptOssDecoderLayer, "attentions": GptOssAttention, } _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] _supports_flash_attention = False _supports_flex_attention = False def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Parameter): module.data.normal_(mean=0.0, std=std) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, GptOssRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, GptOssExperts): module.gate_up_proj.data.normal_(mean=0.0, std=std) module.gate_up_proj_bias.data.zero_() module.down_proj.data.normal_(mean=0.0, std=std) module.down_proj_bias.data.zero_() elif isinstance(module, GptOssAttention): module.sinks.data.normal_(mean=0.0, std=std) elif isinstance(module, GptOssTopKRouter): module.weight.data.normal_(mean=0.0, std=std) module.bias.data.normal_(mean=0.0, std=std) @auto_docstring class GptOssModel(GptOssPreTrainedModel): _no_split_modules = ["GptOssDecoderLayer"] def __init__(self, config: GptOssConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [GptOssDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GptOssRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, } causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) def load_balancing_loss_func( gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], num_experts: Optional[int] = None, top_k=2, attention_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, int]: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. Args: gate_logits: Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of shape [batch_size X sequence_length, num_experts]. num_experts: Number of experts top_k: The number of experts to route per-token, can be also interpreted as the `top-k` routing parameter. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. Returns: The auxiliary loss. """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts @auto_docstring class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = GptOssModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, GptOssForCausalLM >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_router_logits=output_router_logits, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) class GptOssForSequenceClassification(GenericForSequenceClassification, GptOssPreTrainedModel): pass class GptOssForTokenClassification(GenericForTokenClassification, GptOssPreTrainedModel): pass __all__ = [ "GptOssForCausalLM", "GptOssForSequenceClassification", "GptOssForTokenClassification", "GptOssModel", "GptOssPreTrainedModel", ]