Skip to content

Commit 5e95dca

Browse files
authored
[cuda kernels] only compile them when initializing (#29133)
* only compile when needed * fix mra as well * fix yoso as well * update * rempve comment * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py * opps * Update src/transformers/models/deta/modeling_deta.py * nit
1 parent a7755d2 commit 5e95dca

File tree

4 files changed

+93
-68
lines changed

4 files changed

+93
-68
lines changed

src/transformers/models/deformable_detr/modeling_deformable_detr.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import copy
1919
import math
20+
import os
2021
import warnings
2122
from dataclasses import dataclass
23+
from pathlib import Path
2224
from typing import Dict, List, Optional, Tuple, Union
2325

2426
import torch
@@ -46,21 +48,42 @@
4648
from ...utils import is_accelerate_available, is_ninja_available, logging
4749
from ...utils.backbone_utils import load_backbone
4850
from .configuration_deformable_detr import DeformableDetrConfig
49-
from .load_custom import load_cuda_kernels
5051

5152

5253
logger = logging.get_logger(__name__)
5354

54-
# Move this to not compile only when importing, this needs to happen later, like in __init__.
55-
if is_torch_cuda_available() and is_ninja_available():
56-
logger.info("Loading custom CUDA kernels...")
57-
try:
58-
MultiScaleDeformableAttention = load_cuda_kernels()
59-
except Exception as e:
60-
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
61-
MultiScaleDeformableAttention = None
62-
else:
63-
MultiScaleDeformableAttention = None
55+
MultiScaleDeformableAttention = None
56+
57+
58+
def load_cuda_kernels():
59+
from torch.utils.cpp_extension import load
60+
61+
global MultiScaleDeformableAttention
62+
63+
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
64+
src_files = [
65+
root / filename
66+
for filename in [
67+
"vision.cpp",
68+
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
69+
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
70+
]
71+
]
72+
73+
MultiScaleDeformableAttention = load(
74+
"MultiScaleDeformableAttention",
75+
src_files,
76+
with_cuda=True,
77+
extra_include_paths=[str(root)],
78+
extra_cflags=["-DWITH_CUDA=1"],
79+
extra_cuda_cflags=[
80+
"-DCUDA_HAS_FP16=1",
81+
"-D__CUDA_NO_HALF_OPERATORS__",
82+
"-D__CUDA_NO_HALF_CONVERSIONS__",
83+
"-D__CUDA_NO_HALF2_OPERATORS__",
84+
],
85+
)
86+
6487

6588
if is_vision_available():
6689
from transformers.image_transforms import center_to_corners_format
@@ -590,6 +613,14 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
590613

591614
def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
592615
super().__init__()
616+
617+
kernel_loaded = MultiScaleDeformableAttention is not None
618+
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
619+
try:
620+
load_cuda_kernels()
621+
except Exception as e:
622+
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
623+
593624
if config.d_model % num_heads != 0:
594625
raise ValueError(
595626
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"

src/transformers/models/deta/modeling_deta.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,15 @@
5050

5151
logger = logging.get_logger(__name__)
5252

53+
MultiScaleDeformableAttention = None
5354

55+
56+
# Copied from models.deformable_detr.load_cuda_kernels
5457
def load_cuda_kernels():
5558
from torch.utils.cpp_extension import load
5659

60+
global MultiScaleDeformableAttention
61+
5762
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
5863
src_files = [
5964
root / filename
@@ -78,22 +83,6 @@ def load_cuda_kernels():
7883
],
7984
)
8085

81-
import MultiScaleDeformableAttention as MSDA
82-
83-
return MSDA
84-
85-
86-
# Move this to not compile only when importing, this needs to happen later, like in __init__.
87-
if is_torch_cuda_available() and is_ninja_available():
88-
logger.info("Loading custom CUDA kernels...")
89-
try:
90-
MultiScaleDeformableAttention = load_cuda_kernels()
91-
except Exception as e:
92-
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
93-
MultiScaleDeformableAttention = None
94-
else:
95-
MultiScaleDeformableAttention = None
96-
9786

9887
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
9988
class MultiScaleDeformableAttentionFunction(Function):
@@ -596,6 +585,14 @@ class DetaMultiscaleDeformableAttention(nn.Module):
596585

597586
def __init__(self, config: DetaConfig, num_heads: int, n_points: int):
598587
super().__init__()
588+
589+
kernel_loaded = MultiScaleDeformableAttention is not None
590+
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
591+
try:
592+
load_cuda_kernels()
593+
except Exception as e:
594+
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
595+
599596
if config.d_model % num_heads != 0:
600597
raise ValueError(
601598
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"

src/transformers/models/mra/modeling_mra.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,36 +58,19 @@
5858
# See all Mra models at https://huggingface.co/models?filter=mra
5959
]
6060

61+
mra_cuda_kernel = None
62+
6163

6264
def load_cuda_kernels():
63-
global cuda_kernel
65+
global mra_cuda_kernel
6466
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra"
6567

6668
def append_root(files):
6769
return [src_folder / file for file in files]
6870

6971
src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"])
7072

71-
cuda_kernel = load("cuda_kernel", src_files, verbose=True)
72-
73-
import cuda_kernel
74-
75-
76-
cuda_kernel = None
77-
78-
79-
if is_torch_cuda_available() and is_ninja_available():
80-
logger.info("Loading custom CUDA kernels...")
81-
82-
try:
83-
load_cuda_kernels()
84-
except Exception as e:
85-
logger.warning(
86-
"Failed to load CUDA kernels. Mra requires custom CUDA kernels. Please verify that compatible versions of"
87-
f" PyTorch and CUDA Toolkit are installed: {e}"
88-
)
89-
else:
90-
pass
73+
mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True)
9174

