Skip to content

[core] FreeNoise #8948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Aug 7, 2024
Merged

[core] FreeNoise #8948

merged 40 commits into from
Aug 7, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 23, 2024

What does this PR do?

FreeNoise is free lunch in existing short video diffusion models that allows longer video generation without additional training and almost no overhead in inference.

Project: http://haonanqiu.com/projects/FreeNoise.html
Paper: https://arxiv.org/abs/2310.15169
Code: https://github.com/arthur-qiu/FreeNoise-AnimateDiff

Fixes #5576.

Results

All inputs can be found here.

AnimateDiff Text-to-Video

context_length=16, context_stride=4, shuffle=False context_length=16, context_stride=4, shuffle=True
pipeline_animatediff_freenoise-shuffle_False-context_length_16-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_16-context_stride_4.webm
context_length=20, context_stride=4, shuffle=True context_length=20, context_stride=8, shuffle=True
pipeline_animatediff_freenoise-shuffle_True-context_length_20-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_20-context_stride_8.webm
context_length=24, context_stride=4, shuffle=True context_length=24, context_stride=8, shuffle=True
pipeline_animatediff_freenoise-shuffle_True-context_length_24-context_stride_4.webm
pipeline_animatediff_freenoise-shuffle_True-context_length_24-context_stride_8.webm
  • num_frames: 64
  • duration: ~50-60s (25 steps)
Code
import torch

from diffusers import AnimateDiffPipeline, DPMSolverMultistepScheduler, AutoencoderKL, MotionAdapter
from diffusers.utils import export_to_video

device = "cuda:0"

# Initialize models and pipeline
motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
pipe = AnimateDiffPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=motion_adapter, vae=vae, torch_dtype=torch.float16,
).to(device)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, beta_schedule="linear", algorithm_type="dpmsolver++", use_karras_sigmas=True)

num_seconds = 8
fps = 8
num_frames = num_seconds * fps

for context_length in [16, 20, 24]:
    for context_stride in [4, 8]:
        print(f"Processing {context_length=}, {context_stride=}")
        
        # Enable FreeNoise for long video generation
        pipe.enable_free_noise(context_length=context_length, context_stride=context_stride, weighting_scheme="pyramid", shuffle=True)

        # Run inference
        video = pipe(
            prompt="a panda, playing a guitar, sitting in a boat, in the ocean, mountains in background, sunny day, realistic, high quality",
            negative_prompt="bad quality, worst quality",
            num_frames=num_frames,
            num_inference_steps=25,
            guidance_scale=8,
            generator=torch.Generator().manual_seed(1337),
        ).frames[0]

        export_to_video(video, f"animatediff_freenoise/pipeline_animatediff_freenoise-shuffle_True-context_length_{context_length}-context_stride_{context_stride}.mp4", fps=fps)

        # Disable FreeNoise shuffling
        # pipe.disable_free_noise() # optional
        pipe.enable_free_noise(context_length=context_length, context_stride=context_stride, shuffle=False)

        # Run inference
        video = pipe(
            prompt="a panda, playing a guitar, sitting in a boat, in the ocean, mountains in background, sunny day, realistic, high quality",
            negative_prompt="bad quality, worst quality",
            num_frames=num_frames,
            num_inference_steps=25,
            guidance_scale=8,
            generator=torch.Generator().manual_seed(1337),
        ).frames[0]

        export_to_video(video, f"animatediff_freenoise/pipeline_animatediff_freenoise-shuffle_False-context_length_{context_length}-context_stride_{context_stride}.mp4", fps=fps)

AnimateDiff ControlNet

context_length=16, context_stride=4 context_length=16, context_stride=8
pipeline_animatediff_controlnet_freenoise-shuffle_True-context_length_16-context_stride_4.webm
pipeline_animatediff_controlnet_freenoise-shuffle_True-context_length_16-context_stride_8.webm

Additionally, using the code here:

