GistNet Architecture Details
This document captures the plan-of-record implementation of GistNet, the 32→1 compressor that transforms short token spans into substitutable gists for the MegaContext Tree. The Phase 1 code path (see mc/gistnet.py) standardises on a mini transformer backbone plus interchangeable pooling heads.
Architectural Overview
flowchart LR subgraph Window T[Token embeddings<br/>(32 × d)] end CLS[(optional CLS token)] subgraph Backbone B1[RoPE + LayerNorm] B2[Transformer Block × N<br/>(MHA + MLP + residuals)] end Head[[Pooling Head<br/>(mean · query · CLS)]] G[gist ∈ ℝ^d] T --> B1 --> B2 --> Head --> G T -. prepend .-> CLS --> B1
- Inputs. The base model’s embedding layer provides
[B, 32, d]slices. Local RoPE metadata is applied directly to those embeddings. - Mini transformer stack. We reuse nanochat’s attention/MLP blocks (pre-LN, residual connections, rotary queries/keys) with
n_layer = 1–4andn_head = 8by default. GistNet never touches KV caches because each window is processed independently. - Pooling head. The final
[B, 32, d]activations are collapsed into one gist via one of three heads:- Mean pooling (baseline) → linear or two-layer MLP projection.
- Query pooling → learned query attends once over the block (Structured Self-Attentive pooling [1]).
- CLS pooling → prepend a learnable
[CLS]token before the backbone and read it afterwards (BERT-style [2]).
The result is a d-dimensional vector aligned with the base LLM’s embedding space so we can splice it directly into the context.
Transformer Backbone
| Parameter | Value (run10 default) | Notes |
|---|---|---|
| Window length | 32 tokens | Fixed to match MegaContext block size. |
Embedding dim d | Matches base model (depth * 64) | Pulled directly from nanochat config. |
Layers n_layer | 2 | Can be tuned per script. |
Heads n_head | 8 | Full multi-head attention (no MQA). |
| RoPE base | 10 000 | Shared with nanochat GPT. |
Each block performs:
x = x + MHA( norm(x), rotary=True )
x = x + MLP( norm(x) )
with bf16 activations and fp32 parameters. There is no slot/cross-attention anymore; all mixing happens through self-attention inside the window.
Pooling Heads
| Head name | Summary | Projection | Notes |
|---|---|---|---|
mean_linear | Mean of the block | Linear | “Dumb” control baseline (--gistnet_type mean --gistnet_head linear). |
mean_mlp | Mean + 2-layer MLP | Linear→ReLU→Linear | Default for transformer families (--gistnet_head mlp). |
query_* | Learned query attends once (Lin et al. 2017 [1]) | Linear/MLP | Controlled via --gistnet_pooling query and --gistnet_head. |
cls_* | Prepend [CLS], read directly (Devlin et al. 2018 [2]) | Linear/MLP | Controlled via --gistnet_pooling cls and --gistnet_head. |
Every head reports whether it requires a [CLS] token so the runtime can prepend it exactly once.
FLOPs Estimate (run10)
| Depth | Head | FLOPs / block (×10⁶) | Comment |
|---|---|---|---|
| 2 layers | Linear | ~10 | Cheapest transformer head. |
| 2 layers | MLP | ~13 | Default accuracy/cost point. |
| 4 layers | Linear | ~20 | Roughly 2× cost vs depth‑2. |
| 4 layers | MLP | ~25 | Highest quality before diminishing returns. |
| Mean baseline | Linear/MLP | ~1 | No transformer; for sanity checks. |
Hierarchical Application
- LOD0 (tokens → gists). Run the transformer head on raw token embeddings.
- LOD1 (gists → super-gists). Take groups of 32 LOD0 gists, feed them back through the same module (weights can be shared) to get LOD1 gists.
- LOD≥2. Repeat as required: every extra level adds
log₃₂compression and still uses the identical code path, only swapping the input tensor source.
This produces a MegaContext tree where every node is generated by the same transformer head, vastly simplifying maintenance compared to the old slot/cross-attention pipeline.
Implementation Notes
- File:
mc/gistnet.py(imported bymc/runtime.MCController). - Config knobs:
--block_size,--gistnet_type,--gistnet_layers,--gistnet_pooling,--gistnet_head(see run scripts). - Device placement: GistNet runs on the same device as the base model embeddings (CUDA). Metadata (positions, LOD) stays on-device to avoid PCIe traffic.
- Telemetry: Use
nanochat.reportcounters to log gist substitution ΔNLL and pooling head choice per run.
References
- Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen Zhou, and Yoshua Bengio. “A Structured Self-Attentive Sentence Embedding.” ICLR 2017.
- Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” arXiv:1810.04805, 2018.
- Gist Tokens (Mu et al., 2023) — Analysis — Original substitutability proposal.
- Compressive Transformer (Rae et al., 2019) — Analysis — Hierarchical compression inspiration.
- Knowledge Distillation (Hinton et al., 2015) — Analysis — Teacher-student training framework used for ΔNLL supervision.