Click Start Training to begin
Test Prediction: (a + b) mod 67
π Training Metrics
Chart will appear during Training
Grokking Neural Networks
π§ What is Grokking?
Grokking is a fascinating phenomenon where neural networks suddenly "get it" after appearing to only memorize. The network first achieves 100% training accuracy (memorization), then much later, test accuracy suddenly jumps (generalization).
β The Task: Modular Arithmetic
The network learns modular addition: (a + b) mod 67. For example, 45 + 30 mod 67 = 8. This seems simple, but discovering the underlying pattern (instead of memorizing) requires the network to learn the structure of modular arithmetic.
β‘ Why Does Grokking Happen?
The key is HIGH weight decay (regularization). Without it, networks just memorize. With strong weight decay, the network is pushed toward simpler solutions that generalize. This takes time - the "grokking" moment happens when the network finally discovers the generalizing solution.
- Weight Decay = 1.0 (very high!)
- Forces simpler weight patterns
- Memorization uses more "capacity"
- Generalization is more efficient
π The Grokking Timeline
- Phase 1 - Memorization: Train accuracy quickly reaches ~100%
- Phase 2 - Plateau: Test accuracy stays low (20-40%)
- Phase 3 - Grokking: Suddenly, test accuracy jumps to ~100%!
This can take thousands of epochs. Be patient and watch the charts!
ποΈ Network Architecture
Based on Google's Factored MLP with tied embeddings:
- Input: Two tokens (a, b) embedded to 500 dimensions
- Hidden: 64 neurons with ReLU activation
- Output: 67 classes (one per possible result)
- Tied weights: Same embedding for input and output
- Optimizer: AdamW with weight decay = 1.0
π§ Technical: PyScript
The neural network is implemented in pure Python, running via PyScript's Pyodide runtime directly in your browser. Training uses asyncio to yield control periodically, keeping the UI responsive. No server required!
Learn more about grokking from Google's interactive article β