๐Ÿ–จ๏ธ Printing Instructions: Press Ctrl/Cmd + P and select "Save as PDF".
1

GPT-OSS: Model Implementation

Mixture of Experts, Attention Sinks, Sliding Window, YaRN RoPE

2

Where We Are

3

Part 1: Architecture Overview

4

The Full Model at a Glance

text
Token IDs (1D) [n_tokens]
 โ”‚
 โ–ผ
Embedding [n_tokens, 2880]
 โ”‚
 โ–ผ
โ”Œโ”€ TransformerBlock ร— 36 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                                                    โ”‚
โ”‚  โ”Œโ”€ AttentionBlock โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”‚
โ”‚  โ”‚ RMSNorm โ†’ QKV Proj โ†’ RoPE(YaRN) โ†’           โ”‚  โ”‚
โ”‚  โ”‚ SDPA(sinks, sliding_window) โ†’ Out Proj       โ”‚  โ”‚
โ”‚  โ”‚ + Residual                                   โ”‚  โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ”‚
โ”‚                                                    โ”‚
โ”‚  โ”Œโ”€ MLPBlock (MoE) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”‚
โ”‚  โ”‚ RMSNorm โ†’ Router(top-4 of 128 experts) โ†’    โ”‚  โ”‚
โ”‚  โ”‚ Expert MLP1 โ†’ SwiGLU โ†’ Expert MLP2 โ†’        โ”‚  โ”‚
โ”‚  โ”‚ Weighted Sum + Residual                      โ”‚  โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ”‚
โ”‚                                                    โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
 โ”‚
 โ–ผ
Final RMSNorm [n_tokens, 2880]
 โ”‚
 โ–ผ
Unembedding (Linear) [n_tokens, 201088]
5

What's Different from the "Generic Modern Transformer"?

6

Part 2: ModelConfig โ€” Every Hyperparameter

7

ModelConfig: The Blueprint

python
@dataclass
class ModelConfig:
    num_hidden_layers: int = 36       # Number of TransformerBlocks
    num_experts: int = 128            # Total experts per MoE layer
    experts_per_token: int = 4        # Active experts per token (top-k)
    vocab_size: int = 201088          # Tokenizer vocabulary size
    hidden_size: int = 2880           # Residual stream width (d_model)
    intermediate_size: int = 2880     # Expert FFN intermediate dimension
    swiglu_limit: float = 7.0         # Activation clamping threshold
    head_dim: int = 64                # Dimension per attention head
    num_attention_heads: int = 64     # Query heads (total)
    num_key_value_heads: int = 8      # KV heads (GQA groups)
    sliding_window: int = 128         # Local attention window (tokens)
    initial_context_length: int = 4096  # Training context length (for YaRN)
    rope_theta: float = 150000.0      # RoPE base frequency
    rope_scaling_factor: float = 32.0 # YaRN context extension factor
    rope_ntk_alpha: float = 1.0       # NTK-by-parts lower bound
    rope_ntk_beta: float = 32.0       # NTK-by-parts upper bound
8

ModelConfig: Derived Dimensions

9

Part 3: RMSNorm โ€” Quick Review

10

RMSNorm in GPT-OSS: Two Implementation Details

11

RMSNorm: The GPT-OSS Implementation

