LensNet Training (Phase 1)

LensNet is a shallow transformer (2/4/8 layers) that operates directly on the Working Context embeddings. Phase 1 drops the historical tail-gist extras and trains the controller via random variant sampling + pairwise preference comparisons. This page documents the data pipeline, loss, and telemetry that now exist in code (mc/runtime.py, scripts/base_train.py).

High-Level Loop

  1. Build variants. For each training sequence we construct:
    • lod_0_baseline: pure LOD0 window that preserves the recent tail and is trimmed to the current curriculum target length so it is directly comparable to every random variant.
    • num_random_variants stochastic compressions sampled by running the focus allocator with random scores (see _build_random_variant_set).
  2. Score variants. Run the base model on every variant to obtain next-token losses.
  3. Compute advantages. For each variant compute adv_delta = loss_variant − loss_baseline. Negative numbers mean “better than baseline”.
  4. Form preference pairs. Pair every non-baseline WC with the best (lowest-loss) variant, then add the highest-Δloss comparisons among the remaining variants. Each tuple stores (better, worse, strength) where strength = |Δloss|.
  5. Optimize LensNet. Feed the worse WC through LensNet to obtain policy scores and apply the Bradley–Terry loss + optional rank/budget penalties.

This replaces the legacy “best WC LOD map” regression and trace-log replay buffer. All supervision is local to the current batch and amortizes the base-model forward pass we already perform for GistNet training.

Data Specification

FieldShapeSource
baseline_variantWorkingContextVariant_build_lod0_baseline_variant
random_variantsList of WorkingContextVariant_build_random_variant_set
adv_deltascalar per variant_compute_variant_losses
preference_pairsList of (better, worse, strength)_build_preference_pairs

Notes:

  • Variants always respect coverage + tail invariants before entering the loss.
  • strength stores the raw |Δloss|; _build_pairwise_targets later applies tanh(strength) when turning it into per-entry targets so outliers stay bounded.
  • If stochastic allocator steps happen to produce near-identical WCs, we inject an additional “aggressive compression” variant (≈½ the target length) so every batch contains at least one obviously different WC.

Preference Loss

LensNet outputs signed policy scores s_i per WC entry (tanh-clamped to ±1). For each (better, worse) pair we compute:

  1. Align entries via the best WC’s LOD map (_build_pairwise_targets), resulting in a per-entry target t_j ∈ {−strength, +strength} and mask m_j.
  2. Apply a Bradley–Terry / logistic preference loss with temperature τ = mc_lens_temperature:

Implementation detail (mc/runtime.py::_compute_lens_losses):

  • t_j > 0 ⇒ pushing scores positive (expand).
  • t_j < 0 ⇒ pushing scores negative (collapse).
  • collapse_weight optionally reweights collapse targets to balance expand-heavy batches.
  • Larger strength values shrink the effective temperature (s_j is multiplied by max(1, strength) / τ) so undeniable preferences push the logits harder than ambiguous ones.

Rank & Budget Penalties (Optional)

We retain the legacy hooks:

  • Rank loss (lens_rank_weight): hinge loss that forces the mean score over positive targets to exceed the mean over negative targets by lens_margin.
  • Budget loss (lens_budget_weight): squared difference between collapse/expand mass weighted by span sizes to discourage “expand everything”.

Phase 1 keeps these weights low (0.5 / 0.1) so the preference loss dominates.

Temperature & Hyperparameters

All CLI knobs surface through run10.sh and MCConfig:

FlagDescription
--mc_num_random_variantsNumber of random WCs per sequence.
--mc_train_wc_lengthTarget length for random variants at the end of training (we anneal linearly from 0.8 × max_seq_len down to this value; default end = 0.75 × max_seq_len).
--mc_max_lens_pairsUpper bound on (better, worse) pairs per sample.
--mc_lens_temperatureBradley–Terry temperature (default 1.0).
--mc_lens_rank_weight, --mc_lens_budget_weight, --mc_lens_margin, --mc_lens_collapse_weightLegacy regularizer knobs that still work.
--mc_lens_hard_negative_ratioFraction of preference pairs to keep after sorting by advantage (default 1.0 = keep all).

Lowering the temperature sharpens comparisons (steeper gradients for a given Δloss). Raising it smooths updates when the random variants are noisy.

Telemetry

We log the following metrics to W&B (scripts/base_train.py):

MetricMeaning
mc/adv_delta_mean, mc/adv_delta_p95, mc/adv_delta_stdStatistics of Δloss relative to the baseline (want ≤ 0).
mc/preference_corr_meanCorrelation between policy scores and adv_delta (want negative). Check mc/preference_corr_mean_valid to know if the value is meaningful.
mc/preference_agreement, mc/preference_pair_countFraction / count of preference pairs where LensNet’s signed scores pick the same winner as the measured Δloss.
mc/lens_lossMean preference loss value.
mc/variants_total, mc/variants_meanHow many WCs were evaluated per batch.
mc/policy_score_abs_mean, mc/policy_score_std_meanHow much of the tanh range LensNet is actively using across variants.
mc/lod_loss/{0,1,2}LOD-specific losses weighted by each variant’s coverage histogram so every active LOD level is represented.

--mc_log_lens_debug prints per-variant stats (“PrefDebug”) so we can inspect score distributions and correlations during training.

Curriculum & Hard-Negative Mining

  • Curriculum: The random-variant target length anneals linearly from 80 % of max_seq_len at the beginning of training down to mc_train_wc_length (default 0.75 × max_seq_len). Because the baseline WC is trimmed to the same length, every comparison is length-fair.
  • Hard negatives: Every non-baseline variant is paired with the current best WC before we sort the remaining pairs by raw Δloss and keep the top mc_lens_hard_negative_ratio fraction (default 1.0). This guarantees a “real” hard negative for every variant while still allowing us to focus on the most informative extra comparisons.

Stability Tricks

MechanismKnobsPurpose
Advantage normalizationlens_adv_norm_betaMaintain an EMA of adv_delta mean/variance so normalized advantages (norm_adv_delta) drive the preference strength.
Policy KL regularizationlens_kl_weightKeeps LensNet from thrashing by penalizing divergence from the previous policy scores per working context.
Budget smoothinglens_budget_smooth_weight, lens_budget_smooth_betaTracks an EMA of net expand/collapse mass and penalizes deviations to keep scores budget-neutral despite random variants.

All three reuse the WC variants already generated for base LLM + GistNet training; no extra model passes are required.

Future Work (Phase 2 Ideas)

  • Reintroduce tail-gist cross conditioning once preference training is stable.
  • Log per-entry legality masks and re-enable a soft illegality penalty if we observe the allocator fighting LensNet.
  • Explore replay buffers / curriculum sampling so LensNet sees more diverse focus plans than pure random variants.

For implementation details see mc/runtime.py (_build_random_variant_set, _compute_variant_losses, _compute_lens_losses) and scripts/base_train.py (W&B logging, CLI plumbing).