9275

9376
def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
@@ -112,7 +95,7 @@ def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
11295
indices = indices.int()
11396
indices = indices.contiguous()
11497

115-
max_vals, max_vals_scatter = cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
98+
max_vals, max_vals_scatter = mra_cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
11699
max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :]
117100

118101
return max_vals, max_vals_scatter
@@ -178,7 +161,7 @@ def mm_to_sparse(dense_query, dense_key, indices, block_size=32):
178161
indices = indices.int()
179162
indices = indices.contiguous()
180163

181-
return cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())
164+
return mra_cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())
182165

183166

184167
def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32):
@@ -216,7 +199,7 @@ def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_siz
216199
indices = indices.contiguous()
217200
dense_key = dense_key.contiguous()
218201

219-
dense_qk_prod = cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
202+
dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
220203
dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim)
221204
return dense_qk_prod
222205

@@ -393,7 +376,7 @@ def mra2_attention(
393376
"""
394377
Use Mra to approximate self-attention.
395378
"""
396-
if cuda_kernel is None:
379+
if mra_cuda_kernel is None:
397380
return torch.zeros_like(query).requires_grad_()
398381

399382
batch_size, num_head, seq_len, head_dim = query.size()
@@ -561,6 +544,13 @@ def __init__(self, config, position_embedding_type=None):
561544
f"heads ({config.num_attention_heads})"
562545
)
563546

547+
kernel_loaded = mra_cuda_kernel is not None
548+
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
549+
try:
550+
load_cuda_kernels()
551+
except Exception as e:
552+
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
553+
564554
self.num_attention_heads = config.num_attention_heads
565555
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
566556
self.all_head_size = self.num_attention_heads * self.attention_head_size

src/transformers/models/yoso/modeling_yoso.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,14 @@
3535
)
3636
from ...modeling_utils import PreTrainedModel
3737
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
38-
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38+
from ...utils import (
39+
add_code_sample_docstrings,
40+
add_start_docstrings,
41+
add_start_docstrings_to_model_forward,
42+
is_ninja_available,
43+
is_torch_cuda_available,
44+
logging,
45+
)
3946
from .configuration_yoso import YosoConfig
4047

4148

@@ -49,28 +56,22 @@
4956
# See all YOSO models at https://huggingface.co/models?filter=yoso
5057
]
5158

59+
lsh_cumulation = None
60+
5261

5362
def load_cuda_kernels():
5463
global lsh_cumulation
55-
try:
56-
from torch.utils.cpp_extension import load
64+
from torch.utils.cpp_extension import load
5765

58-
def append_root(files):
59-
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso"
60-
return [src_folder / file for file in files]
61-
62-
src_files = append_root(
63-
["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"]
64-
)
66+
def append_root(files):
67+
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso"
68+
return [src_folder / file for file in files]
6569

66-
load("fast_lsh_cumulation", src_files, verbose=True)
70+
src_files = append_root(["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"])
6771

68-
import fast_lsh_cumulation as lsh_cumulation
72+
load("fast_lsh_cumulation", src_files, verbose=True)
6973

70-
return True
71-
except Exception:
72-
lsh_cumulation = None
73-
return False
74+
import fast_lsh_cumulation as lsh_cumulation
7475

7576

7677
def to_contiguous(input_tensors):
@@ -305,6 +306,12 @@ def __init__(self, config, position_embedding_type=None):
305306
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
306307
f"heads ({config.num_attention_heads})"
307308
)
309+
kernel_loaded = lsh_cumulation is not None
310+
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
311+
try:
312+
load_cuda_kernels()
313+
except Exception as e:
314+
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
308315

309316
self.num_attention_heads = config.num_attention_heads
310317
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)

0 commit comments

Comments
 (0)