context_length=16, context_stride=4 context_length=16, context_stride=8
animatediff_controlnet_long_1.webm
animatediff_controlnet_long_2.webm
  • num_frames: 104
  • duration: ~1m 45s (10 steps, 16, 4), ~1m 30s (10 steps, 16, 8)
Code
import torch

from controlnet_aux.processor import LineartAnimeDetector, OpenposeDetector
from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
from diffusers import ControlNetModel, LCMScheduler, AutoencoderKL, MotionAdapter
from diffusers.utils import export_to_video, load_video

device = "cuda:1"

# Initialize models and pipeline
controlnet1 = ControlNetModel.from_single_file("/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_openpose.pth", torch_dtype=torch.float16).to(device)
controlnet2 = ControlNetModel.from_single_file("/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_lineart.pth", torch_dtype=torch.float16)
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
pipe = AnimateDiffControlNetPipeline.from_pretrained(
    "stablediffusionapi/darksushimixv225", controlnet=[controlnet1, controlnet2], motion_adapter=motion_adapter, vae=vae, torch_dtype=torch.float16,
).to(device)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])

# Video credits: https://stable-diffusion-art.com/animatediff-prompt-travel-video2video/
select_nth_frame = 2
video = load_video("https://stable-diffusion-art.com/wp-content/uploads/2023/10/man_dance_2to3_24fps_9s.mp4")[::select_nth_frame]
width = 512
height = 768

# Preprocess video
lineart_processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators").to(device)
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device)
conditioning_frames1 = []
conditioning_frames2 = []

with pipe.progress_bar(total=len(video)) as progress_bar:
    for frame in video:
        conditioning_frames1.append(openpose_processor(frame, include_body=True, include_hand=True, include_face=True))
        conditioning_frames2.append(lineart_processor(frame))
        progress_bar.update()

for context_length in [16, 24]:
    for context_stride in [4, 8]:
        print(f"Processing {context_length=}, {context_stride=}")

        # Enable FreeNoise for long video generation
        pipe.enable_free_noise(context_length=context_length, context_stride=context_stride, weighting_scheme="pyramid", shuffle=True)

        # Run inference
        video = pipe(
            prompt="man dancing, blue shirt, red shorts, psychedelic",
            negative_prompt="bad quality, worst quality, jpeg artifacts, ugly",
            conditioning_frames=[conditioning_frames1, conditioning_frames2],
            controlnet_conditioning_scale=[0.5, 0.4],
            width=width,
            height=height,
            num_frames=len(video),
            num_inference_steps=10,
            guidance_scale=2,
            generator=torch.Generator().manual_seed(42),
        ).frames[0]

        export_to_video(video, f"animatediff_freenoise/pipeline_animatediff_controlnet_freenoise-shuffle_True-context_length_{context_length}-context_stride_{context_stride}.mp4", fps=12)

AnimateDiff Video2Video

context_length=16, context_stride=4 context_length=16, context_stride=8
pipeline_animatediff_vid2vid_freenoise-shuffle_True-context_length_16-context_stride_4.webm
pipeline_animatediff_vid2vid_freenoise-shuffle_True-context_length_16-context_stride_8.webm
  • num_frames: 150
  • duration: ~5m (25 steps, 16, 4), ~2m 30s (25 steps, 16, 8)
Code
import torch

from diffusers import AnimateDiffVideoToVideoPipeline, DPMSolverMultistepScheduler, AutoencoderKL, MotionAdapter
from diffusers.utils import export_to_video, load_video

device = "cuda:0"

# Initialize models and pipeline
motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=motion_adapter, vae=vae, torch_dtype=torch.float16,
).to(device)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, beta_schedule="linear", algorithm_type="dpmsolver++", use_karras_sigmas=True)

select_nth_frame = 2
video = load_video("racecar.mp4")[::select_nth_frame]
width = 512
height = 768

for context_length in [16, 24]:
    for context_stride in [4, 8]:
        # Enable FreeNoise for long video generation
        pipe.enable_free_noise(context_length=context_length, context_stride=context_stride, weighting_scheme="pyramid", shuffle=True)

        # Run inference
        video = pipe(
            prompt="racecar, vaporwave style, cyberpunk, intricately detailed, bright colors, 8k resolution, photorealistic, masterpiece, cinematic lighting",
            negative_prompt="bad quality, worst quality, jpeg artifacts, ugly",
            video=video,
            strength=0.6,
            width=width,
            height=height,
            num_inference_steps=25,
            guidance_scale=8.5,
            generator=torch.Generator().manual_seed(42),
        ).frames[0]

        export_to_video(video, f"animatediff_freenoise/pipeline_animatediff_vid2vid_freenoise-shuffle_True-context_length_{context_length}-context_stride_{context_stride}.mp4", fps=24)

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 @sayakpaul

cc @yiyixuxu as well for library design

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jul 24, 2024

The implementation here gives the following result:

Context length = 16, stride = 8 Context length = 16, stride = 8

I'm not sure why the last few frames go berserk. I will take a better look soon since we need to rewrite this anyway.

The implementation you see here is horrible. Let me explain why it's done this way at the moment:

  • FreeNoise requires modifying the temporal forward pass by breaking up [bhw, f, c] tensors into k = num_frames / context_length chunks of shape [bhw, k, c] and performing self/cross-attn on each chunk, following a weighted averaging of all frames.
  • In Diffusers, both the spatial and temporal forward pass is implemented with BasicTransformerBlock. This makes it very challenging to do frame-wise chunked inference to determine if a pass is spatial or temporal. Currently, to make FreeNoise work, I determine this by some hardcoded logic which will be removed later once we address the design problem.
  • I've retained the forward pass code for original implementation to highlight that the changes needed are quite minimal to support FreeNoise. However, it would be difficult to do so with BasicTransformerBlock.
Code to replicate
import torch

from diffusers import AnimateDiffPipeline
from diffusers.models import AutoencoderKL, MotionAdapter
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    algorithm_type="dpmsolver++",
    use_karras_sigmas=True,
)
pipe = AnimateDiffPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    vae=vae,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to(device)

pipe.enable_free_noise(context_length=16, context_stride=4, shuffle=True)

prompt = "a racoon playing a guitar, sitting in a boat, floating in the ocean, high quality, realistic"
negative_prompt = "bad quality, worst quality"

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=512,
    height=512,
    # num_frames=16,
    num_frames=80, # must be 80 for the current hardcoded logic to work
    num_inference_steps=25,
    guidance_scale=8,
    generator=torch.Generator().manual_seed(1337),
).frames[0]

export_to_gif(video, "animatediff_freenoise.gif")

I would like to hear your thoughts on the following.

  • Is the FreeNoiseMixin implementation a good way to go? We would somehow need to pass the mixin parameters from pipelines to model when FreeNoise is enabled. I'm thinking added_cond_kwargs could help here.
  • What do you think about a BasicTemporalTransformerBlock? It'll be a copy of the BasicTransformerBlock but only containing features specific to AnimateDiff Temporal Blocks. Will require some rewriting of existing code in a non-breaking manner. Any alternative suggestions are welcome.
  • FreeNoise could potentially be applied to any video model that has spatial block - temporal block - spatial block - ... pattern (free lunch). We should try it on our models after addressing the previous issues, no?

@sayakpaul
Copy link
Member

Just out of curiosity, what is the limit of "long" here? To me 16 is not long enough. Is it fair to expect a min-long video with free-noise? If not, perhaps it might be better to wait as the value add is not that evident. Of course, I could be looking at it entirely wrong. So, happy to stand corrected.

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jul 24, 2024

