πŸ–¨οΈ Printing Instructions: Press Ctrl/Cmd + P and select "Save as PDF".
1

Modern Transformer Upgrades

RoPE, RMSNorm, SwiGLU, GQA, KV-Cache & Flash Attention

2

Where We Are

3

Overview β€” From 2017 to 2026

4

Why Upgrade the Original Transformer?

5

The Upgrade Map

6

A Concrete Before/After

7

Part 1: Pre-Norm β€” Stabilizing Deep Networks

8

Post-Norm vs Pre-Norm

9

Pre-Norm vs Post-Norm in Code

python
import torch.nn as nn

# --- POST-NORM (Original 2017 Transformer) ---
class PostNormBlock(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(embed_dim, n_heads)
        self.ffn = FFN(embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.norm1(x + self.attention(x))  # Norm AFTER residual
        x = self.norm2(x + self.ffn(x))        # Gradient must flow through norm
        return x                                # at EVERY layer β†’ unstable at depth

# --- PRE-NORM (Modern β€” common in 2025–2026 SOTA codebases) ---
class PreNormBlock(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(embed_dim, n_heads)
        self.ffn = FFN(embed_dim)
        self.norm1 = RMSNorm(embed_dim)  # RMSNorm, not LayerNorm!
        self.norm2 = RMSNorm(embed_dim)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))  # Norm BEFORE sublayer
        x = x + self.ffn(self.norm2(x))        # Clean residual highway
        return x                                # Stable even at 100+ layers
10

Part 2: RMSNorm β€” Faster Normalization

11

From LayerNorm to RMSNorm

12

RMSNorm Implementation

python
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (Zhang & Sennrich, 2019).
    Drop-in replacement for nn.LayerNorm β€” fewer ops, same quality."""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # Learnable scale Ξ³
        # Note: NO bias parameter Ξ² (unlike LayerNorm)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch, seq_len, dim)
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        x_normed = x / rms                     # Normalize by RMS
        return x_normed * self.weight            # Scale by learned Ξ³

# Compare parameter counts:
dim = 4096
layer_norm = nn.LayerNorm(dim)     # 2 Γ— 4096 = 8,192 params (Ξ³ + Ξ²)
rms_norm   = RMSNorm(dim)          # 1 Γ— 4096 = 4,096 params (Ξ³ only)
print(f"LayerNorm params: {sum(p.numel() for p in layer_norm.parameters()):,}")
print(f"RMSNorm params:   {sum(p.numel() for p in rms_norm.parameters()):,}")

# Per model (32 layers Γ— 2 norms each = 64 norm layers):
# LayerNorm: 64 Γ— 8,192 = 524,288 params
# RMSNorm:   64 Γ— 4,096 = 262,144 params β€” half the norm parameters
13

Part 3: SwiGLU β€” The Modern Feed-Forward Network

14

The FFN Evolution: ReLU β†’ GELU β†’ Swish β†’ SwiGLU

15

SwiGLU: Gated Activation for Better FFNs

16

SwiGLU FFN Implementation

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLUFFN(nn.Module):
    """SwiGLU Feed-Forward Network (Shazeer 2020).
    Replaces the standard ReLU/GELU FFN in modern transformers."""
    def __init__(self, embed_dim: int, hidden_dim: int = None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = int(8 * embed_dim / 3)                # Compensate for 3rd matrix
            hidden_dim = 256 * ((hidden_dim + 255) // 256)     # Round to multiple of 256
        self.w1 = nn.Linear(embed_dim, hidden_dim, bias=False)  # Gate projection
        self.w3 = nn.Linear(embed_dim, hidden_dim, bias=False)  # Up projection
        self.w2 = nn.Linear(hidden_dim, embed_dim, bias=False)  # Down projection

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU(x) = (Swish(xW₁) βŠ™ xW₃) Β· Wβ‚‚
        return self.w2(F.silu(self.w1(x)) * self.w3(x))  # F.silu = Swish

# Parameter comparison at embed_dim = 4096:
class OriginalFFN(nn.Module):
    def __init__(self, d, h=None):
        super().__init__()
        h = h or 4 * d  # h = 16384
        self.w1 = nn.Linear(d, h, bias=False)
        self.act = nn.ReLU()
        self.w2 = nn.Linear(h, d, bias=False)
    def forward(self, x): return self.w2(self.act(self.w1(x)))

original = OriginalFFN(4096)       # 2 matrices: 4096Γ—16384 Γ— 2 = 134,217,728
swiglu   = SwiGLUFFN(4096)         # 3 matrices: 4096Γ—11008 Γ— 2 + 11008Γ—4096 = 135,266,304
print(f"Original: {sum(p.numel() for p in original.parameters()):>12,} params")
print(f"SwiGLU:   {sum(p.numel() for p in swiglu.parameters()):>12,} params")
# Nearly identical parameter count β€” but SwiGLU performs significantly better!
17

Part 4: RoPE β€” Rotary Position Embeddings

18

Why Absolute Position Embeddings Fall Short

19

RoPE: The Core Idea β€” Rotation Encodes Position

20

RoPE: The Mathematics

21

RoPE: Key Properties and Benefits

22

RoPE Implementation

python
import torch

def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
    """Precompute rotation frequencies as complex exponentials."""
    freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
    positions = torch.arange(max_seq_len)
    angles = torch.outer(positions, freqs)           # (max_seq_len, head_dim/2)
    return torch.polar(torch.ones_like(angles), angles)  # e^(jΞΈ) = cos ΞΈ + j sin ΞΈ

def apply_rope(q, k, freqs_cis):
    """Apply rotary embeddings to query and key tensors.
    q, k: (batch, seq_len, n_heads, head_dim)
    freqs_cis: (seq_len, head_dim/2) β€” precomputed complex exponentials
    """
    # Reshape to pairs β†’ view as complex numbers
    q_complex = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
    k_complex = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))

    # Broadcast freqs_cis: (seq_len, head_dim/2) β†’ (1, seq_len, 1, head_dim/2)
    freqs = freqs_cis.unsqueeze(0).unsqueeze(2)

    # Complex multiplication = 2D rotation!
    q_rotated = torch.view_as_real(q_complex * freqs).flatten(-2)
    k_rotated = torch.view_as_real(k_complex * freqs).flatten(-2)
    return q_rotated.type_as(q), k_rotated.type_as(k)

# Precompute once at model init:
freqs_cis = precompute_freqs_cis(head_dim=128, max_seq_len=8192)
# Then in each attention layer:
# q, k = apply_rope(q, k, freqs_cis[:seq_len])
# Followed by standard attention: softmax(QK^T / √d) V
23

Part 5: GQA β€” Grouped Query Attention

24

The KV Memory Problem at Scale

25

From MHA to MQA to GQA

26

GQA: Visual Intuition

27

GQA in Practice

28

Grouped Query Attention Implementation

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupedQueryAttention(nn.Module):
    """GQA: Multiple Q heads share fewer KV heads (Ainslie et al. 2023)."""
    def __init__(self, embed_dim: int, n_heads: int, n_kv_heads: int):
        super().__init__()
        assert n_heads % n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads   # Q heads per KV group
        self.head_dim = embed_dim // n_heads

        self.wq = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(embed_dim, n_kv_heads * self.head_dim, bias=False)  # Smaller!
        self.wv = nn.Linear(embed_dim, n_kv_heads * self.head_dim, bias=False)  # Smaller!
        self.wo = nn.Linear(n_heads * self.head_dim, embed_dim, bias=False)

    def forward(self, x, freqs_cis):
        B, T, _ = x.shape
        q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
        k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)

        q, k = apply_rope(q, k, freqs_cis)   # RoPE on Q and K only

        # Repeat KV heads to match Q head count
        k = k.repeat_interleave(self.n_rep, dim=2)  # (B, T, n_heads, head_dim)
        v = v.repeat_interleave(self.n_rep, dim=2)

        # Standard scaled dot-product attention (uses fast SDPA kernels automatically)
        q, k, v = [t.transpose(1, 2) for t in (q, k, v)]  # (B, n_heads, T, head_dim)
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.wo(out)

# Parameter savings (embed_dim=4096, 32 Q heads, 8 KV heads, head_dim=128):
# MHA K proj: 4096 Γ— 4096 = 16,777,216   GQA K proj: 4096 Γ— 1024 = 4,194,304 β†’ 4Γ— smaller!
29

GQA: Memory Savings Quantified

30

Three KV-Memory Strategies (2025–2026): Sharing vs Compression vs Selection

31

Part 6: KV-Cache β€” Efficient Autoregressive Generation

32

The Generation Bottleneck

33

How KV-Cache Works

34

Prefill vs Decode: Two Different Bottlenecks

35

KV-Cache Memory Analysis

36

Autoregressive Generation with KV-Cache

python
import torch
import torch.nn.functional as F

@torch.no_grad()
def generate(model, prompt_ids, max_new_tokens, temperature=0.7, top_p=0.95):
    """Autoregressive generation with KV-cache."""
    token_ids = prompt_ids  # (1, prompt_len)

    # Phase 1: PREFILL β€” process entire prompt, build initial cache
    logits, kv_cache = model.forward(token_ids, kv_cache=None)
    # kv_cache: list of (K, V) tuples, one per layer
    # Each K, V: (batch, n_kv_heads, seq_len, head_dim)

    for step in range(max_new_tokens):
        # Sample next token from last position's logits
        next_logits = logits[:, -1, :] / temperature
        probs = top_p_sample(F.softmax(next_logits, dim=-1), top_p)
        next_token = torch.multinomial(probs, num_samples=1)  # (1, 1)

        # Phase 2: DECODE β€” process ONLY the new token, reuse cached K,V
        logits, kv_cache = model.forward(next_token, kv_cache=kv_cache)
        # Internally: new K,V appended to cache; Q attends to full cache
        # Only 1 token processed instead of entire sequence!

        token_ids = torch.cat([token_ids, next_token], dim=1)
        if next_token.item() == tokenizer.eos_token_id:
            break

    return token_ids

# Without KV-cache: generating 100 tokens from 1000-token prompt
#   processes 1000 + 1001 + 1002 + ... + 1099 β‰ˆ 105,000 total tokens
# With KV-cache: processes 1000 + 1 + 1 + ... + 1 = 1,100 total tokens β†’ ~95Γ— less!
37

Part 7: Flash Attention β€” Breaking the Memory Wall

38

The Memory Wall in Standard Attention

39

Flash Attention: IO-Aware Exact Attention

40

Flash Attention Versions and Usage (Updated for 2026)

41

Flash Attention: One Line of Code

python
import torch
import torch.nn.functional as F
import math

# Assume q, k, v: (batch, n_heads, seq_len, head_dim)

# === OLD WAY: Standard attention β€” O(TΒ²) memory ===
def standard_attention(q, k, v, is_causal=True):
    T = q.size(-2)
    scale = math.sqrt(q.size(-1))
    attn_weights = torch.matmul(q, k.transpose(-2, -1)) / scale  # (B, H, T, T) ← O(TΒ²)!
    if is_causal:
        mask = torch.triu(torch.ones(T, T, device=q.device), diagonal=1).bool()
        attn_weights = attn_weights.masked_fill(mask, float('-inf'))  # Explicit mask
    attn_weights = F.softmax(attn_weights, dim=-1)
    return torch.matmul(attn_weights, v)

# === NEW WAY: Flash / Memory-Efficient SDPA β€” O(T) memory when flash kernel is used ===
def sdpa_attention(q, k, v):
    return F.scaled_dot_product_attention(q, k, v, is_causal=True)
    # PyTorch automatically chooses an optimized kernel on CUDA.

# NOTE: Exact bitwise equality is not guaranteed across kernels due to floating-point math,
# but results should be numerically very close.
q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
out_standard = standard_attention(q, k, v)
out_sdpa = sdpa_attention(q, k, v)
print(f"Max difference: {(out_standard - out_sdpa).abs().max():.6f}")
42

Part 8: Putting It All Together

43

The Modern Transformer Block

44

The Modern Transformer Block β€” Complete Code

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class ModernTransformerBlock(nn.Module):
    """A single transformer block with all modern upgrades.
    Pre-RMSNorm + GQA with RoPE + SwiGLU FFN."""
    def __init__(self, embed_dim: int, n_heads: int, n_kv_heads: int):
        super().__init__()
        self.norm1 = RMSNorm(embed_dim)
        self.attn = GroupedQueryAttention(embed_dim, n_heads, n_kv_heads)
        self.norm2 = RMSNorm(embed_dim)
        self.ffn = SwiGLUFFN(embed_dim)

    def forward(self, x, freqs_cis, kv_cache=None):
        # Pre-norm: normalize BEFORE sublayer, clean residual
        h = self.norm1(x)
        attn_out, new_kv = self.attn(h, freqs_cis, kv_cache)  # GQA + RoPE
        x = x + attn_out                                       # Residual

        x = x + self.ffn(self.norm2(x))                         # Pre-norm + SwiGLU + Residual
        return x, new_kv

# Compare to the 2017 original block:
# - nn.LayerNorm      β†’ RMSNorm         βœ“
# - nn.MultiheadAttn  β†’ GQA + RoPE      βœ“
# - ReLU FFN          β†’ SwiGLU          βœ“
# - Post-norm         β†’ Pre-norm        βœ“
# Same structure, better components!
45

GPT-OSS β€” Full Model Skeleton with Design Notes

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class GPTModel(nn.Module):
    """Modern GPT with all upgrades β€” the skeleton for GPT-OSS.

    Design choices (roughly matching the modern dense ~8B tier):
      - vocab_size  = 200,000  (byte-level BPE via tiktoken)
      - embed_dim   = 4,096
      - n_layers    = 32
      - n_heads     = 32  (Q heads)  β†’  head_dim = 128
      - n_kv_heads  = 8   (GQA 4:1) β†’  4Γ— KV-cache savings
      - FFN hidden  β‰ˆ 11,008  (SwiGLU, 8/3 Γ— 4096, rounded)
      - No biases (often), no dropout (pre-training), weight tying
      - Estimated total: ~7B parameters

    Note: Many 2026 SOTA stacks add MoE + hybrid/sparse attention + MLA-style KV compression.
    This skeleton intentionally teaches the "classic" modern dense backbone first.
    """
    def __init__(self, vocab_size=200_000, max_seq_len=8192,
                 embed_dim=4096, n_layers=32, n_heads=32, n_kv_heads=8):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        # No positional embedding table β€” RoPE is applied inside attention.

        self.layers = nn.ModuleList([
            ModernTransformerBlock(embed_dim, n_heads, n_kv_heads)
            for _ in range(n_layers)
        ])
        self.norm = RMSNorm(embed_dim)          # Final norm (required by pre-norm)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.lm_head.weight = self.token_embed.weight  # Weight tying β†’ 0 extra params

        # Precompute RoPE frequencies once
        head_dim = embed_dim // n_heads
        self.register_buffer(
            'freqs_cis',
            precompute_freqs_cis(head_dim, max_seq_len),
            persistent=False
        )

    def forward(self, token_ids, targets=None, kv_cache=None):
        B, T = token_ids.shape
        x = self.token_embed(token_ids)        # (B, T, embed_dim)
        freqs = self.freqs_cis[:T]

        new_kv_cache = []
        for i, layer in enumerate(self.layers):
            cache_i = kv_cache[i] if kv_cache else None
            x, kv = layer(x, freqs, cache_i)
            new_kv_cache.append(kv)

        x = self.norm(x)                       # Final RMSNorm
        logits = self.lm_head(x)               # (B, T, vocab_size)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1)
            )
        return logits, loss, new_kv_cache
46

Part 9: Additional Modern Techniques

47

No Biases, No Dropout (Common Pre-Training Defaults)

48

Weight Initialization and Context Extension

49

2026 Add-ons: MoE + MLA + Sparse/Hybrid Attention (Beyond the "Classic" Block)

50

Part 10: Common Pitfalls

51

Mistakes You'll Make (and How to Avoid Them)

52

Part 11: Real Model Configurations

53

State-of-the-Art Architectures Compared (2026, Publicly Documented)

54

Parameter Count Breakdown

55

Summary

56

What You Now Understand

57

The Road to GPT-OSS

58

Interactive Demos

59

Supplementary Resources