Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train v2] Add Train v2 user-facing callback interface #49819

Merged
merged 15 commits into from
Jan 17, 2025

Conversation

justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Jan 14, 2025

Summary

Adds UserCallback, which is a DeveloperAPI that exposes the reported results from ray.train.report calls as well as exceptions that are raised during training.

For now, it only exposes 2 callback methods:

  • after_report(run_context: TrainRunContext, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint])
    • The metrics argument is a list of all reported dicts by ray.train.report, where metrics[i] is the dict reported by rank i.
    • The checkpoint argument is populated if one of the workers reported a checkpoint. There's only one of these since all workers report to the same storage location.
  • after_exception(run_context: TrainRunContext, worker_exceptions: Dict[int, Exception])
    • worker_exceptions maps worker rank to an exception. Does not include ranks that do not have errors associated with them.

The TrainRunContext is the run-level context that lives on the Train driver.

Here's a diagram showing the dependencies of the UserCallback.

Screenshot 2025-01-17 at 11 44 15 AM

New CallbackHandler concept

As part of this PR, I consolidated the CheckpointHandler logic to gather checkpoints/metrics reported by all workers into a ReportCallbackHandler, since the UserCallback.on_report method also relies on the same logic of gathering results across multiple poll_status calls.

I also introduced a UserCallbackHandler which implements the logic needed to call the UserCallback interface.

These two new callbacks are examples of "callback handlers" which take in a list of callbacks and are responsible for calling their methods. The naming convention is [SubscriberCallback]Handler.

Follow-up

A follow-up PR here is to implement TuneReportCallback as a provided UserCallback.

Example Usage

def after_report(self, run_context, metrics, checkpoint):
    print(run_context)
    # TrainRunContext(run_config=RunConfig(name='ray_train_run-2025-01-16_18-18-39', storage_path='/private/var/folders/v6/r_z7clls5hl9bhbc2s7yyvjm0000gq/T/pytest-of-justin/pytest-21/test_user_callback0', storage_filesystem=None, failure_config=FailureConfig(), checkpoint_config=CheckpointConfig(), sync_config='DEPRECATED', verbose='DEPRECATED', stop='DEPRECATED', callbacks=[<test_data_parallel_trainer.MyUserCallback object at 0x11d4c1d10>], progress_reporter='DEPRECATED', log_to_file='DEPRECATED', local_dir=None))

    print(metrics)
    # [{'rank': 0}, {'rank': 1}]

def after_exception(self, run_context, worker_exceptions):
    print(worker_exceptions)
    # {0: UserExceptionWithTraceback(ValueError('error'), 'Traceback (most recent call last):\n  File "/Users/justin/Developer/ray/python/ray/train/v2/tests/test_data_parallel_trainer.py", line 201, in _train_fn\n    raise ValueError("error")\nValueError: error\n')}

commit 5ad1842dadb8ee69d0c1d9baa6374dd0ebb4af01
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Mon Jan 13 15:46:48 2025 -0800

    move the base RayTrainCallback

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit c5b503ba987e8282839440eb631b0076e30adb44
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Mon Jan 13 15:41:11 2025 -0800

    add small comment

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit ae14cef6109aa307580cca8a695cf44537ee6b74
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Mon Jan 13 14:55:21 2025 -0800

    demote to developer API

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit 3aa90038cf435f224fcb3ba6137ca172e5b41e4b
Merge: ac825b553c f70bc4570c
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Mon Jan 13 14:55:07 2025 -0800

    Merge branch 'master' of https://github.com/anyscale/rayturbo into justinvyu/v2/user_callbacks

commit ac825b553c82508d96a6acd95d7b17e1bc89384b
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Thu Dec 19 15:45:33 2024 -0800

    rename report aggregator -> report handler

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit b02deae96196c02924f6a224948efe6364a0cf98
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Thu Dec 19 15:41:16 2024 -0800

    create user callback invoker

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit ac1cb60d467aa139d1b1ce5a75a59531b41c191e
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Thu Dec 19 14:43:01 2024 -0800

    CheckpointCallback -> ReportCallback + updated interface

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit 0d5e705c6b34b98276c9da562a37e277efa45800
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Thu Dec 19 14:01:38 2024 -0800

    update callback interface (no leaking DeveloperAPI callbacks)

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit a1c6c665ad211a9ccd9f8dada7edbd701bdc5675
Merge: 3eb2d344c8 d3c8f6be52
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Thu Dec 19 13:33:33 2024 -0800

    Merge branch 'master' of https://github.com/anyscale/runtime into justinvyu/v2/user_callbacks

