-
Notifications
You must be signed in to change notification settings - Fork 162
Make DistributedSampler stateful #1315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
AI Store test can be safely ignored for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, but would like to simplify the code a bit and move the tests around as well
@@ -1947,6 +1960,116 @@ def test_sampler_reproducibility(self): | |||
ls[i].append(next(its[i])) | |||
self.assertEqual(ls[0], ls[1]) | |||
|
|||
def test_initialization_StatefulDistributedSampler(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move all of these tests out to a new file called test_sampler.py
. You can update https://github.com/pytorch/data/blob/main/.github/workflows/stateful_dataloader_ci.yml to call it in an additional step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created here: https://github.com/pytorch/data/blob/stateful_distributedsampler/test/stateful_dataloader/test_sampler.py
Added new line here:
- name: Run StatefulDataSampler tests with pytest - datasampler |
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler | ||
|
||
dataset = self.dataset | ||
sampler = StatefulDistributedSampler(dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For testing state_dict, let's have most of the tests set up with passing sampler + dataset to StatefulDataLoader so we can test that it works end-to-end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might need to use a dummy Collate function to easily inspect elements, check the test_state_dict.py
file for examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New tests here:
data/test/stateful_dataloader/test_sampler.py
Line 173 in cdc5d31
def test_dataloader_state_dict(self): |
self.next_yielded = None | ||
|
||
def __iter__(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to fork the DistributedSampler.__iter__
code here instead and just update, instead of having a separate Iterator class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data/torchdata/stateful_dataloader/sampler.py
Line 149 in cdc5d31
self.indices = list(super().__iter__()) |
if self.sampler.shuffle: | ||
# deterministically shuffle based on epoch and seed | ||
g = torch.Generator() | ||
g.manual_seed(self.sampler.seed + self.sampler.epoch) | ||
indices = torch.randperm(len(self.sampler.dataset), generator=g).tolist() # type: ignore[arg-type] | ||
else: | ||
indices = list(range(len(self.sampler.dataset))) # type: ignore[arg-type] | ||
|
||
if not self.sampler.drop_last: | ||
# add extra samples to make it evenly divisible | ||
padding_size = self.sampler.total_size - len(indices) | ||
if padding_size <= len(indices): | ||
indices += indices[:padding_size] | ||
else: | ||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | ||
else: | ||
# remove tail of data to make it evenly divisible. | ||
indices = indices[: self.sampler.total_size] | ||
assert len(indices) == self.sampler.total_size | ||
|
||
# subsample | ||
indices = indices[self.sampler.rank : self.sampler.total_size : self.sampler.num_replicas] | ||
assert len(indices) == self.sampler.num_samples | ||
|
||
self.parent_iterator = iter(indices) | ||
self.indices = list(self.parent_iterator) | ||
self.current_index = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to call the original code instead of forking it here?
def state_dict(self) -> Dict[str, Any]: | ||
return self.sampler.state_dict() | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
self.sampler.load_state_dict(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this both here and in the main sampler class, can we consolidate to have this in just one place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of suggestions, but looks great! very nice test suite.
When you're done making changes, please run the fbcode CI for media_dataloader
Co-authored-by: Andrew Ho <andrewkh@meta.com>
@ramanishsingh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@ramanishsingh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This pull request was exported from Phabricator. Differential Revision: D61772177 |
Fixes #1269
Changes
torchdata/stateful_dataloader/sampler.py
: Added new classesStatefulDistributedSampler
and_StatefulDistributedSamplerIterator
test/stateful_dataloader/test_dataloader.py
new tests forStatefulDistributedSampler