python
class RMSNorm(torch.nn.Module):
    def __init__(
        self, num_features: int, eps: float = 1e-05, device: torch.device | None = None
    ):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.scale = torch.nn.Parameter(
            torch.ones(num_features, device=device, dtype=torch.float32)  # โ† Always float32
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape[-1] == self.num_features
        t, dtype = x.float(), x.dtype       # โ‘  Upcast input to float32
        t = t * torch.rsqrt(                 # โ‘ก rsqrt = 1/sqrt (fused, faster)
            torch.mean(t**2, dim=-1, keepdim=True)  # โ‘ข RMSยฒ over last dim
            + self.eps
        )
        return (t * self.scale).to(dtype)    # โ‘ฃ Scale by ฮณ, downcast to bfloat16

# Shape trace (no batch dim!):
# Input x:              (n_tokens, 2880) in bfloat16
# t = x.float():        (n_tokens, 2880) in float32
# mean(tยฒ, dim=-1, keepdim=True): (n_tokens, 1)
# rsqrt(...):           (n_tokens, 1)
# t * rsqrt:            (n_tokens, 2880) โ€” broadcasted
# * self.scale:         (n_tokens, 2880) ร— (2880,) โ†’ (n_tokens, 2880)
# .to(dtype):           back to bfloat16
12

Part 4: YaRN RoPE โ€” Context Length Extension

13

Why YaRN? The Context Extension Problem

14

YaRN: Three Frequency Regions (NTK-by-Parts)

15

YaRN: Computing Inverse Frequencies

python
def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
    """See YaRN paper: https://arxiv.org/abs/2309.00071"""
    # Base frequencies: ฮธ_i = base^(2i/d) for i in [0, d/2)
    freq = self.base ** (
        torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)
        / self.head_dim
    )  # shape: (d/2,) = (32,) for head_dim=64

    if self.scaling_factor > 1.0:
        concentration = (
            0.1 * math.log(self.scaling_factor) + 1.0
        )  # YaRN concentration  (โ‰ˆ 1.347 for factor=32)

        d_half = self.head_dim / 2  # = 32
        # NTK by parts โ€” compute boundary indices for the three regions
        low = (
            d_half
            * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
            / math.log(self.base)
        )
        high = (
            d_half
            * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
            / math.log(self.base)
        )
        assert 0 < low < high < d_half - 1  # Must be valid indices

        interpolation = 1.0 / (self.scaling_factor * freq)  # Scaled-down freqs
        extrapolation = 1.0 / freq                          # Original freqs

        # Linear ramp: 0 below 'low', 1 above 'high', linear blend in between
        ramp = (
            torch.arange(d_half, dtype=torch.float32, device=freq.device) - low
        ) / (high - low)
        mask = 1 - ramp.clamp(0, 1)  # 1 โ†’ extrapolate, 0 โ†’ interpolate

        # Blend: low-i dims use extrapolation, high-i dims use interpolation
        inv_freq = interpolation * (1 - mask) + extrapolation * mask
    else:
        concentration = 1.0
        inv_freq = 1.0 / freq  # Standard RoPE

    return concentration, inv_freq
16

YaRN: Computing cos/sin and Applying RoPE

python
# --- RotaryEmbedding methods ---
def _compute_cos_sin(self, num_tokens: int):
    concentration, inv_freq = self._compute_concentration_and_inv_freq()
    t = torch.arange(num_tokens, dtype=torch.float32, device=self.device)
    freqs = torch.einsum("i,j->ij", t, inv_freq)  # Outer product: (T, d/2)
    cos = freqs.cos() * concentration               # Scale by YaRN concentration
    sin = freqs.sin() * concentration
    return cos, sin  # Both shape: (num_tokens, d/2)

def forward(self, query, key):
    num_tokens = query.shape[0]
    cos, sin = self._compute_cos_sin(num_tokens)  # (T, 32)

    query_shape = query.shape                      # Save original: (T, 8, 8, 64)
    query = query.view(num_tokens, -1, self.head_dim)  # Flatten heads: (T, 64, 64)
    query = _apply_rotary_emb(query, cos, sin)
    query = query.reshape(query_shape)             # Restore: (T, 8, 8, 64)

    key_shape = key.shape                          # Save original: (T, 8, 64)
    key = key.view(num_tokens, -1, self.head_dim)  # (T, 8, 64) โ€” unchanged
    key = _apply_rotary_emb(key, cos, sin)
    key = key.reshape(key_shape)                   # Restore: (T, 8, 64)
    return query, key

