Skip to content

feat: add class-based callback system for training lifecycle hooks#706

Open
hrathina wants to merge 2 commits into
instructlab:mainfrom
hrathina:feat/callback-mechanism
Open

feat: add class-based callback system for training lifecycle hooks#706
hrathina wants to merge 2 commits into
instructlab:mainfrom
hrathina:feat/callback-mechanism

Conversation

@hrathina

@hrathina hrathina commented Jun 23, 2026

Copy link
Copy Markdown

Summary

Adds a callback mechanism to the InstructLab Training library, enabling users and Training Hub to hook into training lifecycle events without modifying source code.

  • Introduces TrainerCallback base class with 13 lifecycle hooks that users subclass and override
  • Callbacks are async fire-and-forget with exception isolation — they never block or crash training
  • Fires on all ranks; each callback gates its own behavior via context.is_world_process_zero or context.is_local_process_zero
  • Serializable across the torchrun subprocess boundary via inspect.getsource + base64 encoding
  • on_train_end blocks up to 10 seconds to allow cleanup callbacks to finish before process exit

Files changed

File Change
src/instructlab/training/callbacks.py New — TrainerCallback, TrainingContext, CallbackManager, serialization
src/instructlab/training/config.py Add callbacks field to TrainingArgs
src/instructlab/training/batch_loss_manager.py Add 2 hooks (on_before_forward, on_after_backward)
src/instructlab/training/main_ds.py Add 11 hooks in train(), serialization in run_training(), deserialization in main()
src/instructlab/training/__init__.py Export TrainerCallback and TrainingContext
tests/unit/test_callbacks.py New — 34 unit tests

Test plan

  • 34 unit tests covering dispatch, snapshot isolation, exception isolation, all-ranks firing, rank gating, kwargs validation, close(), serialization round-trip, public API exposure
  • Full regression suite: 199 passed, 4 pre-existing failures (unrelated LoRA/tensorboard/wandb dependencies)
  • Lint and format clean

Resolves #694

Summary by CodeRabbit

  • New Features

    • Added a training callback system with lifecycle hooks (train begin/end, epoch/step boundaries, before forward/after backward, optimizer step, logging, evaluation, and checkpoint save).
    • Exposed callbacks via training configuration and added CLI support to serialize/deserialize callbacks and enable them during distributed runs.
  • Tests

    • Added comprehensive unit tests covering callback registration/removal, hook dispatch behavior, exception suppression, context snapshot isolation, rank gating, and CLI serialization round-trips.

@coderabbitai

coderabbitai Bot commented Jun 23, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a callback subsystem for training: new callback contracts and context types, asynchronous dispatch, serialization across torchrun, lifecycle hook wiring in the training loop, and package exports plus tests.

Changes

Training Callback System

Layer / File(s) Summary
Callback contracts and config
src/instructlab/training/callbacks.py, src/instructlab/training/config.py
Defines HOOK_NAMES, TrainingContext, and TrainerCallback, and adds callbacks: list | None plus arbitrary_types_allowed=True to TrainingArgs.
Async callback manager
src/instructlab/training/callbacks.py
Implements CallbackManager with background asyncio dispatch, context snapshotting and validation, callback registration/removal, exception suppression, and shutdown.
Callback serialization and CLI transport
src/instructlab/training/callbacks.py, src/instructlab/training/main_ds.py
Adds base64 source serialization for callbacks, JSON list transport helpers, and --callbacks wiring through run_training and the CLI parser.
Training loop hook wiring
src/instructlab/training/batch_loss_manager.py, src/instructlab/training/main_ds.py
Passes callback_manager into BatchLossManager and train(), fires lifecycle hooks around minibatches, optimizer steps, logging, evaluation, saves, and shutdown, and builds the manager in main() from decoded callbacks.
Exports and unit tests
src/instructlab/training/__init__.py, tests/unit/test_callbacks.py
Exports TrainerCallback and TrainingContext from the package namespace and adds unit tests for context defaults, hook dispatch, serialization, gating, and wiring.

Sequence Diagram(s)

