GistNet Architecture Details

This document captures the plan-of-record implementation of GistNet, the 32→1 compressor that transforms short token spans into substitutable gists for the MegaContext Tree. The Phase 1 code path (see mc/gistnet.py) standardises on a mini transformer backbone plus interchangeable pooling heads.

Architectural Overview

flowchart LR
    subgraph Window
        T[Token embeddings<br/>(32 × d)]
    end
    CLS[(optional CLS token)]
    subgraph Backbone
        B1[RoPE + LayerNorm]
        B2[Transformer Block × N<br/>(MHA + MLP + residuals)]
    end
    Head[[Pooling Head<br/>(mean · query · CLS)]]
    G[gist ∈ ℝ^d]

    T --> B1 --> B2 --> Head --> G
    T -. prepend .-> CLS --> B1
  1. Inputs. The base model’s embedding layer provides [B, 32, d] slices. Local RoPE metadata is applied directly to those embeddings.
  2. Mini transformer stack. We reuse nanochat’s attention/MLP blocks (pre-LN, residual connections, rotary queries/keys) with n_layer = 1–4 and n_head = 8 by default. GistNet never touches KV caches because each window is processed independently.
  3. Pooling head. The final [B, 32, d] activations are collapsed into one gist via one of three heads:
    • Mean pooling (baseline) → linear or two-layer MLP projection.
    • Query pooling → learned query attends once over the block (Structured Self-Attentive pooling [1]).
    • CLS pooling → prepend a learnable [CLS] token before the backbone and read it afterwards (BERT-style [2]).

The result is a d-dimensional vector aligned with the base LLM’s embedding space so we can splice it directly into the context.

Transformer Backbone

ParameterValue (run10 default)Notes
Window length32 tokensFixed to match MegaContext block size.
Embedding dim dMatches base model (depth * 64)Pulled directly from nanochat config.
Layers n_layer2Can be tuned per script.
Heads n_head8Full multi-head attention (no MQA).
RoPE base10 000Shared with nanochat GPT.

Each block performs:

x = x + MHA( norm(x), rotary=True )
x = x + MLP( norm(x) )

with bf16 activations and fp32 parameters. There is no slot/cross-attention anymore; all mixing happens through self-attention inside the window.

Pooling Heads

Head nameSummaryProjectionNotes
mean_linearMean of the blockLinear“Dumb” control baseline (--gistnet_type mean --gistnet_head linear).
mean_mlpMean + 2-layer MLPLinear→ReLU→LinearDefault for transformer families (--gistnet_head mlp).
query_*Learned query attends once (Lin et al. 2017 [1])Linear/MLPControlled via --gistnet_pooling query and --gistnet_head.
cls_*Prepend [CLS], read directly (Devlin et al. 2018 [2])Linear/MLPControlled via --gistnet_pooling cls and --gistnet_head.

Every head reports whether it requires a [CLS] token so the runtime can prepend it exactly once.

FLOPs Estimate (run10)

DepthHeadFLOPs / block (×10⁶)Comment
2 layersLinear~10Cheapest transformer head.
2 layersMLP~13Default accuracy/cost point.
4 layersLinear~20Roughly 2× cost vs depth‑2.
4 layersMLP~25Highest quality before diminishing returns.
Mean baselineLinear/MLP~1No transformer; for sanity checks.

Hierarchical Application

  1. LOD0 (tokens → gists). Run the transformer head on raw token embeddings.
  2. LOD1 (gists → super-gists). Take groups of 32 LOD0 gists, feed them back through the same module (weights can be shared) to get LOD1 gists.
  3. LOD≥2. Repeat as required: every extra level adds log₃₂ compression and still uses the identical code path, only swapping the input tensor source.

This produces a MegaContext tree where every node is generated by the same transformer head, vastly simplifying maintenance compared to the old slot/cross-attention pipeline.

Implementation Notes

  • File: mc/gistnet.py (imported by mc/runtime.MCController).
  • Config knobs: --block_size, --gistnet_type, --gistnet_layers, --gistnet_pooling, --gistnet_head (see run scripts).
  • Device placement: GistNet runs on the same device as the base model embeddings (CUDA). Metadata (positions, LOD) stays on-device to avoid PCIe traffic.
  • Telemetry: Use nanochat.report counters to log gist substitution ΔNLL and pooling head choice per run.

References

  1. Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen Zhou, and Yoshua Bengio. “A Structured Self-Attentive Sentence Embedding.” ICLR 2017.
  2. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” arXiv:1810.04805, 2018.
  3. Gist Tokens (Mu et al., 2023) — Analysis — Original substitutability proposal.
  4. Compressive Transformer (Rae et al., 2019) — Analysis — Hierarchical compression inspiration.
  5. Knowledge Distillation (Hinton et al., 2015) — Analysis — Teacher-student training framework used for ΔNLL supervision.