MegaContext: End-to-End Training Strategy

Status: Plan of record (POR). This PRD defines the active JT + nanochat integration contract; legacy alternating-optimization notes are historical reference only.

This document describes the end-to-end training procedure for |MegaContext, integrating GistNet, LensNet, the frozen base LLM, and lightweight LoRA adapters into a unified optimization loop. The objective is to co-train summarization, focus, and decoding so the system learns how to summarize, focus, and reason with compressed context.

Terminology: LOD0 = raw tokens, LOD1 = 32→1 gists, LOD2 = 32 LOD1 gists (i.e., 1024→1), etc.


Motivation

Early MegaContext prototypes used independent training phases (train GistNet, then LensNet, then the base/LoRA). That induces:

  • Distribution shift: each component is optimized on stale or mismatched inputs.
  • Weak coupling: improvements in one part don’t propagate to the others.
  • Wasted compute: repeated forward passes with no joint gradient signal.

End-to-end training aligns GistNet, LensNet, and the base model under a single objective (next-token NLL), so the system co-learns how to summarize, focus, and reason with compressed context.


Overview

Each training sequence (context C) is used to create many working contexts (WCs) — mixtures of LOD0 tokens and GistNet summaries (LOD1, LOD2). Each WC represents a possible “focus” of the same full training context. These are evaluated, compared, and used to refine the models jointly.

Full training context C (e.g. 4k tokens)
 ├── Build Full [[MegaContext Tree]] - [[GistNet]] summaries (LOD1, LOD2, etc.) for C
 ├── Build N1 sampled Working Contexts of length C2 (e.g., 1k tokens)
 │    ├── Always include a WC with the C2 most recent LOD0 tokens (baseline)
 │    ├── Sample the rest from the set of valid summary combinations of C that fit a WC of length C2
 │    └── The sampling strategy should ensure diversity (penalize selection of summary-WCs that are too similar wrt the gist selection).
 ├── Run [[LensNet]] + [[Focus Allocator]] on each of the N1 WCs → expand to N2 WCs
 │    ├── Up to 4 simultaneous edits per N1 WC (expand/compress local spans)
 │    ├── Include sibling perturbations for additional ΔNLL labels
 │    └── Dedup
 ├── Evaluate base model on all N2 WCs (next-LOD0 token loss from training data)
 │    ├── Teacher-forced multi-position horizon (H tokens; see section below)
 │    └── Backprop through Base + [[GistNet]]
 └── Update [[LensNet]] using ΔNLL supervision
        ├── Identify ideal WC(s) for supervision (regularized argmin)
        ├── Derive the ideal target focus scores needed for each WC to move towards its ideal WC
        └── Backprop through [[LensNet]] based on the loss of actual focus scores vs. ideal target

Detailed Procedure

Step 1. Build MegaContext Tree

  • Input: tokenized training context C.
  • Goal: produce multi-level summaries using GistNet (e.g., LOD1 = 32→1, LOD2 = 32 LOD1 → 1).
  • Store the hierarchical MegaContext Tree for re-use across WCs.

Step 2. Sample Working Contexts (N1)

  • Goal: create N1 diverse WCs of a fixed length C2 (e.g., 1k tokens).
  • Always include: the C2 most recent LOD0 tokens (baseline WC).
  • Others: sample valid combinations of {LOD0, LOD1, LOD2, …} that fit C2.
  • Apply diversity penalties to avoid near-duplicate WCs.

Step 3. LensNet Expansion and Focus Allocation (N2)

  • Run LensNet on each WC to produce focus scores per span.
  • The Focus Allocator applies up to four simultaneous edits (expand/collapse).
  • Generate N2 new WCs; add sibling perturbations for extra ΔNLL labels.
  • Deduplicate and keep the slate diverse.

Step 4. Evaluate Base + GistNet (Teacher-Forced Horizon)

Each WC in N2 is evaluated for next-LOD0 token prediction using teacher forcing over a short horizon H.

What “teacher forcing over a horizon” means (ELI5)

  • You give the model one WC (a mix of LOD0/1/2 embeddings).
  • You then ask it to predict the next real tokens from the dataset:
    • For step 1, you score how well it predicts the true next token.
    • For step 2, you still feed the true token (not the model’s guess) and score the next one.
    • Repeat for H steps.
  • The WC stays fixed during these H steps (no refocusing).
  • The horizon loss for a WC is the average NLL across those H predictions.

Efficient ways to compute horizon loss

You do not need iterative rollouts to compute an H-step teacher-forced loss:

  • Packed forward (preferred): build one sequence = [WC || ground-truth continuation of length H], run a single forward, and compute loss only on the last H tokens. This exploits full-seq training efficiency (FlashAttention etc.).
  • Incremental with KV cache (alternative): precompute KV for the WC once, then step through H positions teacher-forced, reusing KV. This is cheaper if |WC| » H, but usually the packed forward is simpler and fast on modern kernels.

Choose based on memory vs. speed: If the packed forward exceeds memory, use the KV-cached incremental method for that WC.

  • Backpropagate losses through the base model and any GistNet summaries present in the WC.

Step 5. Update LensNet (ΔNLL Supervision + Regularized Argmin)

LensNet learns from observed NLLs and target focuses implied by the regularized argmin.

1) Identify ideal WC(s)

For each WC, find its ideal target WC (WC*) using a distance-regularized argmin over the evaluated slate:

WC* = argmin_i [ L(WC_i) + λ * D(WC_i, WC_current) ]
  • L(WC_i): the horizon NLL of WC_i.
  • D(WC_i, WC_current): cost (number/weight) of edits required to reach WC_i from the current WC.
  • λ: stability coefficient (discourages large focus jumps between near-tie WCs).