# --- The rotation itself (module-level function) ---
def _apply_rotary_emb(x, cos, sin):
    cos = cos.unsqueeze(-2).to(x.dtype)  # (T, 1, d/2) โ€” broadcast over heads
    sin = sin.unsqueeze(-2).to(x.dtype)
    x1, x2 = torch.chunk(x, 2, dim=-1)  # Split last dim in half: each (T, H, d/2)
    o1 = x1 * cos - x2 * sin            # 2D rotation formula
    o2 = x2 * cos + x1 * sin
    return torch.cat((o1, o2), dim=-1)   # Reassemble: (T, H, d)

# This is the "split-half" RoPE variant (not interleaved pairs).
17

Part 5: Custom Attention โ€” Sinks and Sliding Window

18

Attention Sinks: Why They Exist

19

Sliding Window Attention: Local vs Global

20

Quick Reference: Einstein Summation (einsum)

21

Custom SDPA: Full Implementation

python
def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
    # sliding_window == 0 means no sliding window
    # Q: (T, n_kv_heads, q_mult, d_head) โ€” q_mult = n_q_heads / n_kv_heads = 8
    # K: (T, n_kv_heads, d_head)          โ€” one K per KV-head group
    # V: (T, n_kv_heads, d_head)          โ€” one V per KV-head group
    # S: (n_q_heads,) = (64,)             โ€” attention sink logits
    n_tokens, n_heads, q_mult, d_head = Q.shape  # n_heads = n_kv_heads = 8
    assert K.shape == (n_tokens, n_heads, d_head)
    assert V.shape == (n_tokens, n_heads, d_head)

    # โ‘  Broadcast K, V to match Q's group dimension
    K = K[:, :, None, :].expand(-1, -1, q_mult, -1)  # (T, 8, 8, 64)
    V = V[:, :, None, :].expand(-1, -1, q_mult, -1)  # (T, 8, 8, 64)

    # โ‘ก Reshape sinks: one scalar per (kv_head, q_group_member)
    S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)  # (8, 8, T, 1)

    # โ‘ข Causal mask (upper triangle = -inf)
    mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
    if sliding_window > 0:  # Add sliding window mask (lower triangle beyond window)
        mask += torch.tril(
            mask.new_full((n_tokens, n_tokens), -float("inf")),
            diagonal=-sliding_window
        )

    # โ‘ฃ Attention scores + sinks
    QK = torch.einsum("qhmd,khmd->hmqk", Q, K)  # (8, 8, T, T)
    QK *= sm_scale                                # Scale by 1/โˆšd
    QK += mask[None, None, :, :]                  # Apply causal + window mask
    QK = torch.cat([QK, S], dim=-1)               # (8, 8, T, T+1) โ† sink column!

    # โ‘ค Softmax over T+1 positions (including sink)
    W = torch.softmax(QK, dim=-1)                 # (8, 8, T, T+1)
    W = W[..., :-1]                               # (8, 8, T, T) โ† remove sink
    # Now W sums to โ‰ค 1 per query (some mass went to the sink)

    # โ‘ฅ Weighted sum of values
    attn = torch.einsum("hmqk,khmd->qhmd", W, V)  # (T, 8, 8, 64)
    return attn.reshape(n_tokens, -1)              # (T, 4096)

# The sink absorbs "unused" attention โ€” after removing the sink column,
# attention weights per query sum to โ‰ค 1.0 instead of exactly 1.0.
# This lets the model gracefully ignore irrelevant context.
22

Part 6: AttentionBlock โ€” Putting It Together

23

AttentionBlock: __init__

