Ctrl/Cmd + P and select "Save as PDF".
function DSAScore(Q, K, W_weights) // Q: indexer queries, K: cached indexer keys // 1. Calculate base dot product base_scores = Q * K^T / sqrt(d_index) // 2. Weight by head importance head_weights = x * W_weights / sqrt(H) // 3. Final score is weighted sum across heads final_score = sum(base_scores * head_weights) return top_k(final_score, k=2048)
# Simplified MLA KV projection path
class MLA_KV(nn.Module):
def __init__(self, hidden, lora_rank, full_dim):
# Down-project to low rank (e.g., 512)
self.down = nn.Linear(hidden, lora_rank)
self.norm = RMSNorm(lora_rank)
# Up-project to full heads (e.g., 64 heads * 256 dim)
self.up = nn.Linear(lora_rank, full_dim)
def forward(self, x):
# Compress and cache this tiny tensor!
c_kv = self.norm(self.down(x))
# Dynamically reconstruct full K and V when needed
full_kv = self.up(c_kv)
return c_kv, full_kvclass GlmMoeDsaMoE(nn.Module):
def forward(self, hidden_states: torch.Tensor):
# 1. Router assigns tokens to top-k experts
router_logits = self.gate(hidden_states)
indices, weights = self.route_tokens(router_logits) # Sigmoid + Bias
# 2. Process through selected routed experts (sparse)
routed_out = self.experts(
hidden_states, indices, weights
)
# 3. Process through shared expert (dense, every token)
shared_out = self.shared_experts(hidden_states)
# 4. Combine outputs
return routed_out + shared_out