Just out of curiosity, what is the limit of "long" here? To me 16 is not long enough. Is it fair to expect a min-long video with free-noise? If not, perhaps it might be better to wait as the value add is not that evident. Of course, I could be looking at it entirely wrong. So, happy to stand corrected.

For text-to-video, I would say 5-30 seconds at 8-12 fps is currently considered "long" enough if the video is consistent. For control-based vid2vid, there is no limit really and you can animate really long videos (1+ mins), especially with Comfy.

The value added may not be evident from the example I showed. Keep in mind that I want to get FreeNoise to work first and iron out design/implementation bugs and then apply to other pipelines. For text-to-video, it is indeed really hard to achieve good videos even with FreeNoise. The place where it really shines is naive vid2vid, controlnet vid2vid, and in combination with other tricks, and I think we should really really support it because it opens up potential for many workflows such as the following within Diffusers. It is widely considered as the best open-source method for long generations (in vid2vid atleast from what I've seen). Essentially, all that's need to support it is chunked frame-wise inference in the BasicTransformerBlock and weighted-averaging of latents.

1, 2, 3, 4, 5, 6, 7 and 8 off the top of my saved reddit posts.

I don't have good workflows set up yet (WIP), so I stole this to generate videos with Comfy for the time being. You can find the results here.

  • The FreeNoise generations at 768x432 (AnimateDiff_00003.gif) takes 3 minutes 12 seconds and has background consistency somewhat.
  • The non-FreeNoise generation at 768x432 (AnimateDiff_00003_nofreenoise.gif) takes 9 minutes and the background keeps changing with a few more artifacts on the body. This is using the Context Scheduler approach for which I have an open PR but we're still deciding if that's worth adding in comparison to just FreeNoise.

cc @DN6

@sayakpaul
Copy link
Member

Thanks for explaining.

My take is we first make it work to convince ourselves about the results we feel good about. We can then work through this PR to reach a design. I think before that achieving those results would be a nice thing to optimize.

@sayakpaul
Copy link
Member

The FreeNoise generations at 768x432 (AnimateDiff_00003.gif) takes 3 minutes 12 seconds and has background consistency somewhat.
The non-FreeNoise generation at 768x432 (AnimateDiff_00003_nofreenoise.gif) takes 9 minutes and the background keeps changing with a few more artifacts on the body. This is using the Context Scheduler approach for which I have an open PR but we're still deciding if that's worth adding in comparison to just FreeNoise.

Which results are these in the videos you shared?

@a-r-r-o-w
Copy link
Member Author

Thanks for explaining.

My take is we first make it work to convince ourselves about the results we feel good about. We can then work through this PR to reach a design. I think before that achieving those results would be a nice thing to optimize.

I see what you mean. Alright, I'll get it to work with our current community AnimateDiff controlnet implementation (which I really think should now be in core because of how broadly the Comfy equivalent is used) and SparseCtrl once it's merged.

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jul 24, 2024

Which results are these in the videos you shared?

Check this. AnimateDiff_00003.gif is the FreeNoise version. AnimateDiff_00003-nofreenoise.gif is the non-FreeNoise version. The other files are input and different settings with FreeNoise.

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jul 24, 2024

Here are some results to demonstrate the effectiveness of FreeNoise in vid2vid settings:

Code
import requests
from io import BytesIO

import imageio
import torch
from controlnet_aux.processor import LineartDetector, OpenposeDetector
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, DPMSolverMultistepScheduler, LCMScheduler
from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
from diffusers.utils import export_to_gif, export_to_video
from PIL import Image


# model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
model_id = "stablediffusionapi/darksushimixv225"
# model_id = "emilianJR/epiCRealism"
# motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
motion_adapter_id = "wangfuyun/AnimateLCM"
controlnet1_id = "/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_openpose.pth"
controlnet2_id = "/raid/aryan/hub/models--lllyasviel--ControlNet-v1-1/snapshots/69fc48b9cbd98661f6d0288dc59b59a5ccb32a6b/control_v11p_sd15_lineart.pth"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda:0"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id)
controlnet1 = ControlNetModel.from_single_file(controlnet1_id, torch_dtype=torch.float16)
controlnet2 = ControlNetModel.from_single_file(controlnet2_id, torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16)
pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    controlnet=[controlnet1, controlnet2],
    vae=vae,
).to(device=device, dtype=torch.float16)
# pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
#     model_id,
#     subfolder="scheduler",
#     timestep_spacing="linspace",
#     beta_schedule="linear",
#     algorithm_type="dpmsolver++",
#     use_karras_sigmas=True,
#     steps_offset=1,
# )
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])

