feat: add class-based callback system for training lifecycle hooks#706
feat: add class-based callback system for training lifecycle hooks#706hrathina wants to merge 2 commits into
Conversation
📝 WalkthroughWalkthroughAdds a callback subsystem for training: new callback contracts and context types, asynchronous dispatch, serialization across ChangesTraining Callback System
Sequence Diagram(s)sequenceDiagram
participant run_training
participant torchrun
participant main
participant train
participant BatchLossManager
participant CallbackManager
participant TrainerCallback
run_training->>torchrun: --callbacks=encoded callbacks
torchrun->>main: args.callbacks
main->>CallbackManager: deserialize + configure context
main->>train: callback_manager
train->>CallbackManager: fire lifecycle hooks
CallbackManager->>TrainerCallback: dispatch snapshot context
train->>BatchLossManager: process_batch(callback_manager)
BatchLossManager->>CallbackManager: fire on_before_forward / on_after_backward
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (2)
tests/unit/test_callbacks.py (1)
195-205: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winStrengthen
test_on_train_end_blocksto verify actual blocking behavior.Current callback body is instant, so this test can pass without proving
fire("on_train_end")waits. Add a deliberate delay in callback and assert elapsed wall time.Suggested assertion upgrade
def test_on_train_end_blocks(self): - called = [] + called = [] class SlowCb(TrainerCallback): def on_train_end(self, context): + time.sleep(0.05) called.append(True) mgr = CallbackManager() mgr.add_callback(SlowCb()) + t0 = time.perf_counter() mgr.fire("on_train_end") + elapsed = time.perf_counter() - t0 assert called == [True] + assert elapsed >= 0.05 + mgr.close()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/test_callbacks.py` around lines 195 - 205, The test_on_train_end_blocks test only verifies that the callback is called, but does not actually prove that the fire method waits for the callback to complete. Strengthen this test by adding a deliberate time delay inside the SlowCb.on_train_end method, then measure the elapsed wall time around the mgr.fire("on_train_end") call and assert that the elapsed time is at least as long as the injected delay, which will verify that fire actually blocks and waits for the callback to finish executing.src/instructlab/training/main_ds.py (1)
355-371: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value
on_logfires only on rank-0, unlike the other hooks.This hook is nested inside the
if local_rank == 0:block, so it dispatches only on rank-0, whereason_step_begin,on_step_end,on_epoch_begin,on_epoch_end,on_save, andon_train_endfire on all ranks. This asymmetry is reasonable since the logging metrics (current_lr,cuda_mem_allocated, throughput) are computed only on rank-0, but it's an inconsistency callback authors won't expect. Worth documenting in the hook contract thaton_logis rank-0-only.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/instructlab/training/main_ds.py` around lines 355 - 371, The callback_manager.fire("on_log") call is nested inside the if local_rank == 0: block, making it rank-0-only, unlike other hooks such as on_step_begin, on_step_end, on_epoch_begin, on_epoch_end, on_save, and on_train_end which fire on all ranks. Document this rank-0-only behavior in the callback hook contract or interface documentation to clarify this asymmetry for callback authors who may not expect on_log to behave differently from the other hooks.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/instructlab/training/callbacks.py`:
- Around line 219-262: Add documentation to the serialize_callback and
deserialize_callback functions clearly stating that TrainerCallback subclasses
must be self-contained (with no external symbol dependencies) and have
zero-argument constructors. Then enhance serialize_callback to validate these
constraints by attempting to instantiate the callback class with no arguments
and checking for any NameError when the callback is executed, raising
informative errors during serialization in the parent process rather than
allowing silent failures in the worker. Additionally, update the test cases to
cover callbacks with constructor parameters and external dependencies to ensure
the validation catches these cases.
In `@src/instructlab/training/main_ds.py`:
- Around line 448-453: The `_save_and_exit()` function returns early at lines
285 and 309, which bypasses the `on_train_end` callback dispatch and
`callback_manager.close()` cleanup code at lines 451-453. To fix this, ensure
that `on_train_end` and `close()` are invoked even when `_save_and_exit()`
triggers an early exit. Either add these cleanup calls directly in the
`_save_and_exit()` function before each return statement, or wrap the training
logic in a try-finally block that guarantees execution of these callbacks and
cleanup regardless of how the function exits.
- Around line 404-407: The code is accessing args.ckpt_output_dir which does not
exist as a defined argument in the parser, causing an AttributeError crash
during the first checkpoint save. Replace all occurrences of
args.ckpt_output_dir with args.output_dir in the three callback_manager.fire()
calls where checkpoint_path is being set as a parameter. The correct attribute
available from the argument parser is args.output_dir.
- Around line 683-703: The TrainerCallback class docstring lacks critical
documentation about the callback execution model. Add prominent documentation to
the TrainerCallback class that clearly states callbacks fire on all distributed
ranks, not just rank 0, and explicitly require callback authors to check
context.is_world_process_zero before executing any rank-specific side effects
such as checkpointing, logging, or job submission. Include concrete examples of
when this guard is necessary to prevent callback authors from inadvertently
running duplicate operations across ranks.
In `@tests/unit/test_callbacks.py`:
- Line 70: Replace all fixed time.sleep(0.1) calls with event-based
synchronization using threading.Event for deterministic completion checks.
Create a threading.Event object before each test section, set it inside the
callback function to signal completion, and then use event.wait(timeout=0.1)
instead of the sleep call. This applies to all occurrences throughout the test
file at the locations mentioned: lines 70, 84, 109, 121, 138, 151, 192, 238, and
253, where the pattern involves waiting for a callback to complete rather than
using arbitrary timing delays.
- Around line 58-255: The test class TestCallbackManager creates multiple
CallbackManager instances throughout its test methods but does not properly
clean up the background threads they spawn. Add a call to mgr.close() at the end
of each test method in TestCallbackManager (in test_fire_dispatches,
test_fire_skips_non_overridden, test_has_callbacks, test_snapshot_isolation,
test_exception_isolation, test_multiple_callbacks, test_kwargs_set_on_snapshot,
test_add_callback_type_error, test_remove_callback_by_instance,
test_remove_callback_by_type, test_fire_all_ranks, test_on_train_end_blocks,
test_fire_invalid_kwarg_raises, test_empty_manager_no_callbacks,
test_hook_name_set_on_snapshot, and test_dict_fields_snapshot_isolation) to
ensure the background thread is properly terminated after each test completes
and prevent thread leakage.
---
Nitpick comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 355-371: The callback_manager.fire("on_log") call is nested inside
the if local_rank == 0: block, making it rank-0-only, unlike other hooks such as
on_step_begin, on_step_end, on_epoch_begin, on_epoch_end, on_save, and
on_train_end which fire on all ranks. Document this rank-0-only behavior in the
callback hook contract or interface documentation to clarify this asymmetry for
callback authors who may not expect on_log to behave differently from the other
hooks.
In `@tests/unit/test_callbacks.py`:
- Around line 195-205: The test_on_train_end_blocks test only verifies that the
callback is called, but does not actually prove that the fire method waits for
the callback to complete. Strengthen this test by adding a deliberate time delay
inside the SlowCb.on_train_end method, then measure the elapsed wall time around
the mgr.fire("on_train_end") call and assert that the elapsed time is at least
as long as the injected delay, which will verify that fire actually blocks and
waits for the callback to finish executing.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bd59a26c-5aa7-48c2-816a-43015e0fc798
📒 Files selected for processing (6)
src/instructlab/training/__init__.pysrc/instructlab/training/batch_loss_manager.pysrc/instructlab/training/callbacks.pysrc/instructlab/training/config.pysrc/instructlab/training/main_ds.pytests/unit/test_callbacks.py
278713d to
d118f17
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/unit/test_callbacks.py`:
- Around line 211-215: In the test_close method, restructure the test to wrap
the CallbackManager instance creation and assertions in a try/finally block,
ensuring that m.close() is always called in the finally block, even if any
assertion fails before it. This prevents the background thread from leaking into
subsequent tests when assertions fail.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 936f5620-2645-4239-b9fb-9b7c9f79ca07
📒 Files selected for processing (6)
src/instructlab/training/__init__.pysrc/instructlab/training/batch_loss_manager.pysrc/instructlab/training/callbacks.pysrc/instructlab/training/config.pysrc/instructlab/training/main_ds.pytests/unit/test_callbacks.py
🚧 Files skipped from review as they are similar to previous changes (4)
- src/instructlab/training/batch_loss_manager.py
- src/instructlab/training/config.py
- src/instructlab/training/init.py
- src/instructlab/training/main_ds.py
|
@RobotSail please review this PR. Thank you. |
Summary
Adds a callback mechanism to the InstructLab Training library, enabling users and Training Hub to hook into training lifecycle events without modifying source code.
TrainerCallbackbase class with 13 lifecycle hooks that users subclass and overridecontext.is_world_process_zeroorcontext.is_local_process_zeroinspect.getsource+ base64 encodingon_train_endblocks up to 10 seconds to allow cleanup callbacks to finish before process exitFiles changed
src/instructlab/training/callbacks.pysrc/instructlab/training/config.pycallbacksfield to TrainingArgssrc/instructlab/training/batch_loss_manager.pysrc/instructlab/training/main_ds.pysrc/instructlab/training/__init__.pytests/unit/test_callbacks.pyTest plan
Resolves #694
Summary by CodeRabbit
New Features
Tests