Ctrl/Cmd + P and select "Save as PDF".
Token IDs (1D) [n_tokens] โ โผ Embedding [n_tokens, 2880] โ โผ โโ TransformerBlock ร 36 โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ โ โโ AttentionBlock โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ โ RMSNorm โ QKV Proj โ RoPE(YaRN) โ โ โ โ โ SDPA(sinks, sliding_window) โ Out Proj โ โ โ โ + Residual โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ โ โ โโ MLPBlock (MoE) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ โ RMSNorm โ Router(top-4 of 128 experts) โ โ โ โ โ Expert MLP1 โ SwiGLU โ Expert MLP2 โ โ โ โ โ Weighted Sum + Residual โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โผ Final RMSNorm [n_tokens, 2880] โ โผ Unembedding (Linear) [n_tokens, 201088]
@dataclass
class ModelConfig:
num_hidden_layers: int = 36 # Number of TransformerBlocks
num_experts: int = 128 # Total experts per MoE layer
experts_per_token: int = 4 # Active experts per token (top-k)
vocab_size: int = 201088 # Tokenizer vocabulary size
hidden_size: int = 2880 # Residual stream width (d_model)
intermediate_size: int = 2880 # Expert FFN intermediate dimension
swiglu_limit: float = 7.0 # Activation clamping threshold
head_dim: int = 64 # Dimension per attention head
num_attention_heads: int = 64 # Query heads (total)
num_key_value_heads: int = 8 # KV heads (GQA groups)
sliding_window: int = 128 # Local attention window (tokens)
initial_context_length: int = 4096 # Training context length (for YaRN)
rope_theta: float = 150000.0 # RoPE base frequency
rope_scaling_factor: float = 32.0 # YaRN context extension factor
rope_ntk_alpha: float = 1.0 # NTK-by-parts lower bound
rope_ntk_beta: float = 32.0 # NTK-by-parts upper boundclass RMSNorm(torch.nn.Module):
def __init__(
self, num_features: int, eps: float = 1e-05, device: torch.device | None = None
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.scale = torch.nn.Parameter(
torch.ones(num_features, device=device, dtype=torch.float32) # โ Always float32
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == self.num_features
t, dtype = x.float(), x.dtype # โ Upcast input to float32
t = t * torch.rsqrt( # โก rsqrt = 1/sqrt (fused, faster)
torch.mean(t**2, dim=-1, keepdim=True) # โข RMSยฒ over last dim
+ self.eps
)
return (t * self.scale).to(dtype) # โฃ Scale by ฮณ, downcast to bfloat16
# Shape trace (no batch dim!):
# Input x: (n_tokens, 2880) in bfloat16
# t = x.float(): (n_tokens, 2880) in float32
# mean(tยฒ, dim=-1, keepdim=True): (n_tokens, 1)
# rsqrt(...): (n_tokens, 1)
# t * rsqrt: (n_tokens, 2880) โ broadcasted
# * self.scale: (n_tokens, 2880) ร (2880,) โ (n_tokens, 2880)
# .to(dtype): back to bfloat16def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
"""See YaRN paper: https://arxiv.org/abs/2309.00071"""
# Base frequencies: ฮธ_i = base^(2i/d) for i in [0, d/2)
freq = self.base ** (
torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)
/ self.head_dim
) # shape: (d/2,) = (32,) for head_dim=64
if self.scaling_factor > 1.0:
concentration = (
0.1 * math.log(self.scaling_factor) + 1.0
) # YaRN concentration (โ 1.347 for factor=32)
d_half = self.head_dim / 2 # = 32
# NTK by parts โ compute boundary indices for the three regions
low = (
d_half
* math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
/ math.log(self.base)
)
high = (
d_half
* math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
/ math.log(self.base)
)
assert 0 < low < high < d_half - 1 # Must be valid indices
interpolation = 1.0 / (self.scaling_factor * freq) # Scaled-down freqs
extrapolation = 1.0 / freq # Original freqs
# Linear ramp: 0 below 'low', 1 above 'high', linear blend in between
ramp = (
torch.arange(d_half, dtype=torch.float32, device=freq.device) - low
) / (high - low)
mask = 1 - ramp.clamp(0, 1) # 1 โ extrapolate, 0 โ interpolate
# Blend: low-i dims use extrapolation, high-i dims use interpolation
inv_freq = interpolation * (1 - mask) + extrapolation * mask
else:
concentration = 1.0
inv_freq = 1.0 / freq # Standard RoPE
return concentration, inv_freq# --- RotaryEmbedding methods ---
def _compute_cos_sin(self, num_tokens: int):
concentration, inv_freq = self._compute_concentration_and_inv_freq()
t = torch.arange(num_tokens, dtype=torch.float32, device=self.device)
freqs = torch.einsum("i,j->ij", t, inv_freq) # Outer product: (T, d/2)
cos = freqs.cos() * concentration # Scale by YaRN concentration
sin = freqs.sin() * concentration
return cos, sin # Both shape: (num_tokens, d/2)
def forward(self, query, key):
num_tokens = query.shape[0]
cos, sin = self._compute_cos_sin(num_tokens) # (T, 32)
query_shape = query.shape # Save original: (T, 8, 8, 64)
query = query.view(num_tokens, -1, self.head_dim) # Flatten heads: (T, 64, 64)
query = _apply_rotary_emb(query, cos, sin)
query = query.reshape(query_shape) # Restore: (T, 8, 8, 64)
key_shape = key.shape # Save original: (T, 8, 64)
key = key.view(num_tokens, -1, self.head_dim) # (T, 8, 64) โ unchanged
key = _apply_rotary_emb(key, cos, sin)
key = key.reshape(key_shape) # Restore: (T, 8, 64)
return query, key
# --- The rotation itself (module-level function) ---
def _apply_rotary_emb(x, cos, sin):
cos = cos.unsqueeze(-2).to(x.dtype) # (T, 1, d/2) โ broadcast over heads
sin = sin.unsqueeze(-2).to(x.dtype)
x1, x2 = torch.chunk(x, 2, dim=-1) # Split last dim in half: each (T, H, d/2)
o1 = x1 * cos - x2 * sin # 2D rotation formula
o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2), dim=-1) # Reassemble: (T, H, d)
# This is the "split-half" RoPE variant (not interleaved pairs).
def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
# sliding_window == 0 means no sliding window
# Q: (T, n_kv_heads, q_mult, d_head) โ q_mult = n_q_heads / n_kv_heads = 8
# K: (T, n_kv_heads, d_head) โ one K per KV-head group
# V: (T, n_kv_heads, d_head) โ one V per KV-head group
# S: (n_q_heads,) = (64,) โ attention sink logits
n_tokens, n_heads, q_mult, d_head = Q.shape # n_heads = n_kv_heads = 8
assert K.shape == (n_tokens, n_heads, d_head)
assert V.shape == (n_tokens, n_heads, d_head)
# โ Broadcast K, V to match Q's group dimension
K = K[:, :, None, :].expand(-1, -1, q_mult, -1) # (T, 8, 8, 64)
V = V[:, :, None, :].expand(-1, -1, q_mult, -1) # (T, 8, 8, 64)
# โก Reshape sinks: one scalar per (kv_head, q_group_member)
S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1) # (8, 8, T, 1)
# โข Causal mask (upper triangle = -inf)
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0: # Add sliding window mask (lower triangle beyond window)
mask += torch.tril(
mask.new_full((n_tokens, n_tokens), -float("inf")),
diagonal=-sliding_window
)
# โฃ Attention scores + sinks
QK = torch.einsum("qhmd,khmd->hmqk", Q, K) # (8, 8, T, T)
QK *= sm_scale # Scale by 1/โd
QK += mask[None, None, :, :] # Apply causal + window mask
QK = torch.cat([QK, S], dim=-1) # (8, 8, T, T+1) โ sink column!
# โค Softmax over T+1 positions (including sink)
W = torch.softmax(QK, dim=-1) # (8, 8, T, T+1)
W = W[..., :-1] # (8, 8, T, T) โ remove sink
# Now W sums to โค 1 per query (some mass went to the sink)
# โฅ Weighted sum of values
attn = torch.einsum("hmqk,khmd->qhmd", W, V) # (T, 8, 8, 64)
return attn.reshape(n_tokens, -1) # (T, 4096)
# The sink absorbs "unused" attention โ after removing the sink column,
# attention weights per query sum to โค 1.0 instead of exactly 1.0.
# This lets the model gracefully ignore irrelevant context.class AttentionBlock(torch.nn.Module):
def __init__(self, config: ModelConfig, layer_idx: int, device: torch.device | None = None):
super().__init__()
self.config = config
self.norm = RMSNorm(config.hidden_size, device=device)
# QKV fused projection: 2880 โ 5120
# = head_dim ร (num_attention_heads + 2 ร num_key_value_heads)
# = 64 ร (64 + 16) = 5120
self.qkv = torch.nn.Linear(
config.hidden_size,
config.head_dim * (config.num_attention_heads + 2 * config.num_key_value_heads),
device=device,
)
# Output projection: 4096 โ 2880
self.out = torch.nn.Linear(
config.head_dim * config.num_attention_heads, # 64 ร 64 = 4096
config.hidden_size, # 2880
device=device,
)
# Rotary embeddings (YaRN-extended)
self.rotary_emb = RotaryEmbedding(config, device=device)
# Attention sinks: one learnable scalar per Q-head
self.S = torch.nn.Parameter(torch.zeros(config.num_attention_heads, device=device))
# Even layers โ sliding window (128); Odd layers โ full causal (0)
self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0
self.sm_scale = config.head_dim ** -0.5 # 1/โ64 = 0.125def forward(self, x: torch.Tensor) -> torch.Tensor:
n_tokens = x.shape[0] # x: (T, 2880)
cfg = self.config
n_q = cfg.num_attention_heads # 64
n_kv = cfg.num_key_value_heads # 8
d = cfg.head_dim # 64
q_per_kv = n_q // n_kv # 8 (GQA group size)
# โ Pre-norm + fused QKV projection
qkv = self.qkv(self.norm(x)) # (T, 2880) โ (T, 5120)
# โก Slice into Q, K, V
q = qkv[:, : n_q * d] # (T, 4096)
k = qkv[:, n_q * d : (n_q + n_kv) * d] # (T, 512)
v = qkv[:, (n_q + n_kv) * d :] # (T, 512)
# โข Reshape for GQA: Q gets an extra group dimension
q = q.view(n_tokens, n_kv, q_per_kv, d) # (T, 8, 8, 64)
k = k.view(n_tokens, n_kv, d) # (T, 8, 64)
v = v.view(n_tokens, n_kv, d) # (T, 8, 64)
# โฃ Apply YaRN RoPE to Q and K
q, k = self.rotary_emb(q, k) # Shapes unchanged
# โค Custom SDPA with sinks + sliding window
attn = sdpa(q, k, v, self.S, self.sm_scale, self.sliding_window) # (T, 4096)
# โฅ Output projection + residual connection
return x + self.out(attn) # (T, 4096) โ (T, 2880), then + residualdef swiglu(x: torch.Tensor, limit: float) -> torch.Tensor:
# x shape: (..., intermediate_size * 2) e.g., (T, 4, 5760)
# Split into gate and linear paths via interleaving
x = x.unflatten(-1, (-1, 2)) # (..., 2880, 2) โ reshape last dim
x0 = x[..., 0] # (..., 2880) โ gate path (even indices)
x1 = x[..., 1] # (..., 2880) โ linear path (odd indices)
x0 = x0.clamp(-limit, limit) # Clamp gate to [-7.0, 7.0]
x1 = x1.clamp(-limit, limit) # Clamp linear to [-7.0, 7.0]
return (
x0
* torch.sigmoid(x0 * 1.702) # Scaled sigmoid โ GELU gate
* (x1 + 1) # Linear path with +1 bias
)
# Shape trace:
# Input: (T, 4, 5760) โ from MLP1 (per expert, top-4)
# unflatten: (T, 4, 2880, 2) โ split interleaved gate/linear
# x0, x1: (T, 4, 2880) โ each is half the intermediate_size
# Output: (T, 4, 2880) โ ready for MLP2
#
# Note: x0 appears twice โ both as the sigmoid input AND
# multiplied by the sigmoid output. This is the "Swish" pattern:
# Swish(x) = x ยท ฯ(ฮฑx)class MLPBlock(torch.nn.Module):
def __init__(self, config: ModelConfig, device: torch.device | None = None):
super().__init__()
self.config = config
self.norm = RMSNorm(config.hidden_size, device=device) # Pre-norm
# Router: linear layer, no bias, scores each of 128 experts
self.gate = torch.nn.Linear(
config.hidden_size, # 2880
config.num_experts, # 128
device=device,
)
# Expert parameters โ raw tensors, not nn.Linear!
E = config.num_experts # 128
I = config.intermediate_size * 2 # 5760 (ร2 for interleaved gate+linear)
H = config.hidden_size # 2880
# MLP1: up-projection per expert (2880 โ 5760)
self.mlp1_weight = torch.nn.Parameter(torch.empty(E, I, H, device=device))
self.mlp1_bias = torch.nn.Parameter(torch.empty(E, I, device=device))
# MLP2: down-projection per expert (2880 โ 2880)
self.mlp2_weight = torch.nn.Parameter(torch.empty(E, H, H, device=device))
self.mlp2_bias = torch.nn.Parameter(torch.empty(E, H, device=device))
# Note: These are (128, ...) tensors โ all experts stored together.
# At forward time, we index into them to extract only the top-4 experts.
# This is more memory-efficient than 128 separate nn.Linear modules
# and allows batched operations via einsum.def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x) # (T, 2880) โ pre-norm
# โ Route: score all 128 experts per token
gate = torch.softmax( # Softmax over experts
self.gate(t), dim=-1 # (T, 2880) โ (T, 128)
)
# โก Select top-4 experts per token
expert_weights, expert_indices = torch.topk(
gate, self.config.experts_per_token, dim=-1 # โ (T, 4) each
)
expert_weights = expert_weights / expert_weights.sum(-1, keepdim=True) # Renormalize
# โข Gather selected expert parameters
mlp1_w = self.mlp1_weight[expert_indices] # (T, 4, 5760, 2880)
mlp1_b = self.mlp1_bias[expert_indices] # (T, 4, 5760)
mlp2_w = self.mlp2_weight[expert_indices] # (T, 4, 2880, 2880)
mlp2_b = self.mlp2_bias[expert_indices] # (T, 4, 2880)
# โฃ MLP1: up-projection per expert
t = torch.einsum("beck,bk->bec", mlp1_w, t) + mlp1_b # (T, 4, 5760)
# โค SwiGLU activation
t = swiglu(t, self.config.swiglu_limit) # (T, 4, 2880)
# โฅ MLP2: down-projection per expert
t = torch.einsum("beck,bec->bek", mlp2_w, t) + mlp2_b # (T, 4, 2880)
# โฆ Weighted combination of expert outputs + residual
t = torch.einsum("bec,be->bc", t, expert_weights) # (T, 2880)
return x + t # Residual connectionclass TransformerBlock(torch.nn.Module):
"""One layer: attention + MoE, with residual connections."""
def __init__(self, config: ModelConfig, layer_idx: int, device: torch.device | None = None):
super().__init__()
self.attention = AttentionBlock(config, layer_idx, device=device)
self.mlp = MLPBlock(config, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attention(x) # Residual is inside AttentionBlock
x = self.mlp(x) # Residual is inside MLPBlock
return x
class Transformer(torch.nn.Module):
def __init__(self, config: ModelConfig, device: torch.device | None = None):
super().__init__()
self.config = config
# Token embedding: 201088 โ 2880
self.embedding = torch.nn.Embedding(
config.vocab_size, config.hidden_size, device=device
)
# 36 transformer blocks
self.block = torch.nn.ModuleList([
TransformerBlock(config, layer_idx=i, device=device)
for i in range(config.num_hidden_layers)
])
# Final layer norm
self.norm = RMSNorm(config.hidden_size, device=device)
# Unembedding: 2880 โ 201088 (NOT tied with embedding)
self.unembed = torch.nn.Linear(
config.hidden_size, config.vocab_size, device=device
)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
x = self.embedding(token_ids) # (T,) โ (T, 2880)
for block in self.block: # 36 layers
x = block(x) # (T, 2880) โ (T, 2880)
return self.unembed(self.norm(x)) # (T, 2880) โ (T, 201088)@classmethod
def from_checkpoint(
cls, checkpoint_path: str, device: str, world_size: int = 1, rank: int = 0
) -> "Transformer":
config = ModelConfig()
if not isinstance(device, torch.device):
device = torch.device(device)
model = cls(config, device=device) # โ Create model (random params on device)
model.eval() # โก Set to eval mode
per_rank = config.intermediate_size * 2 // world_size # 5760 // ws
offset = rank * per_rank # Start index for this GPU
with Checkpoint(checkpoint_path) as cp: # โข Open MXFP4 checkpoint
for name, param in model.named_parameters():
full_tensor = cp[name] # Load & upcast from checkpoint
if "mlp1" in name: # Shard MLP1 along intermediate dim
full_tensor = full_tensor[..., offset : offset + per_rank, :]
# mlp1_weight: (128, 5760, 2880) โ (..., per_rank, 2880)
# mlp1_bias: (128, 5760) โ (..., per_rank)
# The ellipsis handles both weight (3D) and bias (2D)
elif "mlp2_weight" in name: # Shard MLP2 weight along last dim
full_tensor = full_tensor[..., offset : offset + per_rank]
# (128, 2880, 2880) โ (128, 2880, per_rank)
# mlp2_bias is NOT sharded โ needed in full after all_reduce
try:
param.data.copy_(full_tensor)
except Exception as e:
print(f"Error: {name}: {e} ({param.shape} vs {full_tensor.shape})")
raise
return modelclass TokenGenerator:
@torch.inference_mode() # โ Disables gradient tracking entirely
def __init__(self, checkpoint_path: str, device: str):
self.device = device
self.model = Transformer.from_checkpoint(checkpoint_path, device)
# That's it! The constructor:
# 1. Loads the model from an MXFP4 checkpoint
# 2. Moves everything to the specified device (e.g., "cuda:0")
# 3. @torch.inference_mode() means no autograd overhead during loading
#
# Note what's NOT here:
# - No tokenizer! TokenGenerator works with raw token IDs
# - No KV-cache initialization
# - No batch size configuration
# - Tokenization is handled externally before calling generate()@torch.inference_mode()
def generate(
self,
prompt_tokens: list[int], # Already-tokenized input
stop_tokens: list[int], # Token IDs that signal end-of-generation
temperature: float = 1.0, # Sampling temperature (0 = greedy)
max_tokens: int = 0, # 0 means unlimited
return_logprobs: bool = False,
):
tokens = list(prompt_tokens)
num_generated = 0
while max_tokens == 0 or num_generated < max_tokens:
# โ Run FULL model on ALL tokens (no KV-cache!)
input_ids = torch.as_tensor(tokens, dtype=torch.int32, device=self.device)
logits = self.model(input_ids) # (T, 201088) โ logits for every position
logits = logits[-1] # (201088,) โ only last position matters
# โก Sample or argmax
if temperature == 0: # Greedy decoding
token = logits.argmax(-1).item()
else: # Temperature sampling
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
token = torch.multinomial(probs, 1).item()
tokens.append(token)
num_generated += 1
# โข Yield the token (and optionally log-prob)
if return_logprobs:
log_probs = torch.log_softmax(logits, dim=-1)
yield token, log_probs[token].item()
else:
yield token
# โฃ Stop if we hit a stop token
if token in stop_tokens:
breakExternal Tokenizer (not in model.py) โ โผ Token IDs: [101, 2003, 1037, 6251, 102] shape: (5,) โ โผ โโ Transformer.forward() โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ โ Embedding lookup (5,) โ (5, 2880) โ โ โ โ โผ โ โ TransformerBlock 0 (sliding window W=128) โ โ โโ AttentionBlock: RMSNorm โ QKV(5120) โ RoPE โ SDPA โ Out โ โ โโ MLPBlock: RMSNorm โ Route(top-4/128) โ Experts โ ฮฃ โ โ โ โ โ TransformerBlock 1 (full causal attention) โ โ โโ AttentionBlock: RMSNorm โ QKV(5120) โ RoPE โ SDPA โ Out โ โ โโ MLPBlock: RMSNorm โ Route(top-4/128) โ Experts โ ฮฃ โ โ โ โ โ ... (36 blocks total, alternating window/full) ... โ โ โ โ โ TransformerBlock 35 (full causal) โ โ โ โ โ โผ โ โ Final RMSNorm (5, 2880) โ (5, 2880) โ โ โ โ โผ โ โ Unembedding (Linear) (5, 2880) โ (5, 201088) โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โผ Logits[:, -1, :] โ softmax โ sample โ next token ID