Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuda kernels] only compile them when initializing #29133

Merged
merged 12 commits into from
Feb 20, 2024
53 changes: 42 additions & 11 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import copy
import math
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -46,21 +48,42 @@
from ...utils import is_accelerate_available, is_ninja_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels


logger = logging.get_logger(__name__)

# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None
MultiScaleDeformableAttention = None


def load_cuda_kernels():
from torch.utils.cpp_extension import load

global MultiScaleDeformableAttention

root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ArthurZucker, was it intended for the root to be deta related?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EduardoPach I think this is a copy-pasta. Would you like to open a PR to fix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts, sure!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]

load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved


if is_vision_available():
from transformers.image_transforms import center_to_corners_format
Expand Down Expand Up @@ -590,6 +613,14 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):

def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
super().__init__()

kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")

if config.d_model % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
Expand Down
32 changes: 16 additions & 16 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,18 @@

logger = logging.get_logger(__name__)

MultiScaleDeformableAttention = None


# Copied from models.deformable_detr.load_cuda_kernels
def load_cuda_kernels():
from torch.utils.cpp_extension import load

global MultiScaleDeformableAttention
# Only load the kernel if it's not been loaded yet or if we changed the context length
if MultiScaleDeformableAttention is not None:
return
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
src_files = [
root / filename
Expand All @@ -78,22 +86,6 @@ def load_cuda_kernels():
],
)

import MultiScaleDeformableAttention as MSDA

return MSDA


# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function):
Expand Down Expand Up @@ -596,6 +588,14 @@ class DetaMultiscaleDeformableAttention(nn.Module):

def __init__(self, config: DetaConfig, num_heads: int, n_points: int):
super().__init__()

kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")

if config.d_model % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
Expand Down
40 changes: 15 additions & 25 deletions src/transformers/models/mra/modeling_mra.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,19 @@
# See all Mra models at https://huggingface.co/models?filter=mra
]

mra_cuda_kernel = None


def load_cuda_kernels():
global cuda_kernel
global mra_cuda_kernel
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra"

def append_root(files):
return [src_folder / file for file in files]

src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"])

cuda_kernel = load("cuda_kernel", src_files, verbose=True)

import cuda_kernel


cuda_kernel = None


if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")

try:
load_cuda_kernels()
except Exception as e:
logger.warning(
"Failed to load CUDA kernels. Mra requires custom CUDA kernels. Please verify that compatible versions of"
f" PyTorch and CUDA Toolkit are installed: {e}"
)
else:
pass
mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True)


def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
Expand All @@ -112,7 +95,7 @@ def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
indices = indices.int()
indices = indices.contiguous()

max_vals, max_vals_scatter = cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
max_vals, max_vals_scatter = mra_cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :]

return max_vals, max_vals_scatter
Expand Down Expand Up @@ -178,7 +161,7 @@ def mm_to_sparse(dense_query, dense_key, indices, block_size=32):
indices = indices.int()
indices = indices.contiguous()

return cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())
return mra_cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())


def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32):
Expand Down Expand Up @@ -216,7 +199,7 @@ def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_siz
indices = indices.contiguous()
dense_key = dense_key.contiguous()

dense_qk_prod = cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim)
return dense_qk_prod

Expand Down Expand Up @@ -393,7 +376,7 @@ def mra2_attention(
"""
Use Mra to approximate self-attention.
"""
if cuda_kernel is None:
if mra_cuda_kernel is None:
return torch.zeros_like(query).requires_grad_()

batch_size, num_head, seq_len, head_dim = query.size()
Expand Down Expand Up @@ -561,6 +544,13 @@ def __init__(self, config, position_embedding_type=None):
f"heads ({config.num_attention_heads})"
)

kernel_loaded = mra_cuda_kernel is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")

self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
Expand Down
17 changes: 16 additions & 1 deletion src/transformers/models/yoso/modeling_yoso.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_ninja_available,
is_torch_cuda_available,
logging,
)
from .configuration_yoso import YosoConfig


Expand All @@ -49,6 +56,8 @@
# See all YOSO models at https://huggingface.co/models?filter=yoso
]

lsh_cumulation = None


def load_cuda_kernels():
global lsh_cumulation
Expand Down Expand Up @@ -305,6 +314,12 @@ def __init__(self, config, position_embedding_type=None):
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
kernel_loaded = lsh_cumulation is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")

self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
Expand Down