python
class AttentionBlock(torch.nn.Module):
    def __init__(self, config: ModelConfig, layer_idx: int, device: torch.device | None = None):
        super().__init__()
        self.config = config
        self.norm = RMSNorm(config.hidden_size, device=device)

        # QKV fused projection: 2880 โ†’ 5120
        # = head_dim ร— (num_attention_heads + 2 ร— num_key_value_heads)
        # = 64 ร— (64 + 16) = 5120
        self.qkv = torch.nn.Linear(
            config.hidden_size,
            config.head_dim * (config.num_attention_heads + 2 * config.num_key_value_heads),
            device=device,
        )

        # Output projection: 4096 โ†’ 2880
        self.out = torch.nn.Linear(
            config.head_dim * config.num_attention_heads,  # 64 ร— 64 = 4096
            config.hidden_size,                            # 2880
            device=device,
        )

        # Rotary embeddings (YaRN-extended)
        self.rotary_emb = RotaryEmbedding(config, device=device)

        # Attention sinks: one learnable scalar per Q-head
        self.S = torch.nn.Parameter(torch.zeros(config.num_attention_heads, device=device))

        # Even layers โ†’ sliding window (128); Odd layers โ†’ full causal (0)
        self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0
        self.sm_scale = config.head_dim ** -0.5  # 1/โˆš64 = 0.125
24

AttentionBlock: forward โ€” Shape Trace

python
def forward(self, x: torch.Tensor) -> torch.Tensor:
    n_tokens = x.shape[0]  # x: (T, 2880)
    cfg = self.config
    n_q = cfg.num_attention_heads      # 64
    n_kv = cfg.num_key_value_heads     # 8
    d = cfg.head_dim                   # 64
    q_per_kv = n_q // n_kv             # 8 (GQA group size)

    # โ‘  Pre-norm + fused QKV projection
    qkv = self.qkv(self.norm(x))       # (T, 2880) โ†’ (T, 5120)

    # โ‘ก Slice into Q, K, V
    q = qkv[:, : n_q * d]                          # (T, 4096)
    k = qkv[:, n_q * d : (n_q + n_kv) * d]         # (T, 512)
    v = qkv[:, (n_q + n_kv) * d :]                  # (T, 512)

    # โ‘ข Reshape for GQA: Q gets an extra group dimension
    q = q.view(n_tokens, n_kv, q_per_kv, d)  # (T, 8, 8, 64)
    k = k.view(n_tokens, n_kv, d)             # (T, 8, 64)
    v = v.view(n_tokens, n_kv, d)             # (T, 8, 64)

    # โ‘ฃ Apply YaRN RoPE to Q and K
    q, k = self.rotary_emb(q, k)      # Shapes unchanged

    # โ‘ค Custom SDPA with sinks + sliding window
    attn = sdpa(q, k, v, self.S, self.sm_scale, self.sliding_window)  # (T, 4096)

    # โ‘ฅ Output projection + residual connection
    return x + self.out(attn)          # (T, 4096) โ†’ (T, 2880), then + residual
25

AttentionBlock: Key Observations

26

Part 7: SwiGLU Activation โ€” The Gating Mechanism

27

SwiGLU in GPT-OSS: Not Quite Textbook

28

SwiGLU: The GPT-OSS Implementation

python
def swiglu(x: torch.Tensor, limit: float) -> torch.Tensor:
    # x shape: (..., intermediate_size * 2)  e.g., (T, 4, 5760)
    # Split into gate and linear paths via interleaving
    x = x.unflatten(-1, (-1, 2))     # (..., 2880, 2)  โ€” reshape last dim
    x0 = x[..., 0]                   # (..., 2880) โ€” gate path (even indices)
    x1 = x[..., 1]                   # (..., 2880) โ€” linear path (odd indices)

    x0 = x0.clamp(-limit, limit)     # Clamp gate to [-7.0, 7.0]
    x1 = x1.clamp(-limit, limit)     # Clamp linear to [-7.0, 7.0]

    return (
        x0
        * torch.sigmoid(x0 * 1.702)  # Scaled sigmoid โ‰ˆ GELU gate
        * (x1 + 1)                   # Linear path with +1 bias
    )

# Shape trace:
# Input:  (T, 4, 5760)   โ€” from MLP1 (per expert, top-4)
# unflatten: (T, 4, 2880, 2)  โ€” split interleaved gate/linear
# x0, x1: (T, 4, 2880)   โ€” each is half the intermediate_size
# Output: (T, 4, 2880)   โ€” ready for MLP2
#
# Note: x0 appears twice โ€” both as the sigmoid input AND
# multiplied by the sigmoid output. This is the "Swish" pattern:
# Swish(x) = x ยท ฯƒ(ฮฑx)
29