commit 3eb2d344c859261598c1fa8afcd396049479eb5c
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Tue Dec 17 00:36:22 2024 -0800

    rename test

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit 870092621e2e5868771d63a0f19f08ebecb8ed2c
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Tue Dec 17 00:35:20 2024 -0800

    fix test

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit 5e5230e94d1b18327add1b0edd8f184059273ca5
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Tue Dec 17 00:34:57 2024 -0800

    fix typos

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit 43f02d929ef19527da4a24eb96c1d7ef4d9e045d
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Mon Dec 16 18:01:05 2024 -0800

    fix import

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit bc00026a39bc568ac54545b141db744d068c90dd
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Mon Dec 16 17:59:05 2024 -0800

    enable RunConfig(callbacks)

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit 0738c2628fd7a576502b5e359f73fa1bb2b95a32
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Mon Dec 16 17:57:02 2024 -0800

    remove checkpoint handler

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

commit 00471ef8f5bfbd57118d1409d1d730f6b73ce9a6
Author: Justin Yu <justinvyu@anyscale.com>
Date:   Mon Dec 16 17:55:27 2024 -0800

    add new report aggregator callback

    Signed-off-by: Justin Yu <justinvyu@anyscale.com>

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Comment on lines +36 to +38
def after_worker_group_poll_status(
self, worker_group_status: WorkerGroupStatus
) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we shouldn't have too many callbacks implementing this method since it might accumulate a lot of overhead every health check period (ex: 2 seconds).

UserCallbackInvoker also implements this now to capture the worker group failures.

@@ -16,7 +16,6 @@

import ray
from ray.train._internal.utils import count_required_parameters
from ray.train.v2._internal.execution.callback import Callback
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the invoke_callbacks_context_managers a bit to fix a cyclical dependency issue.

Comment on lines +9 to +16
class RayTrainCallback:
"""Base Ray Train callback interface."""

pass


@DeveloperAPI
class UserCallback(RayTrainCallback):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UserCallback sounds good & straightforward to me.

Comment on lines +37 to +38
checkpoint: A Checkpoint object that has been persisted to
storage. This is None if no workers reported a checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Clarify that the checkpoint points to the persisted location, since this is currently a bit ambiguous.

