MC Lens Loss Bottleneck — What’s Slow and How to Fix It
Problem (today)
mc/lens_loss_msis ~2.9 s per micro-step even thoughmc/lens_forward_msis ~2 ms. The controller, not the base model, dominates step time.- At your current settings (device micro-batch=13 after auto-batch,
mc_num_random_variants=2, grad_accum=5), we process ~195 working contexts per optimizer step:- 13 samples × (baseline + 2 variants) = 39 WCs per micro-step
- 5 grad-accum micro-steps → ~195 WCs total
- The loss code is Python-heavy: building pairwise targets, masks, and stats per variant/pair, then launching many tiny ops. That’s where the ~3 s comes from.
Where the time is spent (code pointers)
mc/runtime.py:_compute_lens_losses(around line 2128):- Flattens all variants, runs
global_score_cache = self._batch_variant_scores(all_variants)once (fast, ~2 ms). - For each sample:
_build_preference_pairs,_build_lod_lookup,_build_pairwise_targets, then_batched_pref_lossesand_batched_aux_losses. - The heavy part is the per-pair/variant Python work (mask/target construction, scatter into lists) and multiple small tensor ops per pair.
- Flattens all variants, runs
mc/runtime.py:_build_pairwise_targetsand_build_lod_lookup:- Python loops over positions/lods with tensor
.item()/masking; lots of host-side work and GPU syncs.
- Python loops over positions/lods with tensor
mc/runtime.py:_batched_aux_losses:- Iterates over
pair_entries, does elementwise ops per pair; not fused or vectorized.
- Iterates over
Constraints / goal
- Keep the training objective intact (pairwise preference + budget/coverage terms) but make
lens_loss_msnegligible compared to the base forward. Target: tens of ms → ideally sub-ms. - Maintain the same variant semantics:
(1 + mc_num_random_variants)per sample; grad-accum stays, no reuse across micro-steps.
Fix options (stack-ranked by impact)
-
Batch the entire preference loss on GPU (eliminate Python loops):
- Build flat tensors for all pairs across the micro-batch:
scores_flat,targets_flat,weights_flat,scale_flat,span_tokens_flat.
- Compute the Bradley–Terry (or hinge) loss in one fused kernel or a small number of vectorized ops. Use torch ops only; no per-pair Python.
- Compute budget/coverage penalties from precomputed histograms via
torch.bincount/scatter_addon GPU. - Expected impact: remove ~90% of
lens_loss_ms(turn seconds into single-digit ms).
- Build flat tensors for all pairs across the micro-batch:
-
Precompute / cache pairwise targets on CPU without
.item()syncs:- Move
positions,lodsto CPU once per WC; build masks/targets with numpy/torch CPU ops; avoid per-element.item()on CUDA tensors. - Cache LOD lookups per WC to reuse across pairs.
- Expected impact: large cut to Python overhead even before full batching.
- Move
-
Reduce pair count aggressively (sampling):
- Sample a fixed small number of preference pairs per variant (e.g., top-K spans or K random spans) instead of all spans.
- Downsample span tokens before loss (e.g., stride >1) to shrink targets/masks.
- Expected impact: linear reduction in loss compute proportional to pair count.
-
Gate the loss frequency:
- Compute lens loss every N steps (or every M micro-steps), accumulate/average, and set it to zero otherwise.
- Expected impact: amortize cost without changing math when it runs.
-
Toggle off auxiliary terms that add overhead:
- Temporarily disable budget/coverage penalties (
lens_budget_weight,lens_margin, etc.) to isolate the pairwise loss cost. Re-enable after batching. - Expected impact: simplifies the loss graph; less Python/tensor churn.
- Temporarily disable budget/coverage penalties (
-
AOT/fused custom kernel:
- Write a custom CUDA/torch.compile fused op for the Bradley–Terry loss over flattened pairs. This bypasses Python and leverages GPU fully.
- Expected impact: minimal overhead; stretch goal if vectorized torch ops aren’t enough.
-
Variant count sanity:
- Verify we are only doing
(micro_batch_size × (1 + mc_num_random_variants))variants per micro-step and that grad-accum is the only multiplier. (Current telemetry shows 39 per micro-step for micro-batch=13, which is correct.) - No change expected; just ensure we don’t inadvertently inflate pair counts.
- Verify we are only doing
Recommended path (pragmatic + fast to implement)
- Implement a vectorized preference loss:
- Build per-variant score cache once (
global_score_cacheis already batched). - Flatten all pair payloads into contiguous tensors and compute the Bradley–Terry/hinge loss in a single torch op block.
- Compute budget/coverage penalties with
torch.bincount/scatter_add. - Remove per-pair loops and
.item()calls.
- Build per-variant score cache once (
- If still >50 ms, add pair sampling (e.g., cap pairs per variant) and/or loss gating (every N steps).
- Only then consider a fused kernel if torch ops aren’t sufficient.
Stretch ideas (if we need near-zero cost)
- Two-stage training: train LensNet offline or in a separate stage with cached WCs, then freeze during main training (lens loss = 0 during main loop).
- Proxy loss: replace pairwise comparisons with a cheap L2/hinge on a small subset of spans or a distilled target (e.g., match a slower teacher’s scores offline).
- Asynchronous loss: compute LensNet loss on a background stream / secondary device and update weights less frequently (requires careful optimizer handling).
Debug checkpoints
- After vectorization, log
mc/lens_loss_ms,mc/lens_forward_ms,mc/variants_totalat step 1 to confirm the drop. - Track peak memory; batching should also reduce per-step allocations by avoiding many tiny tensors.
If we execute the vectorized loss and optional pair sampling, lens_loss_ms should fall from ~2.9 s to the low-ms range, bringing tok/sec back toward the expected ~3× slowdown relative to baseline.***