MegaContext: End-to-End Training Strategy
Status: Plan of record (POR). This PRD defines the active JT + nanochat integration contract; legacy alternating-optimization notes are historical reference only.
This document describes the end-to-end training procedure for |MegaContext, integrating GistNet, LensNet, the frozen base LLM, and lightweight LoRA adapters into a unified optimization loop. The objective is to co-train summarization, focus, and decoding so the system learns how to summarize, focus, and reason with compressed context.
Terminology: LOD0 = raw tokens, LOD1 = 32→1 gists, LOD2 = 32 LOD1 gists (i.e., 1024→1), etc.
Motivation
Early MegaContext prototypes used independent training phases (train GistNet, then LensNet, then the base/LoRA). That induces:
- Distribution shift: each component is optimized on stale or mismatched inputs.
- Weak coupling: improvements in one part don’t propagate to the others.
- Wasted compute: repeated forward passes with no joint gradient signal.
End-to-end training aligns GistNet, LensNet, and the base model under a single objective (next-token NLL), so the system co-learns how to summarize, focus, and reason with compressed context.
Overview
Each training sequence (context C) is used to create many working contexts (WCs) — mixtures of LOD0 tokens and GistNet summaries (LOD1, LOD2). Each WC represents a possible “focus” of the same full training context. These are evaluated, compared, and used to refine the models jointly.
Full training context C (e.g. 4k tokens)
├── Build Full [[MegaContext Tree]] - [[GistNet]] summaries (LOD1, LOD2, etc.) for C
├── Build N1 sampled Working Contexts of length C2 (e.g., 1k tokens)
│ ├── Always include a WC with the C2 most recent LOD0 tokens (baseline)
│ ├── Sample the rest from the set of valid summary combinations of C that fit a WC of length C2
│ └── The sampling strategy should ensure diversity (penalize selection of summary-WCs that are too similar wrt the gist selection).
├── Run [[LensNet]] + [[Focus Allocator]] on each of the N1 WCs → expand to N2 WCs
│ ├── Up to 4 simultaneous edits per N1 WC (expand/compress local spans)
│ ├── Include sibling perturbations for additional ΔNLL labels
│ └── Dedup
├── Evaluate base model on all N2 WCs (next-LOD0 token loss from training data)
│ ├── Teacher-forced multi-position horizon (H tokens; see section below)
│ └── Backprop through Base + [[GistNet]]
└── Update [[LensNet]] using ΔNLL supervision
├── Identify ideal WC(s) for supervision (regularized argmin)
├── Derive the ideal target focus scores needed for each WC to move towards its ideal WC
└── Backprop through [[LensNet]] based on the loss of actual focus scores vs. ideal target
Detailed Procedure
Step 1. Build MegaContext Tree
- Input: tokenized training context
C. - Goal: produce multi-level summaries using GistNet (e.g., LOD1 = 32→1, LOD2 = 32 LOD1 → 1).
- Store the hierarchical MegaContext Tree for re-use across WCs.
Step 2. Sample Working Contexts (N1)
- Goal: create N1 diverse WCs of a fixed length
C2(e.g., 1k tokens). - Always include: the
C2most recent LOD0 tokens (baseline WC). - Others: sample valid combinations of {LOD0, LOD1, LOD2, …} that fit
C2. - Apply diversity penalties to avoid near-duplicate WCs.
Step 3. LensNet Expansion and Focus Allocation (N2)
- Run LensNet on each WC to produce focus scores per span.
- The Focus Allocator applies up to four simultaneous edits (expand/collapse).
- Generate N2 new WCs; add sibling perturbations for extra ΔNLL labels.
- Deduplicate and keep the slate diverse.
Step 4. Evaluate Base + GistNet (Teacher-Forced Horizon)
Each WC in N2 is evaluated for next-LOD0 token prediction using teacher forcing over a short horizon H.
What “teacher forcing over a horizon” means (ELI5)
- You give the model one WC (a mix of LOD0/1/2 embeddings).
- You then ask it to predict the next real tokens from the dataset:
- For step 1, you score how well it predicts the true next token.
- For step 2, you still feed the true token (not the model’s guess) and score the next one.
- Repeat for H steps.
- The WC stays fixed during these H steps (no refocusing).
- The horizon loss for a WC is the average NLL across those H predictions.
Efficient ways to compute horizon loss
You do not need iterative rollouts to compute an H-step teacher-forced loss:
- Packed forward (preferred): build one sequence =
[WC || ground-truth continuation of length H], run a single forward, and compute loss only on the last H tokens. This exploits full-seq training efficiency (FlashAttention etc.). - Incremental with KV cache (alternative): precompute KV for the WC once, then step through H positions teacher-forced, reusing KV. This is cheaper if
|WC| » H, but usually the packed forward is simpler and fast on modern kernels.
Choose based on memory vs. speed: If the packed forward exceeds memory, use the KV-cached incremental method for that WC.
- Backpropagate losses through the base model and any GistNet summaries present in the WC.
Step 5. Update LensNet (ΔNLL Supervision + Regularized Argmin)
LensNet learns from observed NLLs and target focuses implied by the regularized argmin.
1) Identify ideal WC(s)
For each WC, find its ideal target WC (WC*) using a distance-regularized argmin over the evaluated slate:
WC* = argmin_i [ L(WC_i) + λ * D(WC_i, WC_current) ]
L(WC_i): the horizon NLL of WC_i.D(WC_i, WC_current): cost (number/weight) of edits required to reach WC_i from the current WC.λ: stability coefficient (discourages large focus jumps between near-tie WCs).
Each WC may have a different
WC*(the argmin is relative to its own current configuration).
2) Compute target focus scores
From WC → WC*, let the Allocator compute the minimal edit set (bounded) that moves WC toward WC*.
Translate those edits into target focus scores per span (expand = positive, collapse = negative).
3) LensNet objective
- Score-matching loss: compare LensNet’s actual focus scores on
WCvs. the target focus scores derived above. - Action-level ΔNLL hints: if sibling perturbations were evaluated, include the measured ΔNLL per edit as auxiliary regression labels (normalized by token cost).
- Regularizers:
- Zero-sum budget: net focus must conserve total WC tokens.
- Legality mask: penalize invalid operations (e.g., expanding LOD0, collapsing beyond root).
- Hysteresis: discourage immediate reversals of recent edits.
Gist-Level Auxiliary Losses (LOD1 / LOD2)
You can add semantic “gist” losses that align predicted content with the ground-truth summaries, without doing a full rollout.
LOD1 (32-token) gist loss
- When H is a multiple of 32 (e.g., H=32, 64, 96, …), split the H future tokens into 32-token blocks.
- For each 32-token block:
- Compute the ground-truth LOD1 gist via GistNet on the true tokens.
- Compute the predicted LOD1 gist in one of two ways (no feedback into inputs):
- Add a similarity loss (e.g., cosine or MSE) between predicted and ground-truth gists.
- Weight this loss with a small coefficient so it supplements token NLL.
LOD2 (1024-token) gist loss
- If you want a coarser semantic target, set H=1024 (or accumulate across steps) so a full LOD2 gist is defined:
Compute considerations: H=1024 is heavy. Prefer packed forward if memory allows; otherwise compute in chunks and combine gists (GistNet is cheap). Start with LOD1; add LOD2 once everything is stable.
Why This Design Works
| Design choice | Reason |
|---|---|
| Regularized argmin | Stabilizes focus targets; prefers small, consistent improvements over chaotic switches. |
| ΔNLL supervision | Ties LensNet directly to measured utility of focus changes. |
| Equal token budget | Ensures we learn allocation, not “more tokens helps.” |
| Diversity in WCs | Avoids myopic strategies (e.g., pure recency). |
| Gist losses (LOD1/LOD2) | Adds semantic pressure that complements token NLL, without rollout. |
| Packed forward or KV-cached horizon | Efficient horizon scoring; choose per-memory profile. |
Key Hyperparameters
| Symbol | Meaning | Typical Value |
|---|---|---|
N1 | # of initial sampled WCs | 8–32 |
N2 | # of post-Allocator WCs | 32–128 |
C2 | WC token budget (subset of C) | 512–2048 |
H | Horizon tokens for teacher-forced loss | 32–64 (start), 128–256 (later), 1024 (optional LOD2) |
λ | Stability regularization for argmin | 0.1–0.3 |
α1 | LOD1 gist loss weight | 0.01–0.05 |
α2 | LOD2 gist loss weight | 0.005–0.02 |
Pseudocode (Horizon + Gist Losses)
for each WC in N2_WCs:
# Horizon loss (packed forward)
seq = concat(WC, ground_truth_next_H_tokens)
logits = BaseModel(seq) # single forward
token_loss = NLL(logits[-H:], targets=ground_truth_next_H_tokens)
# Optional LOD1 gist loss (if H multiple of 32)
if H % 32 == 0:
gt_blocks = split(ground_truth_next_H_tokens, size=32)
pred_probs = softmax(logits[-H:]) # [H, vocab]
pred_embeds = pred_probs @ TokenEmbedding # [H, d]
pred_blocks = split(pred_embeds, size=32)
gt_lod1 = [ [[GistNet]](Embed(gt_block)) for gt_block in gt_blocks ]
pred_lod1 = [ [[GistNet]](pred_block) for pred_block in pred_blocks ]
lod1_loss = mean( 1 - cosine(pred_lod1[i], gt_lod1[i]) for i in range(len(gt_blocks)) )
# Optional LOD2 gist loss (if H == 1024)
if H == 1024:
gt_lod2 = [[GistNet]]( stack(gt_lod1) ) # 32 LOD1 gists -> 1
pred_lod2 = [[GistNet]]( stack(pred_lod1) )
lod2_loss = 1 - cosine(pred_lod2, gt_lod2)
total_loss = token_loss + α1 * lod1_loss + α2 * lod2_loss
backprop(total_loss) # through Base + [[GistNet]]
Expected Dynamics
- Early: The baseline LOD0-only WC wins; GistNet learns substitutability; LOD1 gist loss helps stabilize semantics.
- Mid: Mixed WCs beat LOD0; LensNet discovers where summaries help; ΔNLL supervision sharpens span scoring.
- Late: Stable focus patterns; occasional LOD2 loss improves global semantic consistency; strong gains vs. LOD0 baseline at fixed compute.
Summary
This version adds precise, efficient horizon scoring (teacher-forced) and gist-level auxiliary losses that align predicted semantics with ground-truth summaries, without autoregressive rollouts. Together with regularized argmin and ΔNLL supervision, MegaContext learns to manage its own focus and detail adaptively while keeping compute fixed.