"""

callbacks: Optional[List["Callback"]] = None
callbacks: Optional[List["RayTrainCallback"]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep this as UserCallback for now to prevent misuse, and expand as needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit confusing if we have a UserCallback class but it's not exposed in the public API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about making the type UserCallback but then keeping this https://github.com/ray-project/ray/pull/49819/files#diff-730d6551cfa9d3b55d3d413ce0500fad427b6409bc47d7e103f4f586341c3b2eR210-R214

So that workarounds / custom callbacks are still possible for workarounds and in unit tests. It will just break the typechecking on the editor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that works for now.

@@ -111,16 +110,20 @@ def __init__(
checkpoint_config=self._run_config.checkpoint_config,
storage_context=self._storage_context,
)

self._checkpoint_handler = CheckpointHandler(self._checkpoint_manager)
self._report_handler = ReportHandler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Does this need to be an instance variable or should it just be a local variable? Seems unused outside of the init.

Suggested change
self._report_handler = ReportHandler(
report_handler = ReportHandler(

from ray.train.v2.api.callback import UserCallback


class UserCallbackInvoker(WorkerGroupCallback, ReportCallback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer if all callbacks have Callback as their suffix, so it's easy to tell if it is a Callback or a related class.

For this one, it can be something like CombinedUserCallback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for ReportHandler -> ReportHandlerCallback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongpeng-guo actually thought it was more confusing for callbacks that are not meant to be further subclassed to have a Callback suffix.

Ex: ReportCallback is a callback interface, but UserCallbackInvoker just implements the interface but should not be a new callback interface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm isn't that what we do for all the other callbacks (e.g. AcceleratorSetupCallback, BackendSetupCallback, ...)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the idea was to remove the suffix later but didn't do it yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, as the callback structure are getting more complicated, i.e., some callbacks becomes the input of other callbacks, the invoker callbacks are not being subclassed, but the main purpose of them is to invoke other lower level callbacks. I think it's more clear for us to differentiate these two levels of callbacks.

I.e., right now, the UserCallbackInvoker and ReportHandler are the two callbacks that will take other callbacks as input, i.e, these two are the higher level internal callbacks. I think it makes sense to name these two things specially. For other first level callbacks, I think it's fine to have the Callback suffix with them.

Another way to think of it is: Callbacks are relatively stand alone modules that if we modify one callback, it should not influence the effect of another callback. But any changes in the two invokers will change how their subscribers effects. That's why I think these two callbacks should be at least named specially. I may still prefer to have a world of only one level of callbacks if possible.

Copy link
Contributor

@hongpeng-guo hongpeng-guo Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline discussion result XXXXCallbackHandler takes in a list of XXXXCallback and is responsible for calling its methods.

Comment on lines +9 to +16
class RayTrainCallback:
"""Base Ray Train callback interface."""

pass


@DeveloperAPI
class UserCallback(RayTrainCallback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UserCallback sounds good & straightforward to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add tests?

Copy link
Contributor

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. The following comments are not blocking.

It seems the main goal of this PR proposing UserCallback is to enable TuneReportCallback which is a P0 feature to integrate Tune with Train V2. I think we should deliver this PR ASAP to release Train V2 as a whole.

However, as I am thinking about some possible feature on async checkpointing and evaluation, I feel report(metrics: Dict, checkpoint(Optional[CKPT])) may not be an optimal API in the long run. I.e., if the evaluation happens on a separate cluster, it may not be very intuitive for every training worker to report a metric. I think it's possible that we need to revamp the report/ checkpoint related callbacks in some near future. But this should be non-blocking for now.

@justinvyu
Copy link
Contributor Author

@hongpeng-guo Yeah, I'm also not too confident about how future-proof the current callback API is, so I'm keeping it as a DeveloperAPI and not surfacing it in docs. The only thing in docs will be the TuneReportCallback that is our "out of the box" integration API between Train v2 and Tune.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@hongpeng-guo hongpeng-guo added the go add ONLY when ready to merge, run all tests label Jan 17, 2025
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu justinvyu enabled auto-merge (squash) January 17, 2025 20:03
lint
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@github-actions github-actions bot disabled auto-merge January 17, 2025 21:42
@justinvyu justinvyu enabled auto-merge (squash) January 17, 2025 22:02
@justinvyu justinvyu merged commit 5e42bdd into ray-project:master Jan 17, 2025
6 checks passed
@justinvyu justinvyu deleted the train_v2/user_callbacks branch January 17, 2025 22:50
justinvyu added a commit that referenced this pull request Jan 21, 2025
… Train results to Tune (#49927)

Add `TuneReportCallback`, which implements the `UserCallback` interface
introduced in #49819.

This is a callback provided by Ray Train out of the box to support the
Ray Tune integration. The callback collects intermediate metrics
reported by Train workers and propagates the rank 0 metrics to the Tune
driver. This allows Ray Tune searchers, schedulers, etc. to kick in.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
win5923 pushed a commit to win5923/ray that referenced this pull request Jan 23, 2025
… Train results to Tune (ray-project#49927)

Add `TuneReportCallback`, which implements the `UserCallback` interface
introduced in ray-project#49819.

This is a callback provided by Ray Train out of the box to support the
Ray Tune integration. The callback collects intermediate metrics
reported by Train workers and propagates the rank 0 metrics to the Tune
driver. This allows Ray Tune searchers, schedulers, etc. to kick in.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
simonsays1980 pushed a commit to simonsays1980/ray that referenced this pull request Jan 23, 2025
… Train results to Tune (ray-project#49927)

Add `TuneReportCallback`, which implements the `UserCallback` interface
introduced in ray-project#49819.

This is a callback provided by Ray Train out of the box to support the
Ray Tune integration. The callback collects intermediate metrics
reported by Train workers and propagates the rank 0 metrics to the Tune
driver. This allows Ray Tune searchers, schedulers, etc. to kick in.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
anson627 pushed a commit to anson627/ray that referenced this pull request Jan 31, 2025
…9819)

Adds `UserCallback`, which is a `DeveloperAPI` that exposes the reported
results from `ray.train.report` calls as well as exceptions that are
raised during training.

For now, it only exposes 2 callback methods:
* `after_report(run_context: TrainRunContext, metrics: List[Dict[str,
Any]], checkpoint: Optional[Checkpoint])`
* `after_exception(run_context: TrainRunContext, worker_exceptions:
Dict[int, Exception])`

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Anson Qian <anson627@gmail.com>
anson627 pushed a commit to anson627/ray that referenced this pull request Jan 31, 2025
… Train results to Tune (ray-project#49927)

Add `TuneReportCallback`, which implements the `UserCallback` interface
introduced in ray-project#49819.

This is a callback provided by Ray Train out of the box to support the
Ray Tune integration. The callback collects intermediate metrics
reported by Train workers and propagates the rank 0 metrics to the Tune
driver. This allows Ray Tune searchers, schedulers, etc. to kick in.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Anson Qian <anson627@gmail.com>
srinathk10 pushed a commit that referenced this pull request Feb 2, 2025
Adds `UserCallback`, which is a `DeveloperAPI` that exposes the reported
results from `ray.train.report` calls as well as exceptions that are
raised during training.

For now, it only exposes 2 callback methods:
* `after_report(run_context: TrainRunContext, metrics: List[Dict[str,
Any]], checkpoint: Optional[Checkpoint])`
* `after_exception(run_context: TrainRunContext, worker_exceptions:
Dict[int, Exception])`

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
srinathk10 pushed a commit that referenced this pull request Feb 2, 2025
… Train results to Tune (#49927)

Add `TuneReportCallback`, which implements the `UserCallback` interface
introduced in #49819.

This is a callback provided by Ray Train out of the box to support the
Ray Tune integration. The callback collects intermediate metrics
reported by Train workers and propagates the rank 0 metrics to the Tune
driver. This allows Ray Tune searchers, schedulers, etc. to kick in.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
anyadontfly pushed a commit to anyadontfly/ray that referenced this pull request Feb 13, 2025
…9819)

Adds `UserCallback`, which is a `DeveloperAPI` that exposes the reported
results from `ray.train.report` calls as well as exceptions that are
raised during training.

For now, it only exposes 2 callback methods:
* `after_report(run_context: TrainRunContext, metrics: List[Dict[str,
Any]], checkpoint: Optional[Checkpoint])`
* `after_exception(run_context: TrainRunContext, worker_exceptions:
Dict[int, Exception])`

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Puyuan Yao <williamyao034@gmail.com>
anyadontfly pushed a commit to anyadontfly/ray that referenced this pull request Feb 13, 2025
… Train results to Tune (ray-project#49927)

Add `TuneReportCallback`, which implements the `UserCallback` interface
introduced in ray-project#49819.

This is a callback provided by Ray Train out of the box to support the
Ray Tune integration. The callback collects intermediate metrics
reported by Train workers and propagates the rank 0 metrics to the Tune
driver. This allows Ray Tune searchers, schedulers, etc. to kick in.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Puyuan Yao <williamyao034@gmail.com>
park12sj pushed a commit to park12sj/ray that referenced this pull request Mar 18, 2025
…9819)

Adds `UserCallback`, which is a `DeveloperAPI` that exposes the reported
results from `ray.train.report` calls as well as exceptions that are
raised during training.

For now, it only exposes 2 callback methods:
* `after_report(run_context: TrainRunContext, metrics: List[Dict[str,
Any]], checkpoint: Optional[Checkpoint])`
* `after_exception(run_context: TrainRunContext, worker_exceptions:
Dict[int, Exception])`

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
park12sj pushed a commit to park12sj/ray that referenced this pull request Mar 18, 2025
… Train results to Tune (ray-project#49927)

Add `TuneReportCallback`, which implements the `UserCallback` interface
introduced in ray-project#49819.

This is a callback provided by Ray Train out of the box to support the
Ray Tune integration. The callback collects intermediate metrics
reported by Train workers and propagates the rank 0 metrics to the Tune
driver. This allows Ray Tune searchers, schedulers, etc. to kick in.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants