Torch.compile Stabilization Plan
Goals
- Keep
torch.compile(mode="reduce-overhead")enabled for LensNet (and eventually GistNet) without turning off Inductor’s cudagraph capture. - Ensure compiled modules run reliably during training and inference by auditing tensor lifetimes and entry points.
- Add instrumentation and tests so future regressions are caught immediately.
Plan
1. Instrument LensNet/GistNet usage
- Add lightweight counters/logging in
MCControllerto record:- How many times LensNet/GistNet run per micro-step (training vs inference).
- Shapes of the inputs.
- Whether we re-enter LensNet within the same backward pass (e.g., inference allocator).
- Expose these metrics via debug telemetry so we can validate behavior quickly.
2. Audit tensor lifetimes
- Trace every place where LensNet outputs are stored beyond the compiled call:
WorkingContextVariant.policy_scores- Focus allocator cached scores
- Preference/policy telemetry
- For each path, decide whether to:
- Consume the data immediately inside the compiled graph.
- Clone the tensor at the user boundary (single helper, no ad-hoc
.clone()sprinkled everywhere). - Re-run LensNet explicitly so we never hold onto compiled outputs.
3. Restructure buffer management
- Introduce helper functions:
_run_lensnet_batched(inputs) -> scoresthat performs the cudagraph step mark + clone once._lensnet_allocator_scores(wc)for inference allocator.
- Ensure these helpers manage clones and mark-step boundaries consistently so compiled graphs aren’t reused incorrectly.
4. Build a torch.compile harness
- Create a standalone script/test that:
- Compiles LensNet with the actual config.
- Runs the same batched call as training.
- Runs the allocator-style repeated calls.
- Checks for cudagraph errors.
- Run this harness before enabling compile in
mc_run/run10.shso we only toggle the flag in the standard workflow after it passes. - ✅ Implemented in
scripts/mc_compile_harness.py. We still run all official experiments throughmc_run.sh … run10.sh; the harness exists purely for debugging so we can reproduce compile bugs before toggling the flag in the standard workflow.
5. Re-enable compile progressively
- Once the harness passes, re-enable compile for training (single batch) and gate inference allocator behind a config flag so we can roll out gradually.
- Update
mc_run.sh/run10.shso the supported way to flip compile on is passing--mc_compile_lensnet=1 [--mc_compile_lensnet_inference=1]through the normal entry point. Avoid bespoke scripts so the behavior mirrors production. - Status: enabling
--mc_compile_lensnet=1viamc_run.sh … run10.shstill throws Inductor/cudagraph exceptions (same failure we saw before this refactor). Compile remains disabled in default configs until we fix those runtime errors.
6. Documentation & tests
- Update
obsidian/reference/LensNet Pairwise Training.mdwith the new invariants (“LensNet outputs must be consumed or cloned via the helper”). - Add unit tests that:
- Mock LensNet to ensure
_run_lensnet_batchedis used everywhere. - Simulate repeated calls (allocator) to confirm we call
cudagraph_mark_step_begin.
- Mock LensNet to ensure
- Verify
tests/test_mc_components.pycovers these cases.
Execution checklist
- Add instrumentation/logging for LensNet/GistNet invocation counts.
- Implement
_run_lensnet_batchedand_lensnet_allocator_scoreshelpers with mark-step + clone. - Refactor runtime/allocator to use the helpers exclusively.
- Build/run the torch.compile harness.
- Update docs and unit tests.
- Verify full pytest suite + harness + end-to-end training smoke test (or as close as feasible without full GPU run).
Current Status
- Instrumentation + helper refactors landed (
mc/runtime.py), and the newtests/test_mc_components.py::test_lensnet_timers_and_usagecovers the controller helpers. - The harness (
scripts/mc_compile_harness.py) reproduces LensNet/GistNet compile issues locally, but runningmc_run.sh … run10.sh --mc_compile_lensnet=1still throws Inductor/cudagraph exceptions, so compile remains disabled in default configs. - Next steps: finish the doc/test polish (item 5), integrate compile toggles/testing into the run10 workflow, and chase down the remaining compile-time exceptions so we can check off item 6.