Part 8: Mixture of Experts (MoE) โ€” The MLPBlock

30

MoE: Why Sparse Experts?

31

MLPBlock: __init__ โ€” Expert Parameters

python
class MLPBlock(torch.nn.Module):
    def __init__(self, config: ModelConfig, device: torch.device | None = None):
        super().__init__()
        self.config = config
        self.norm = RMSNorm(config.hidden_size, device=device)  # Pre-norm

        # Router: linear layer, no bias, scores each of 128 experts
        self.gate = torch.nn.Linear(
            config.hidden_size,     # 2880
            config.num_experts,     # 128
            device=device,
        )

        # Expert parameters โ€” raw tensors, not nn.Linear!
        E = config.num_experts           # 128
        I = config.intermediate_size * 2 # 5760 (ร—2 for interleaved gate+linear)
        H = config.hidden_size           # 2880

        # MLP1: up-projection per expert (2880 โ†’ 5760)
        self.mlp1_weight = torch.nn.Parameter(torch.empty(E, I, H, device=device))
        self.mlp1_bias   = torch.nn.Parameter(torch.empty(E, I, device=device))

        # MLP2: down-projection per expert (2880 โ†’ 2880)
        self.mlp2_weight = torch.nn.Parameter(torch.empty(E, H, H, device=device))
        self.mlp2_bias   = torch.nn.Parameter(torch.empty(E, H, device=device))

# Note: These are (128, ...) tensors โ€” all experts stored together.
# At forward time, we index into them to extract only the top-4 experts.
# This is more memory-efficient than 128 separate nn.Linear modules
# and allows batched operations via einsum.
32

MLPBlock: forward โ€” Routing and Expert Computation

python
def forward(self, x: torch.Tensor) -> torch.Tensor:
    t = self.norm(x)                    # (T, 2880) โ€” pre-norm

    # โ‘  Route: score all 128 experts per token
    gate = torch.softmax(               # Softmax over experts
        self.gate(t), dim=-1            # (T, 2880) โ†’ (T, 128)
    )

    # โ‘ก Select top-4 experts per token
    expert_weights, expert_indices = torch.topk(
        gate, self.config.experts_per_token, dim=-1  # โ†’ (T, 4) each
    )
    expert_weights = expert_weights / expert_weights.sum(-1, keepdim=True)  # Renormalize

    # โ‘ข Gather selected expert parameters
    mlp1_w = self.mlp1_weight[expert_indices]  # (T, 4, 5760, 2880)
    mlp1_b = self.mlp1_bias[expert_indices]    # (T, 4, 5760)
    mlp2_w = self.mlp2_weight[expert_indices]  # (T, 4, 2880, 2880)
    mlp2_b = self.mlp2_bias[expert_indices]    # (T, 4, 2880)

    # โ‘ฃ MLP1: up-projection per expert
    t = torch.einsum("beck,bk->bec", mlp1_w, t) + mlp1_b  # (T, 4, 5760)

    # โ‘ค SwiGLU activation
    t = swiglu(t, self.config.swiglu_limit)                # (T, 4, 2880)

    # โ‘ฅ MLP2: down-projection per expert
    t = torch.einsum("beck,bec->bek", mlp2_w, t) + mlp2_b  # (T, 4, 2880)

    # โ‘ฆ Weighted combination of expert outputs + residual
    t = torch.einsum("bec,be->bc", t, expert_weights)       # (T, 2880)
    return x + t  # Residual connection
33

MLPBlock: Detailed Shape Trace

34

Part 9: The Full Transformer

35

TransformerBlock and Transformer: Assembly

