packages = ["bokeh", "pandas", "numpy", "networkx", "diagrams", "scikit-learn", "pillow", "matplotlib", "plotly"] [[fetch]] files = ["diagrams_base.py"] from = "../../python/diagrams/" [[fetch]] files = ["pyscript_manager.py", "data.py"] to_folder = "lib" from = "../../python/lib/" [[fetch]] files = ["bokeh_utils.py"] from = "../../python/bokeh/" [[fetch]] files = ["matplotlib_utils.py", "plotly_utils.py"] from = "../../python/matplotlib/" [[fetch]] files = ["agent.py", "trainer.py", "utils.py", "metrics_chart.py", "crossover.py", "__init__.py"] to_folder = "ml/neuro" from = "../../python/ml/neuro/" [[fetch]] files = ["trainer.py"] to_folder = "ml/grokking" from = "../../python/ml/grokking/"

Loading...

Click Start Training to begin

Epoch
0
Train Acc
0.0%
Test Acc
0.0%
Grokked?
No

Test Prediction: (a + b) mod 67

+ mod 67

πŸ“ˆ Training Metrics

Chart will appear during Training

Grokking Neural Networks

🧠 What is Grokking?

Grokking is a fascinating phenomenon where neural networks suddenly "get it" after appearing to only memorize. The network first achieves 100% training accuracy (memorization), then much later, test accuracy suddenly jumps (generalization).

βž• The Task: Modular Arithmetic

The network learns modular addition: (a + b) mod 67. For example, 45 + 30 mod 67 = 8. This seems simple, but discovering the underlying pattern (instead of memorizing) requires the network to learn the structure of modular arithmetic.

(a + b) mod p = c

⚑ Why Does Grokking Happen?

The key is HIGH weight decay (regularization). Without it, networks just memorize. With strong weight decay, the network is pushed toward simpler solutions that generalize. This takes time - the "grokking" moment happens when the network finally discovers the generalizing solution.

  • Weight Decay = 1.0 (very high!)
  • Forces simpler weight patterns
  • Memorization uses more "capacity"
  • Generalization is more efficient

πŸ“Š The Grokking Timeline

  1. Phase 1 - Memorization: Train accuracy quickly reaches ~100%
  2. Phase 2 - Plateau: Test accuracy stays low (20-40%)
  3. Phase 3 - Grokking: Suddenly, test accuracy jumps to ~100%!

This can take thousands of epochs. Be patient and watch the charts!

πŸ—οΈ Network Architecture

Based on Google's Factored MLP with tied embeddings:

  • Input: Two tokens (a, b) embedded to 500 dimensions
  • Hidden: 64 neurons with ReLU activation
  • Output: 67 classes (one per possible result)
  • Tied weights: Same embedding for input and output
  • Optimizer: AdamW with weight decay = 1.0

πŸ”§ Technical: PyScript

The neural network is implemented in pure Python, running via PyScript's Pyodide runtime directly in your browser. Training uses asyncio to yield control periodically, keeping the UI responsive. No server required!

Learn more about grokking from Google's interactive article β†’

🐍 Python Console