sequenceDiagram
  participant run_training
  participant torchrun
  participant main
  participant train
  participant BatchLossManager
  participant CallbackManager
  participant TrainerCallback

  run_training->>torchrun: --callbacks=encoded callbacks
  torchrun->>main: args.callbacks
  main->>CallbackManager: deserialize + configure context
  main->>train: callback_manager
  train->>CallbackManager: fire lifecycle hooks
  CallbackManager->>TrainerCallback: dispatch snapshot context
  train->>BatchLossManager: process_batch(callback_manager)
  BatchLossManager->>CallbackManager: fire on_before_forward / on_after_backward
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇 Hop hop, the hooks now sing,
Through async threads they spread their wing.
A callback hops from step to save,
Then back to training, calm and brave.

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Linked Issues check ⚠️ Warning The PR adds callbacks and wiring, but it appears to fire on all ranks with context-based gating instead of rank-0-only by default. Make callback dispatch default to rank-0/main-process only and add the requested flat run_training/TrainingArgs hook registration API.
Docstring Coverage ⚠️ Warning Docstring coverage is 8.41% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: adding a class-based training callback system.
Out of Scope Changes check ✅ Passed The changed files all support the callback system, its serialization, training-loop hooks, exports, or tests.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@mergify mergify Bot added the testing Relates to testing label Jun 23, 2026

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (2)
tests/unit/test_callbacks.py (1)

195-205: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Strengthen test_on_train_end_blocks to verify actual blocking behavior.

Current callback body is instant, so this test can pass without proving fire("on_train_end") waits. Add a deliberate delay in callback and assert elapsed wall time.

Suggested assertion upgrade
     def test_on_train_end_blocks(self):
-        called = []
+        called = []

         class SlowCb(TrainerCallback):
             def on_train_end(self, context):
+                time.sleep(0.05)
                 called.append(True)

         mgr = CallbackManager()
         mgr.add_callback(SlowCb())
+        t0 = time.perf_counter()
         mgr.fire("on_train_end")
+        elapsed = time.perf_counter() - t0
         assert called == [True]
