Skip to content

Fixes issues with Colab being updated to Python 3.10 and CUDA 11.8 #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 60 additions & 27 deletions torch_utils/ops/conv2d_gradfix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
Expand All @@ -9,9 +9,9 @@
"""Custom replacement for `torch.nn.functional.conv2d` that supports
arbitrarily high order gradients with zero performance penalty."""

import warnings
import contextlib
import torch
from pkg_resources import parse_version

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
Expand All @@ -21,12 +21,14 @@

enabled = False # Enable the custom op by setting this to true.
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11

@contextlib.contextmanager
def no_weight_gradients():
def no_weight_gradients(disable=True):
global weight_gradients_disabled
old = weight_gradients_disabled
weight_gradients_disabled = True
if disable:
weight_gradients_disabled = True
yield
weight_gradients_disabled = old

Expand All @@ -48,12 +50,12 @@ def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if _use_pytorch_1_11_api:
# The work-around code doesn't work on PyTorch 1.11.0 onwards
return False
if input.device.type != 'cuda':
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
return True
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
return False
return True

def _tuple_of_ints(xs, ndim):
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
Expand All @@ -64,6 +66,7 @@ def _tuple_of_ints(xs, ndim):
#----------------------------------------------------------------------------

_conv2d_gradfix_cache = dict()
_null_tensor = torch.empty([0])

def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
# Parse arguments.
Expand Down Expand Up @@ -108,24 +111,39 @@ class Conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
assert weight.shape == weight_shape
if not transpose:
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
else: # transpose
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
ctx.save_for_backward(input, weight)
return output
ctx.save_for_backward(
input if weight.requires_grad else _null_tensor,
weight if input.requires_grad else _null_tensor,
)
ctx.input_shape = input.shape

# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))

# General case => cuDNN.
if transpose:
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input_shape = ctx.input_shape
grad_input = None
grad_weight = None
grad_bias = None

if ctx.needs_input_grad[0]:
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
assert grad_input.shape == input.shape
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
grad_input = op.apply(grad_output, weight, None)
assert grad_input.shape == input_shape

if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)
Expand All @@ -140,31 +158,46 @@ def backward(ctx, grad_output):
class Conv2dGradWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
ctx.save_for_backward(
grad_output if input.requires_grad else _null_tensor,
input if grad_output.requires_grad else _null_tensor,
)
ctx.grad_output_shape = grad_output.shape
ctx.input_shape = input.shape

# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))

# General case => cuDNN.
name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
assert grad_weight.shape == weight_shape
ctx.save_for_backward(grad_output, input)
return grad_weight
return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)

@staticmethod
def backward(ctx, grad2_grad_weight):
grad_output, input = ctx.saved_tensors
grad_output_shape = ctx.grad_output_shape
input_shape = ctx.input_shape
grad2_grad_output = None
grad2_input = None

if ctx.needs_input_grad[0]:
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
assert grad2_grad_output.shape == grad_output.shape
assert grad2_grad_output.shape == grad_output_shape

if ctx.needs_input_grad[1]:
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
assert grad2_input.shape == input.shape
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
assert grad2_input.shape == input_shape

return grad2_grad_output, grad2_input

_conv2d_gradfix_cache[key] = Conv2d
return Conv2d

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
23 changes: 13 additions & 10 deletions torch_utils/ops/grid_sample_gradfix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
Expand All @@ -11,8 +11,8 @@
Only works on 2D images and assumes
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""

import warnings
import torch
from pkg_resources import parse_version

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
Expand All @@ -21,6 +21,8 @@
#----------------------------------------------------------------------------

enabled = False # Enable the custom op by setting this to true.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12

#----------------------------------------------------------------------------

Expand All @@ -32,12 +34,7 @@ def grid_sample(input, grid):
#----------------------------------------------------------------------------

def _should_use_custom_op():
if not enabled:
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
return True
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
return False
return enabled

#----------------------------------------------------------------------------

Expand All @@ -62,7 +59,13 @@ class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
if _use_pytorch_1_12_api:
op = op[0]
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid

Expand All @@ -80,4 +83,4 @@ def backward(ctx, grad2_grad_input, grad2_grad_grid):
assert not ctx.needs_input_grad[2]
return grad2_grad_output, grad2_input, grad2_grid

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------