-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[refactor] CogVideoX followups + tiled decoding support #9150
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Something interesting/fishy going on with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! I have left some questions. LMK if they are unclear.
Additionally, let's include a note on the memory savings due to tiling in the docs?
|
||
def _set_gradient_checkpointing(self, module, value=False): | ||
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): | ||
module.gradient_checkpointing = value | ||
|
||
def clear_fake_context_parallel_cache(self): | ||
def _clear_fake_context_parallel_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better!
I used this method, and the result is also 90 seconds. I didn’t replicate the issue you’re mentioning, so I need to check further. This shouldn’t be an issue. |
@sayakpaul I've added a few explanations here. Could you please review again? |
I think it would be good to add dynamic positional embeddings as well, to test the generalization capabilities of CogVideoX and remove the 48 frame, 480 height, 720 width limit. I have a POC almost ready for the same. Should I push here and share results in a while, or do it in a separate PR? Shouldn't break anything existing IMO @sayakpaul |
Let’s do separate PR |
I've pushed the code to https://github.com/huggingface/diffusers/tree/cogvideox-dynamic-pos-embeds for possibly future reference. After further testing with number of frames greater than 49 and different resolutions, I think the results are not convincing enough to support it. I think best not to add it at the moment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍🏽
@sayakpaul, could you check the note about memory optimizations here? If it looks good, we can merge this I think. cc @zRzRzRzRzRzRzR for visibility Edit: By the way, accelerate must be installed from source to replicate the memory numbers here. Until the next accelerate release, should we add a note saying the same? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the memory optims. Sleek!
* refactor context parallel cache; update torch compile time benchmark * add tiling support * make style * remove num_frames % 8 == 0 requirement * update default num_frames to original value * add explanations + refactor * update torch compile example * update docs * update * clean up if-statements * address review comments * add test for vae tiling * update docs * update docs * update docstrings * add modeling test for cogvideox transformer * make style
* refactor context parallel cache; update torch compile time benchmark * add tiling support * make style * remove num_frames % 8 == 0 requirement * update default num_frames to original value * add explanations + refactor * update torch compile example * update docs * update * clean up if-statements * address review comments * add test for vae tiling * update docs * update docs * update docstrings * add modeling test for cogvideox transformer * make style
What does this PR do?
Code
Memory usage:
Results:
output.webm
output_tiling.webm
Note that you will need to install
accelerate:main
from source for this to work and get the expected numbers I'm getting above. If you're using the stable version of accelerate, you might see an addition 5-7GB usage.Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@DN6 @sayakpaul @zRzRzRzRzRzRzR