def load_video(file_path: str):
    images = []

    if file_path.startswith(('http://', 'https://')):
        # If the file_path is a URL
        response = requests.get(file_path)
        response.raise_for_status()
        content = BytesIO(response.content)
        reader = imageio.get_reader(content)
    else:
        # Assuming it's a local file path
        reader = imageio.get_reader(file_path)

    for frame in reader:
        pil_image = Image.fromarray(frame)
        images.append(pil_image)

    return images


# skip_nth_frame = 1
# max_frames = 16
# video = load_video("input.gif")[::skip_nth_frame][:max_frames]
# width = 512
# height = 768

skip_nth_frame = 2
max_frames = 80
assert max_frames == 80 # Must be 80 for FreeNoise to work because of the hardcoded implementation at the moment
video = load_video("vid2vid_input2.mov")[::skip_nth_frame][:max_frames]
width = 768
height = 432

p1 = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device)
p2 = LineartDetector.from_pretrained("lllyasviel/Annotators").to(device)
cn1, cn2 = [], []

with pipe.progress_bar(total=len(video)) as progress_bar:
    for frame in video:
        cn1.append(p1(frame, include_body=True, include_hand=True, include_face=True))
        cn2.append(p2(frame))
        progress_bar.update()

prompt = "girl dancing, blue hair, high quality, surreal"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"

# pipe.enable_free_init(use_fast_sampling=True)
# pipe.enable_free_init(use_fast_sampling=False)

pipe.enable_free_noise(context_length=16, context_stride=4, shuffle=True)

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=len(video),
    num_inference_steps=10,
    guidance_scale=2.0,
    conditioning_frames=[cn1, cn2],
    controlnet_conditioning_scale=[0.5, 0.8],
    decode_batch_size=16,
    generator=torch.Generator().manual_seed(42),
).frames[0]

export_to_video(video, "animatediff_controlnet_long_1.mp4", fps=8)
Both videos are generated with FreeNoise context_length=16 and context_stride=4. Left video has FreeInit (different method) disabled (total 10 inference steps) whereas right video has it enabled in fast mode (3 + 6 + 10 inference steps).

Viewing the GIF in your browser will show various artifacts so the mp4 versions along with inputs can be found here. Ignore the last 8 frames in the output - it is caused due to an implementation bug that I'll look into soon.

cc @asomoza since we can write a few guides to improve long video generation quality once we have a more stable implementation merged

@sayakpaul
Copy link
Member

Looks pretty good. Good luck with the bug hunting!

@a-r-r-o-w a-r-r-o-w requested review from yiyixuxu and sayakpaul July 28, 2024 05:02
@a-r-r-o-w
Copy link
Member Author

This PR requires #8972 to be merged to remove the AnimateDiffControlNet changes done here for the sake of demo code to run if someone wants to replicate. While #8979 shoudn't really cause any problems with the implementation here for inference, it would be better for that to be merged as well and handle merge conflicts here.

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments on why certain changes were made and what will get removed as some of the linked PRs are merged.

Big thanks to @DN6 for the design suggestions and helping me on integrating this. It's been long overdue and I can't wait to cook up some good tutorials on long video generation with Diffusers.