+        assert elapsed >= 0.05
+        mgr.close()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/test_callbacks.py` around lines 195 - 205, The
test_on_train_end_blocks test only verifies that the callback is called, but
does not actually prove that the fire method waits for the callback to complete.
Strengthen this test by adding a deliberate time delay inside the
SlowCb.on_train_end method, then measure the elapsed wall time around the
mgr.fire("on_train_end") call and assert that the elapsed time is at least as
long as the injected delay, which will verify that fire actually blocks and
waits for the callback to finish executing.
src/instructlab/training/main_ds.py (1)

355-371: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

on_log fires only on rank-0, unlike the other hooks.

This hook is nested inside the if local_rank == 0: block, so it dispatches only on rank-0, whereas on_step_begin, on_step_end, on_epoch_begin, on_epoch_end, on_save, and on_train_end fire on all ranks. This asymmetry is reasonable since the logging metrics (current_lr, cuda_mem_allocated, throughput) are computed only on rank-0, but it's an inconsistency callback authors won't expect. Worth documenting in the hook contract that on_log is rank-0-only.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/instructlab/training/main_ds.py` around lines 355 - 371, The
callback_manager.fire("on_log") call is nested inside the if local_rank == 0:
block, making it rank-0-only, unlike other hooks such as on_step_begin,
on_step_end, on_epoch_begin, on_epoch_end, on_save, and on_train_end which fire
on all ranks. Document this rank-0-only behavior in the callback hook contract
or interface documentation to clarify this asymmetry for callback authors who
may not expect on_log to behave differently from the other hooks.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/instructlab/training/callbacks.py`:
- Around line 219-262: Add documentation to the serialize_callback and
deserialize_callback functions clearly stating that TrainerCallback subclasses
must be self-contained (with no external symbol dependencies) and have
zero-argument constructors. Then enhance serialize_callback to validate these
constraints by attempting to instantiate the callback class with no arguments
and checking for any NameError when the callback is executed, raising
informative errors during serialization in the parent process rather than
allowing silent failures in the worker. Additionally, update the test cases to
cover callbacks with constructor parameters and external dependencies to ensure
the validation catches these cases.

In `@src/instructlab/training/main_ds.py`:
- Around line 448-453: The `_save_and_exit()` function returns early at lines
285 and 309, which bypasses the `on_train_end` callback dispatch and
`callback_manager.close()` cleanup code at lines 451-453. To fix this, ensure
that `on_train_end` and `close()` are invoked even when `_save_and_exit()`
triggers an early exit. Either add these cleanup calls directly in the
`_save_and_exit()` function before each return statement, or wrap the training
logic in a try-finally block that guarantees execution of these callbacks and
cleanup regardless of how the function exits.
- Around line 404-407: The code is accessing args.ckpt_output_dir which does not
exist as a defined argument in the parser, causing an AttributeError crash
during the first checkpoint save. Replace all occurrences of
args.ckpt_output_dir with args.output_dir in the three callback_manager.fire()
calls where checkpoint_path is being set as a parameter. The correct attribute
available from the argument parser is args.output_dir.
- Around line 683-703: The TrainerCallback class docstring lacks critical
documentation about the callback execution model. Add prominent documentation to
the TrainerCallback class that clearly states callbacks fire on all distributed
ranks, not just rank 0, and explicitly require callback authors to check
context.is_world_process_zero before executing any rank-specific side effects
such as checkpointing, logging, or job submission. Include concrete examples of
when this guard is necessary to prevent callback authors from inadvertently
running duplicate operations across ranks.

In `@tests/unit/test_callbacks.py`:
- Line 70: Replace all fixed time.sleep(0.1) calls with event-based
synchronization using threading.Event for deterministic completion checks.
Create a threading.Event object before each test section, set it inside the
callback function to signal completion, and then use event.wait(timeout=0.1)
instead of the sleep call. This applies to all occurrences throughout the test
file at the locations mentioned: lines 70, 84, 109, 121, 138, 151, 192, 238, and
253, where the pattern involves waiting for a callback to complete rather than
using arbitrary timing delays.
- Around line 58-255: The test class TestCallbackManager creates multiple
CallbackManager instances throughout its test methods but does not properly
clean up the background threads they spawn. Add a call to mgr.close() at the end
of each test method in TestCallbackManager (in test_fire_dispatches,
test_fire_skips_non_overridden, test_has_callbacks, test_snapshot_isolation,
test_exception_isolation, test_multiple_callbacks, test_kwargs_set_on_snapshot,
test_add_callback_type_error, test_remove_callback_by_instance,
test_remove_callback_by_type, test_fire_all_ranks, test_on_train_end_blocks,
test_fire_invalid_kwarg_raises, test_empty_manager_no_callbacks,
test_hook_name_set_on_snapshot, and test_dict_fields_snapshot_isolation) to
ensure the background thread is properly terminated after each test completes
and prevent thread leakage.

---

Nitpick comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 355-371: The callback_manager.fire("on_log") call is nested inside
the if local_rank == 0: block, making it rank-0-only, unlike other hooks such as
on_step_begin, on_step_end, on_epoch_begin, on_epoch_end, on_save, and
on_train_end which fire on all ranks. Document this rank-0-only behavior in the
callback hook contract or interface documentation to clarify this asymmetry for
callback authors who may not expect on_log to behave differently from the other
hooks.

In `@tests/unit/test_callbacks.py`:
- Around line 195-205: The test_on_train_end_blocks test only verifies that the
callback is called, but does not actually prove that the fire method waits for
the callback to complete. Strengthen this test by adding a deliberate time delay
inside the SlowCb.on_train_end method, then measure the elapsed wall time around
the mgr.fire("on_train_end") call and assert that the elapsed time is at least
as long as the injected delay, which will verify that fire actually blocks and
waits for the callback to finish executing.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bd59a26c-5aa7-48c2-816a-43015e0fc798

📥 Commits

Reviewing files that changed from the base of the PR and between fa5b53b and 2c59bcb.

📒 Files selected for processing (6)
  • src/instructlab/training/__init__.py
  • src/instructlab/training/batch_loss_manager.py
  • src/instructlab/training/callbacks.py
  • src/instructlab/training/config.py
  • src/instructlab/training/main_ds.py
  • tests/unit/test_callbacks.py

Comment thread src/instructlab/training/callbacks.py
Comment thread src/instructlab/training/main_ds.py Outdated
Comment thread src/instructlab/training/main_ds.py
Comment thread src/instructlab/training/main_ds.py
Comment thread tests/unit/test_callbacks.py
Comment thread tests/unit/test_callbacks.py
@hrathina hrathina force-pushed the feat/callback-mechanism branch from 278713d to d118f17 Compare June 23, 2026 12:56

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/unit/test_callbacks.py`:
- Around line 211-215: In the test_close method, restructure the test to wrap
the CallbackManager instance creation and assertions in a try/finally block,
ensuring that m.close() is always called in the finally block, even if any
assertion fails before it. This prevents the background thread from leaking into
subsequent tests when assertions fail.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 936f5620-2645-4239-b9fb-9b7c9f79ca07

📥 Commits

Reviewing files that changed from the base of the PR and between 2c59bcb and d118f17.

📒 Files selected for processing (6)
  • src/instructlab/training/__init__.py
  • src/instructlab/training/batch_loss_manager.py
  • src/instructlab/training/callbacks.py
  • src/instructlab/training/config.py
  • src/instructlab/training/main_ds.py
  • tests/unit/test_callbacks.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • src/instructlab/training/batch_loss_manager.py
  • src/instructlab/training/config.py
  • src/instructlab/training/init.py
  • src/instructlab/training/main_ds.py

Comment thread tests/unit/test_callbacks.py
@hrathina

Copy link
Copy Markdown
Author

@RobotSail please review this PR. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

testing Relates to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Callbacks to Instructlab Training

1 participant