Ctrl/Cmd + P and select "Save as PDF".
import torch.nn as nn
# --- POST-NORM (Original 2017 Transformer) ---
class PostNormBlock(nn.Module):
def __init__(self, embed_dim, n_heads):
super().__init__()
self.attention = MultiHeadAttention(embed_dim, n_heads)
self.ffn = FFN(embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.norm1(x + self.attention(x)) # Norm AFTER residual
x = self.norm2(x + self.ffn(x)) # Gradient must flow through norm
return x # at EVERY layer β unstable at depth
# --- PRE-NORM (Modern β common in 2025β2026 SOTA codebases) ---
class PreNormBlock(nn.Module):
def __init__(self, embed_dim, n_heads):
super().__init__()
self.attention = MultiHeadAttention(embed_dim, n_heads)
self.ffn = FFN(embed_dim)
self.norm1 = RMSNorm(embed_dim) # RMSNorm, not LayerNorm!
self.norm2 = RMSNorm(embed_dim)
def forward(self, x):
x = x + self.attention(self.norm1(x)) # Norm BEFORE sublayer
x = x + self.ffn(self.norm2(x)) # Clean residual highway
return x # Stable even at 100+ layersimport torch
import torch.nn as nn
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (Zhang & Sennrich, 2019).
Drop-in replacement for nn.LayerNorm β fewer ops, same quality."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # Learnable scale Ξ³
# Note: NO bias parameter Ξ² (unlike LayerNorm)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x shape: (batch, seq_len, dim)
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
x_normed = x / rms # Normalize by RMS
return x_normed * self.weight # Scale by learned Ξ³
# Compare parameter counts:
dim = 4096
layer_norm = nn.LayerNorm(dim) # 2 Γ 4096 = 8,192 params (Ξ³ + Ξ²)
rms_norm = RMSNorm(dim) # 1 Γ 4096 = 4,096 params (Ξ³ only)
print(f"LayerNorm params: {sum(p.numel() for p in layer_norm.parameters()):,}")
print(f"RMSNorm params: {sum(p.numel() for p in rms_norm.parameters()):,}")
# Per model (32 layers Γ 2 norms each = 64 norm layers):
# LayerNorm: 64 Γ 8,192 = 524,288 params
# RMSNorm: 64 Γ 4,096 = 262,144 params β half the norm parametersimport torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
"""SwiGLU Feed-Forward Network (Shazeer 2020).
Replaces the standard ReLU/GELU FFN in modern transformers."""
def __init__(self, embed_dim: int, hidden_dim: int = None):
super().__init__()
if hidden_dim is None:
hidden_dim = int(8 * embed_dim / 3) # Compensate for 3rd matrix
hidden_dim = 256 * ((hidden_dim + 255) // 256) # Round to multiple of 256
self.w1 = nn.Linear(embed_dim, hidden_dim, bias=False) # Gate projection
self.w3 = nn.Linear(embed_dim, hidden_dim, bias=False) # Up projection
self.w2 = nn.Linear(hidden_dim, embed_dim, bias=False) # Down projection
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU(x) = (Swish(xWβ) β xWβ) Β· Wβ
return self.w2(F.silu(self.w1(x)) * self.w3(x)) # F.silu = Swish
# Parameter comparison at embed_dim = 4096:
class OriginalFFN(nn.Module):
def __init__(self, d, h=None):
super().__init__()
h = h or 4 * d # h = 16384
self.w1 = nn.Linear(d, h, bias=False)
self.act = nn.ReLU()
self.w2 = nn.Linear(h, d, bias=False)
def forward(self, x): return self.w2(self.act(self.w1(x)))
original = OriginalFFN(4096) # 2 matrices: 4096Γ16384 Γ 2 = 134,217,728
swiglu = SwiGLUFFN(4096) # 3 matrices: 4096Γ11008 Γ 2 + 11008Γ4096 = 135,266,304
print(f"Original: {sum(p.numel() for p in original.parameters()):>12,} params")
print(f"SwiGLU: {sum(p.numel() for p in swiglu.parameters()):>12,} params")
# Nearly identical parameter count β but SwiGLU performs significantly better!import torch
def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
"""Precompute rotation frequencies as complex exponentials."""
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
positions = torch.arange(max_seq_len)
angles = torch.outer(positions, freqs) # (max_seq_len, head_dim/2)
return torch.polar(torch.ones_like(angles), angles) # e^(jΞΈ) = cos ΞΈ + j sin ΞΈ
def apply_rope(q, k, freqs_cis):
"""Apply rotary embeddings to query and key tensors.
q, k: (batch, seq_len, n_heads, head_dim)
freqs_cis: (seq_len, head_dim/2) β precomputed complex exponentials
"""
# Reshape to pairs β view as complex numbers
q_complex = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
k_complex = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
# Broadcast freqs_cis: (seq_len, head_dim/2) β (1, seq_len, 1, head_dim/2)
freqs = freqs_cis.unsqueeze(0).unsqueeze(2)
# Complex multiplication = 2D rotation!
q_rotated = torch.view_as_real(q_complex * freqs).flatten(-2)
k_rotated = torch.view_as_real(k_complex * freqs).flatten(-2)
return q_rotated.type_as(q), k_rotated.type_as(k)
# Precompute once at model init:
freqs_cis = precompute_freqs_cis(head_dim=128, max_seq_len=8192)
# Then in each attention layer:
# q, k = apply_rope(q, k, freqs_cis[:seq_len])
# Followed by standard attention: softmax(QK^T / βd) Vimport torch
import torch.nn as nn
import torch.nn.functional as F
class GroupedQueryAttention(nn.Module):
"""GQA: Multiple Q heads share fewer KV heads (Ainslie et al. 2023)."""
def __init__(self, embed_dim: int, n_heads: int, n_kv_heads: int):
super().__init__()
assert n_heads % n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads # Q heads per KV group
self.head_dim = embed_dim // n_heads
self.wq = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(embed_dim, n_kv_heads * self.head_dim, bias=False) # Smaller!
self.wv = nn.Linear(embed_dim, n_kv_heads * self.head_dim, bias=False) # Smaller!
self.wo = nn.Linear(n_heads * self.head_dim, embed_dim, bias=False)
def forward(self, x, freqs_cis):
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
q, k = apply_rope(q, k, freqs_cis) # RoPE on Q and K only
# Repeat KV heads to match Q head count
k = k.repeat_interleave(self.n_rep, dim=2) # (B, T, n_heads, head_dim)
v = v.repeat_interleave(self.n_rep, dim=2)
# Standard scaled dot-product attention (uses fast SDPA kernels automatically)
q, k, v = [t.transpose(1, 2) for t in (q, k, v)] # (B, n_heads, T, head_dim)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.wo(out)
# Parameter savings (embed_dim=4096, 32 Q heads, 8 KV heads, head_dim=128):
# MHA K proj: 4096 Γ 4096 = 16,777,216 GQA K proj: 4096 Γ 1024 = 4,194,304 β 4Γ smaller!import torch
import torch.nn.functional as F
@torch.no_grad()
def generate(model, prompt_ids, max_new_tokens, temperature=0.7, top_p=0.95):
"""Autoregressive generation with KV-cache."""
token_ids = prompt_ids # (1, prompt_len)
# Phase 1: PREFILL β process entire prompt, build initial cache
logits, kv_cache = model.forward(token_ids, kv_cache=None)
# kv_cache: list of (K, V) tuples, one per layer
# Each K, V: (batch, n_kv_heads, seq_len, head_dim)
for step in range(max_new_tokens):
# Sample next token from last position's logits
next_logits = logits[:, -1, :] / temperature
probs = top_p_sample(F.softmax(next_logits, dim=-1), top_p)
next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
# Phase 2: DECODE β process ONLY the new token, reuse cached K,V
logits, kv_cache = model.forward(next_token, kv_cache=kv_cache)
# Internally: new K,V appended to cache; Q attends to full cache
# Only 1 token processed instead of entire sequence!
token_ids = torch.cat([token_ids, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return token_ids
# Without KV-cache: generating 100 tokens from 1000-token prompt
# processes 1000 + 1001 + 1002 + ... + 1099 β 105,000 total tokens
# With KV-cache: processes 1000 + 1 + 1 + ... + 1 = 1,100 total tokens β ~95Γ less!import torch
import torch.nn.functional as F
import math
# Assume q, k, v: (batch, n_heads, seq_len, head_dim)
# === OLD WAY: Standard attention β O(TΒ²) memory ===
def standard_attention(q, k, v, is_causal=True):
T = q.size(-2)
scale = math.sqrt(q.size(-1))
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / scale # (B, H, T, T) β O(TΒ²)!
if is_causal:
mask = torch.triu(torch.ones(T, T, device=q.device), diagonal=1).bool()
attn_weights = attn_weights.masked_fill(mask, float('-inf')) # Explicit mask
attn_weights = F.softmax(attn_weights, dim=-1)
return torch.matmul(attn_weights, v)
# === NEW WAY: Flash / Memory-Efficient SDPA β O(T) memory when flash kernel is used ===
def sdpa_attention(q, k, v):
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
# PyTorch automatically chooses an optimized kernel on CUDA.
# NOTE: Exact bitwise equality is not guaranteed across kernels due to floating-point math,
# but results should be numerically very close.
q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
out_standard = standard_attention(q, k, v)
out_sdpa = sdpa_attention(q, k, v)
print(f"Max difference: {(out_standard - out_sdpa).abs().max():.6f}")import torch
import torch.nn as nn
import torch.nn.functional as F
class ModernTransformerBlock(nn.Module):
"""A single transformer block with all modern upgrades.
Pre-RMSNorm + GQA with RoPE + SwiGLU FFN."""
def __init__(self, embed_dim: int, n_heads: int, n_kv_heads: int):
super().__init__()
self.norm1 = RMSNorm(embed_dim)
self.attn = GroupedQueryAttention(embed_dim, n_heads, n_kv_heads)
self.norm2 = RMSNorm(embed_dim)
self.ffn = SwiGLUFFN(embed_dim)
def forward(self, x, freqs_cis, kv_cache=None):
# Pre-norm: normalize BEFORE sublayer, clean residual
h = self.norm1(x)
attn_out, new_kv = self.attn(h, freqs_cis, kv_cache) # GQA + RoPE
x = x + attn_out # Residual
x = x + self.ffn(self.norm2(x)) # Pre-norm + SwiGLU + Residual
return x, new_kv
# Compare to the 2017 original block:
# - nn.LayerNorm β RMSNorm β
# - nn.MultiheadAttn β GQA + RoPE β
# - ReLU FFN β SwiGLU β
# - Post-norm β Pre-norm β
# Same structure, better components!import torch
import torch.nn as nn
import torch.nn.functional as F
class GPTModel(nn.Module):
"""Modern GPT with all upgrades β the skeleton for GPT-OSS.
Design choices (roughly matching the modern dense ~8B tier):
- vocab_size = 200,000 (byte-level BPE via tiktoken)
- embed_dim = 4,096
- n_layers = 32
- n_heads = 32 (Q heads) β head_dim = 128
- n_kv_heads = 8 (GQA 4:1) β 4Γ KV-cache savings
- FFN hidden β 11,008 (SwiGLU, 8/3 Γ 4096, rounded)
- No biases (often), no dropout (pre-training), weight tying
- Estimated total: ~7B parameters
Note: Many 2026 SOTA stacks add MoE + hybrid/sparse attention + MLA-style KV compression.
This skeleton intentionally teaches the "classic" modern dense backbone first.
"""
def __init__(self, vocab_size=200_000, max_seq_len=8192,
embed_dim=4096, n_layers=32, n_heads=32, n_kv_heads=8):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, embed_dim)
# No positional embedding table β RoPE is applied inside attention.
self.layers = nn.ModuleList([
ModernTransformerBlock(embed_dim, n_heads, n_kv_heads)
for _ in range(n_layers)
])
self.norm = RMSNorm(embed_dim) # Final norm (required by pre-norm)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
self.lm_head.weight = self.token_embed.weight # Weight tying β 0 extra params
# Precompute RoPE frequencies once
head_dim = embed_dim // n_heads
self.register_buffer(
'freqs_cis',
precompute_freqs_cis(head_dim, max_seq_len),
persistent=False
)
def forward(self, token_ids, targets=None, kv_cache=None):
B, T = token_ids.shape
x = self.token_embed(token_ids) # (B, T, embed_dim)
freqs = self.freqs_cis[:T]
new_kv_cache = []
for i, layer in enumerate(self.layers):
cache_i = kv_cache[i] if kv_cache else None
x, kv = layer(x, freqs, cache_i)
new_kv_cache.append(kv)
x = self.norm(x) # Final RMSNorm
logits = self.lm_head(x) # (B, T, vocab_size)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1)
)
return logits, loss, new_kv_cache