python
class TransformerBlock(torch.nn.Module):
    """One layer: attention + MoE, with residual connections."""
    def __init__(self, config: ModelConfig, layer_idx: int, device: torch.device | None = None):
        super().__init__()
        self.attention = AttentionBlock(config, layer_idx, device=device)
        self.mlp = MLPBlock(config, device=device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.attention(x)  # Residual is inside AttentionBlock
        x = self.mlp(x)        # Residual is inside MLPBlock
        return x


class Transformer(torch.nn.Module):
    def __init__(self, config: ModelConfig, device: torch.device | None = None):
        super().__init__()
        self.config = config

        # Token embedding: 201088 โ†’ 2880
        self.embedding = torch.nn.Embedding(
            config.vocab_size, config.hidden_size, device=device
        )

        # 36 transformer blocks
        self.block = torch.nn.ModuleList([
            TransformerBlock(config, layer_idx=i, device=device)
            for i in range(config.num_hidden_layers)
        ])

        # Final layer norm
        self.norm = RMSNorm(config.hidden_size, device=device)

        # Unembedding: 2880 โ†’ 201088 (NOT tied with embedding)
        self.unembed = torch.nn.Linear(
            config.hidden_size, config.vocab_size, device=device
        )

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        x = self.embedding(token_ids)  # (T,) โ†’ (T, 2880)
        for block in self.block:       # 36 layers
            x = block(x)               # (T, 2880) โ†’ (T, 2880)
        return self.unembed(self.norm(x))  # (T, 2880) โ†’ (T, 201088)
36

Transformer: End-to-End Data Flow

37

Part 10: Weight Loading โ€” from_checkpoint

38

Weight Loading: Overview

39

from_checkpoint: The Loading Code

python
@classmethod
def from_checkpoint(
    cls, checkpoint_path: str, device: str, world_size: int = 1, rank: int = 0
) -> "Transformer":
    config = ModelConfig()
    if not isinstance(device, torch.device):
        device = torch.device(device)
    model = cls(config, device=device)  # โ‘  Create model (random params on device)
    model.eval()                         # โ‘ก Set to eval mode

    per_rank = config.intermediate_size * 2 // world_size  # 5760 // ws
    offset = rank * per_rank                                # Start index for this GPU

    with Checkpoint(checkpoint_path) as cp:  # โ‘ข Open MXFP4 checkpoint
        for name, param in model.named_parameters():
            full_tensor = cp[name]          # Load & upcast from checkpoint

            if "mlp1" in name:              # Shard MLP1 along intermediate dim
                full_tensor = full_tensor[..., offset : offset + per_rank, :]
                # mlp1_weight: (128, 5760, 2880) โ†’ (..., per_rank, 2880)
                # mlp1_bias:   (128, 5760)       โ†’ (..., per_rank)
                # The ellipsis handles both weight (3D) and bias (2D)

            elif "mlp2_weight" in name:     # Shard MLP2 weight along last dim
                full_tensor = full_tensor[..., offset : offset + per_rank]
                # (128, 2880, 2880) โ†’ (128, 2880, per_rank)

            # mlp2_bias is NOT sharded โ€” needed in full after all_reduce

            try:
                param.data.copy_(full_tensor)
            except Exception as e:
                print(f"Error: {name}: {e} ({param.shape} vs {full_tensor.shape})")
                raise
    return model
40

Expert Parallelism: How Sharding Works

41

Part 11: Token Generation

42

TokenGenerator: __init__ โ€” Model Initialization

python
class TokenGenerator:
    @torch.inference_mode()  # โ† Disables gradient tracking entirely
    def __init__(self, checkpoint_path: str, device: str):
        self.device = device
        self.model = Transformer.from_checkpoint(checkpoint_path, device)

    # That's it! The constructor:
    #   1. Loads the model from an MXFP4 checkpoint
    #   2. Moves everything to the specified device (e.g., "cuda:0")
    #   3. @torch.inference_mode() means no autograd overhead during loading
    #
    # Note what's NOT here:
    #   - No tokenizer! TokenGenerator works with raw token IDs
    #   - No KV-cache initialization
    #   - No batch size configuration
    #   - Tokenization is handled externally before calling generate()
43

TokenGenerator: generate โ€” The Autoregressive Loop

python
@torch.inference_mode()
def generate(
    self,
    prompt_tokens: list[int],     # Already-tokenized input
    stop_tokens: list[int],       # Token IDs that signal end-of-generation
    temperature: float = 1.0,     # Sampling temperature (0 = greedy)
    max_tokens: int = 0,          # 0 means unlimited
    return_logprobs: bool = False,
):
    tokens = list(prompt_tokens)
    num_generated = 0

    while max_tokens == 0 or num_generated < max_tokens:
        # โ‘  Run FULL model on ALL tokens (no KV-cache!)
        input_ids = torch.as_tensor(tokens, dtype=torch.int32, device=self.device)
        logits = self.model(input_ids)  # (T, 201088) โ€” logits for every position
        logits = logits[-1]              # (201088,)  โ€” only last position matters

        # โ‘ก Sample or argmax
        if temperature == 0:            # Greedy decoding
            token = logits.argmax(-1).item()
        else:                           # Temperature sampling
            probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
            token = torch.multinomial(probs, 1).item()

        tokens.append(token)
        num_generated += 1

        # โ‘ข Yield the token (and optionally log-prob)
        if return_logprobs:
            log_probs = torch.log_softmax(logits, dim=-1)
            yield token, log_probs[token].item()
        else:
            yield token

        # โ‘ฃ Stop if we hit a stop token
        if token in stop_tokens:
            break
44

TokenGenerator: Design Decisions

45

Generation Trace: Step by Step

46

Part 12: End-to-End Architecture Diagram

47

Complete Data Flow: Token to Token

text
External Tokenizer (not in model.py)
 โ”‚
 โ–ผ
Token IDs: [101, 2003, 1037, 6251, 102]     shape: (5,)
 โ”‚
 โ–ผ
โ”Œโ”€ Transformer.forward() โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                                                                โ”‚
โ”‚  Embedding lookup                                (5,) โ†’ (5, 2880)
โ”‚  โ”‚                                                             โ”‚
โ”‚  โ–ผ                                                             โ”‚
โ”‚  TransformerBlock 0 (sliding window W=128)                     โ”‚
โ”‚   โ”œโ”€ AttentionBlock: RMSNorm โ†’ QKV(5120) โ†’ RoPE โ†’ SDPA โ†’ Out  โ”‚
โ”‚   โ””โ”€ MLPBlock:       RMSNorm โ†’ Route(top-4/128) โ†’ Experts โ†’ ฮฃ โ”‚
โ”‚  โ”‚                                                             โ”‚
โ”‚  TransformerBlock 1 (full causal attention)                    โ”‚
โ”‚   โ”œโ”€ AttentionBlock: RMSNorm โ†’ QKV(5120) โ†’ RoPE โ†’ SDPA โ†’ Out  โ”‚
โ”‚   โ””โ”€ MLPBlock:       RMSNorm โ†’ Route(top-4/128) โ†’ Experts โ†’ ฮฃ โ”‚
โ”‚  โ”‚                                                             โ”‚
โ”‚  ... (36 blocks total, alternating window/full) ...            โ”‚
โ”‚  โ”‚                                                             โ”‚
โ”‚  TransformerBlock 35 (full causal)                             โ”‚
โ”‚  โ”‚                                                             โ”‚
โ”‚  โ–ผ                                                             โ”‚
โ”‚  Final RMSNorm                               (5, 2880) โ†’ (5, 2880)
โ”‚  โ”‚                                                             โ”‚
โ”‚  โ–ผ                                                             โ”‚
โ”‚  Unembedding (Linear)                        (5, 2880) โ†’ (5, 201088)
โ”‚                                                                โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
 โ”‚
 โ–ผ
Logits[:, -1, :] โ†’ softmax โ†’ sample โ†’ next token ID
48

Part 13: Summary

49

What We Covered Today

50

Key Takeaways