@@ -272,6 +272,17 @@ def __init__(
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes were made to initialize the FreeNoiseTransformerBlock correctly. I'm not sure how else we could determine these attributes in a "simple" way without accessing the interal pytorch dimensions which adds many many extra LOC after make style.


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint

from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From #8995 to fix LoRA in UNetMotionModel. Once that's merged, these changes shouldn't be visible here


# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made some assumptions here and removed all the branches handling different configurations of layernorms and attention. We support multiple norm_type in BasicTransformerBlock. I think that SD15 checkpoints, used in AnimateDiff, always use LayerNorm but am only 99% sure. LMK if any other norm types must be handled


return frame_indices

def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original FreeNoise implementation proposes using a pyramid weighted averaging (see Eq. 9 of the paper. However, the diffusion community found different weighting schemes that also seem to work well in practice. While I haven't tested it deeply, I would like to keep the implementation to extension in the future. For now, let's roll with the original unless we can test different methods qualitatively before next release

@@ -0,0 +1,1095 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file will be removed from here once my animatediff controlnet PR is merged :)

weighting_scheme (`str`, defaults to `4`):
TODO(aryan)
shuffle (`str`, defaults to `True`):
TODO(aryan): decide if this is even needed
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latent shuffling (or better explained as reusing context_length number of latent frames intead of num_frames in the pipeline) is very much required to improve temporal consistency. In my initial pass of the paper, I misunderstood what it meant and implemented incorrectly. Now it's done correctly and you can see that text2vid quality has significantly improved.

We can either remove this as parameter and always do shuffling, or leave it in for more experimental freedom in different settings. Either is okay with me

self._free_noise_weighting_scheme = weighting_scheme
self._free_noise_shuffle = shuffle

blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not added FreeNoise to AnimateDiff SparseCtrl yet because it fails somewhere in the Mid block. It seems that it will require more time but there are more important things to be attended to at the moment, so I propose to roll with this and revisit in near future

@@ -401,6 +402,64 @@ def test_free_init_with_schedulers(self):
"Enabling of FreeInit should lead to results different from the default pipeline results",
)

def test_free_noise_blocks(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added two tests for FreeNoise based on what I think are the most important parts. LMK if anything else is needed

@@ -569,6 +601,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_batch_size: int = 16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use naming/logic similar to SVD for batch decoding.

frames = self.decode_latents(latents, num_frames, decode_chunk_size)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also used in the vae encode for animatediff_video2video btw, but can rename it that

@a-r-r-o-w a-r-r-o-w requested a review from DN6 August 2, 2024 15:10
DN6 and others added 5 commits August 3, 2024 13:58
@DN6 DN6 merged commit 16a93f1 into main Aug 7, 2024
18 checks passed
@a-r-r-o-w a-r-r-o-w deleted the freenoise branch August 7, 2024 06:00
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* initial work draft for freenoise; needs massive cleanup

* fix freeinit bug

* add animatediff controlnet implementation

* revert attention changes

* add freenoise

* remove old helper functions

* add decode batch size param to all pipelines

* make style

* fix copied from comments

* make fix-copies

* make style

* copy animatediff controlnet implementation from #8972

* add experimental support for num_frames not perfectly fitting context length, ocntext stride

* make unet motion model lora work again based on #8995

* copy load video utils from #8972

* copied from AnimateDiff::prepare_latents

* address the case where last batch of frames does not match length of indices in prepare latents

* decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid

* revert sparsectrl and sdxl freenoise changes

* revert pia

* add freenoise tests

* make fix-copies

* improve docstrings

* add freenoise tests to animatediff controlnet

* update tests

* Update src/diffusers/models/unets/unet_motion_model.py

* add freenoise to animatediff pag

* address review comments

* make style

* update tests

* make fix-copies

* fix error message

* remove copied from comment

* fix imports in tests

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Request to implement FreeNoise, a new diffusion scheduler
4 participants