Each WC may have a different WC* (the argmin is relative to its own current configuration).

2) Compute target focus scores

From WCWC*, let the Allocator compute the minimal edit set (bounded) that moves WC toward WC*. Translate those edits into target focus scores per span (expand = positive, collapse = negative).

3) LensNet objective

  • Score-matching loss: compare LensNet’s actual focus scores on WC vs. the target focus scores derived above.
  • Action-level ΔNLL hints: if sibling perturbations were evaluated, include the measured ΔNLL per edit as auxiliary regression labels (normalized by token cost).
  • Regularizers:
    • Zero-sum budget: net focus must conserve total WC tokens.
    • Legality mask: penalize invalid operations (e.g., expanding LOD0, collapsing beyond root).
    • Hysteresis: discourage immediate reversals of recent edits.

Gist-Level Auxiliary Losses (LOD1 / LOD2)

You can add semantic “gist” losses that align predicted content with the ground-truth summaries, without doing a full rollout.

LOD1 (32-token) gist loss

  • When H is a multiple of 32 (e.g., H=32, 64, 96, …), split the H future tokens into 32-token blocks.
  • For each 32-token block:
    1. Compute the ground-truth LOD1 gist via GistNet on the true tokens.
    2. Compute the predicted LOD1 gist in one of two ways (no feedback into inputs):
      • Hard: take argmax tokens from the model’s logits for those 32 positions; embed them; run GistNet.
      • Soft (preferred): compute expected embeddings by multiplying softmax probs by the token embedding matrix at each position; feed those embeddings to GistNet to get a soft gist.
    3. Add a similarity loss (e.g., cosine or MSE) between predicted and ground-truth gists.
  • Weight this loss with a small coefficient so it supplements token NLL.

LOD2 (1024-token) gist loss

  • If you want a coarser semantic target, set H=1024 (or accumulate across steps) so a full LOD2 gist is defined:
    1. Ground-truth LOD2 gist = GistNet(32 LOD1 gists from the true 1024 tokens).
    2. Predicted LOD2 gist = GistNet over the model’s soft or hard predicted embeddings for those 1024 positions (structured in the same 32×32 tree).
    3. Add a small similarity loss at LOD2.

Compute considerations: H=1024 is heavy. Prefer packed forward if memory allows; otherwise compute in chunks and combine gists (GistNet is cheap). Start with LOD1; add LOD2 once everything is stable.


Why This Design Works

Design choiceReason
Regularized argminStabilizes focus targets; prefers small, consistent improvements over chaotic switches.
ΔNLL supervisionTies LensNet directly to measured utility of focus changes.
Equal token budgetEnsures we learn allocation, not “more tokens helps.”
Diversity in WCsAvoids myopic strategies (e.g., pure recency).
Gist losses (LOD1/LOD2)Adds semantic pressure that complements token NLL, without rollout.
Packed forward or KV-cached horizonEfficient horizon scoring; choose per-memory profile.

Key Hyperparameters

SymbolMeaningTypical Value
N1# of initial sampled WCs8–32
N2# of post-Allocator WCs32–128
C2WC token budget (subset of C)512–2048
HHorizon tokens for teacher-forced loss32–64 (start), 128–256 (later), 1024 (optional LOD2)
λStability regularization for argmin0.1–0.3
α1LOD1 gist loss weight0.01–0.05
α2LOD2 gist loss weight0.005–0.02

Pseudocode (Horizon + Gist Losses)

for each WC in N2_WCs:
    # Horizon loss (packed forward)
    seq = concat(WC, ground_truth_next_H_tokens)
    logits = BaseModel(seq)                 # single forward
    token_loss = NLL(logits[-H:], targets=ground_truth_next_H_tokens)

    # Optional LOD1 gist loss (if H multiple of 32)
    if H % 32 == 0:
        gt_blocks = split(ground_truth_next_H_tokens, size=32)
        pred_probs = softmax(logits[-H:])               # [H, vocab]
        pred_embeds = pred_probs @ TokenEmbedding       # [H, d]
        pred_blocks = split(pred_embeds, size=32)

        gt_lod1 = [ [[GistNet]](Embed(gt_block)) for gt_block in gt_blocks ]
        pred_lod1 = [ [[GistNet]](pred_block) for pred_block in pred_blocks ]

        lod1_loss = mean( 1 - cosine(pred_lod1[i], gt_lod1[i]) for i in range(len(gt_blocks)) )

    # Optional LOD2 gist loss (if H == 1024)
    if H == 1024:
        gt_lod2 = [[GistNet]]( stack(gt_lod1) )            # 32 LOD1 gists -> 1
        pred_lod2 = [[GistNet]]( stack(pred_lod1) )
        lod2_loss = 1 - cosine(pred_lod2, gt_lod2)

    total_loss = token_loss + α1 * lod1_loss + α2 * lod2_loss
    backprop(total_loss)   # through Base + [[GistNet]]

Expected Dynamics

  • Early: The baseline LOD0-only WC wins; GistNet learns substitutability; LOD1 gist loss helps stabilize semantics.
  • Mid: Mixed WCs beat LOD0; LensNet discovers where summaries help; ΔNLL supervision sharpens span scoring.
  • Late: Stable focus patterns; occasional LOD2 loss improves global semantic consistency; strong gains vs. LOD0 baseline at fixed compute.

Summary

This version adds precise, efficient horizon scoring (teacher-forced) and gist-level auxiliary losses that align predicted semantics with ground-truth summaries, without autoregressive rollouts. Together with regularized argmin and ΔNLL supervision, MegaContext learns to manage its own focus and detail adaptively while keeping compute fixed.