Skip to content

Commit dd7d188

Browse files
committed
add realesrgan to basicsr
1 parent 2f0ad00 commit dd7d188

10 files changed

+1545
-0
lines changed

basicsr/archs/discriminator_arch.py

+65
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from torch import nn as nn
2+
from torch.nn import functional as F
3+
from torch.nn.utils import spectral_norm
24

35
from basicsr.utils.registry import ARCH_REGISTRY
46

@@ -83,3 +85,66 @@ def forward(self, x):
8385
feat = self.lrelu(self.linear1(feat))
8486
out = self.linear2(feat)
8587
return out
88+
89+
90+
@ARCH_REGISTRY.register(suffix='basicsr')
91+
class UNetDiscriminatorSN(nn.Module):
92+
"""Defines a U-Net discriminator with spectral normalization (SN)
93+
94+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
95+
96+
Arg:
97+
num_in_ch (int): Channel number of inputs. Default: 3.
98+
num_feat (int): Channel number of base intermediate features. Default: 64.
99+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
100+
"""
101+
102+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
103+
super(UNetDiscriminatorSN, self).__init__()
104+
self.skip_connection = skip_connection
105+
norm = spectral_norm
106+
# the first convolution
107+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
108+
# downsample
109+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
110+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
111+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
112+
# upsample
113+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
114+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
115+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
116+
# extra convolutions
117+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
118+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
119+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
120+
121+
def forward(self, x):
122+
# downsample
123+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
124+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
125+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
126+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
127+
128+
# upsample
129+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
130+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
131+
132+
if self.skip_connection:
133+
x4 = x4 + x2
134+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
135+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
136+
137+
if self.skip_connection:
138+
x5 = x5 + x1
139+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
140+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
141+
142+
if self.skip_connection:
143+
x6 = x6 + x0
144+
145+
# extra convolutions
146+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
147+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
148+
out = self.conv9(out)
149+
150+
return out

basicsr/archs/srvgg_arch.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from torch import nn as nn
2+
from torch.nn import functional as F
3+
4+
from basicsr.utils.registry import ARCH_REGISTRY
5+
6+
7+
@ARCH_REGISTRY.register(suffix='basicsr')
8+
class SRVGGNetCompact(nn.Module):
9+
"""A compact VGG-style network structure for super-resolution.
10+
11+
It is a compact network structure, which performs upsampling in the last layer and no convolution is
12+
conducted on the HR feature space.
13+
14+
Args:
15+
num_in_ch (int): Channel number of inputs. Default: 3.
16+
num_out_ch (int): Channel number of outputs. Default: 3.
17+
num_feat (int): Channel number of intermediate features. Default: 64.
18+
num_conv (int): Number of convolution layers in the body network. Default: 16.
19+
upscale (int): Upsampling factor. Default: 4.
20+
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
21+
"""
22+
23+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
24+
super(SRVGGNetCompact, self).__init__()
25+
self.num_in_ch = num_in_ch
26+
self.num_out_ch = num_out_ch
27+
self.num_feat = num_feat
28+
self.num_conv = num_conv
29+
self.upscale = upscale
30+
self.act_type = act_type
31+
32+
self.body = nn.ModuleList()
33+
# the first conv
34+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
35+
# the first activation
36+
if act_type == 'relu':
37+
activation = nn.ReLU(inplace=True)
38+
elif act_type == 'prelu':
39+
activation = nn.PReLU(num_parameters=num_feat)
40+
elif act_type == 'leakyrelu':
41+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
42+
self.body.append(activation)
43+
44+
# the body structure
45+
for _ in range(num_conv):
46+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
47+
# activation
48+
if act_type == 'relu':
49+
activation = nn.ReLU(inplace=True)
50+
elif act_type == 'prelu':
51+
activation = nn.PReLU(num_parameters=num_feat)
52+
elif act_type == 'leakyrelu':
53+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
54+
self.body.append(activation)
55+
56+
# the last conv
57+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
58+
# upsample
59+
self.upsampler = nn.PixelShuffle(upscale)
60+
61+
def forward(self, x):
62+
out = x
63+
for i in range(0, len(self.body)):
64+
out = self.body[i](out)
65+
66+
out = self.upsampler(out)
67+
# add the nearest upsampled image, so that the network learns the residual
68+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
69+
out += base
70+
return out

basicsr/data/realesrgan_dataset.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import cv2
2+
import math
3+
import numpy as np
4+
import os
5+
import os.path as osp
6+
import random
7+
import time
8+
import torch
9+
from torch.utils import data as data
10+
11+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
12+
from basicsr.data.transforms import augment
13+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
14+
from basicsr.utils.registry import DATASET_REGISTRY
15+
16+
17+
@DATASET_REGISTRY.register(suffix='basicsr')
18+
class RealESRGANDataset(data.Dataset):
19+
"""Dataset used for Real-ESRGAN model:
20+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
21+
22+
It loads gt (Ground-Truth) images, and augments them.
23+
It also generates blur kernels and sinc kernels for generating low-quality images.
24+
Note that the low-quality images are processed in tensors on GPUS for faster processing.
25+
26+
Args:
27+
opt (dict): Config for train datasets. It contains the following keys:
28+
dataroot_gt (str): Data root path for gt.
29+
meta_info (str): Path for meta information file.
30+
io_backend (dict): IO backend type and other kwarg.
31+
use_hflip (bool): Use horizontal flips.
32+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
33+
Please see more options in the codes.
34+
"""
35+
36+
def __init__(self, opt):
37+
super(RealESRGANDataset, self).__init__()
38+
self.opt = opt
39+
self.file_client = None
40+
self.io_backend_opt = opt['io_backend']
41+
self.gt_folder = opt['dataroot_gt']
42+
43+
# file client (lmdb io backend)
44+
if self.io_backend_opt['type'] == 'lmdb':
45+
self.io_backend_opt['db_paths'] = [self.gt_folder]
46+
self.io_backend_opt['client_keys'] = ['gt']
47+
if not self.gt_folder.endswith('.lmdb'):
48+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
49+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
50+
self.paths = [line.split('.')[0] for line in fin]
51+
else:
52+
# disk backend with meta_info
53+
# Each line in the meta_info describes the relative path to an image
54+
with open(self.opt['meta_info']) as fin:
55+
paths = [line.strip().split(' ')[0] for line in fin]
56+
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
57+
58+
# blur settings for the first degradation
59+
self.blur_kernel_size = opt['blur_kernel_size']
60+
self.kernel_list = opt['kernel_list']
61+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
62+
self.blur_sigma = opt['blur_sigma']
63+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
64+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
65+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
66+
67+
# blur settings for the second degradation
68+
self.blur_kernel_size2 = opt['blur_kernel_size2']
69+
self.kernel_list2 = opt['kernel_list2']
70+
self.kernel_prob2 = opt['kernel_prob2']
71+
self.blur_sigma2 = opt['blur_sigma2']
72+
self.betag_range2 = opt['betag_range2']
73+
self.betap_range2 = opt['betap_range2']
74+
self.sinc_prob2 = opt['sinc_prob2']
75+
76+
# a final sinc filter
77+
self.final_sinc_prob = opt['final_sinc_prob']
78+
79+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
80+
# TODO: kernel range is now hard-coded, should be in the configure file
81+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
82+
self.pulse_tensor[10, 10] = 1
83+
84+
def __getitem__(self, index):
85+
if self.file_client is None:
86+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
87+
88+
# -------------------------------- Load gt images -------------------------------- #
89+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
90+
gt_path = self.paths[index]
91+
# avoid errors caused by high latency in reading files
92+
retry = 3
93+
while retry > 0:
94+
try:
95+
img_bytes = self.file_client.get(gt_path, 'gt')
96+
except (IOError, OSError) as e:
97+
logger = get_root_logger()
98+
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
99+
# change another file to read
100+
index = random.randint(0, self.__len__())
101+
gt_path = self.paths[index]
102+
time.sleep(1) # sleep 1s for occasional server congestion
103+
else:
104+
break
105+
finally:
106+
retry -= 1
107+
img_gt = imfrombytes(img_bytes, float32=True)
108+
109+
# -------------------- Do augmentation for training: flip, rotation -------------------- #
110+
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
111+
112+
# crop or pad to 400
113+
# TODO: 400 is hard-coded. You may change it accordingly
114+
h, w = img_gt.shape[0:2]
115+
crop_pad_size = 400
116+
# pad
117+
if h < crop_pad_size or w < crop_pad_size:
118+
pad_h = max(0, crop_pad_size - h)
119+
pad_w = max(0, crop_pad_size - w)
120+
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
121+
# crop
122+
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
123+
h, w = img_gt.shape[0:2]
124+
# randomly choose top and left coordinates
125+
top = random.randint(0, h - crop_pad_size)
126+
left = random.randint(0, w - crop_pad_size)
127+
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
128+
129+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
130+
kernel_size = random.choice(self.kernel_range)
131+
if np.random.uniform() < self.opt['sinc_prob']:
132+
# this sinc filter setting is for kernels ranging from [7, 21]
133+
if kernel_size < 13:
134+
omega_c = np.random.uniform(np.pi / 3, np.pi)
135+
else:
136+
omega_c = np.random.uniform(np.pi / 5, np.pi)
137+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
138+
else:
139+
kernel = random_mixed_kernels(
140+
self.kernel_list,
141+
self.kernel_prob,
142+
kernel_size,
143+
self.blur_sigma,
144+
self.blur_sigma, [-math.pi, math.pi],
145+
self.betag_range,
146+
self.betap_range,
147+
noise_range=None)
148+
# pad kernel
149+
pad_size = (21 - kernel_size) // 2
150+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
151+
152+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
153+
kernel_size = random.choice(self.kernel_range)
154+
if np.random.uniform() < self.opt['sinc_prob2']:
155+
if kernel_size < 13:
156+
omega_c = np.random.uniform(np.pi / 3, np.pi)
157+
else:
158+
omega_c = np.random.uniform(np.pi / 5, np.pi)
159+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
160+
else:
161+
kernel2 = random_mixed_kernels(
162+
self.kernel_list2,
163+
self.kernel_prob2,
164+
kernel_size,
165+
self.blur_sigma2,
166+
self.blur_sigma2, [-math.pi, math.pi],
167+
self.betag_range2,
168+
self.betap_range2,
169+
noise_range=None)
170+
171+
# pad kernel
172+
pad_size = (21 - kernel_size) // 2
173+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
174+
175+
# ------------------------------------- the final sinc kernel ------------------------------------- #
176+
if np.random.uniform() < self.opt['final_sinc_prob']:
177+
kernel_size = random.choice(self.kernel_range)
178+
omega_c = np.random.uniform(np.pi / 3, np.pi)
179+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
180+
sinc_kernel = torch.FloatTensor(sinc_kernel)
181+
else:
182+
sinc_kernel = self.pulse_tensor
183+
184+
# BGR to RGB, HWC to CHW, numpy to tensor
185+
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
186+
kernel = torch.FloatTensor(kernel)
187+
kernel2 = torch.FloatTensor(kernel2)
188+
189+
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
190+
return return_d
191+
192+
def __len__(self):
193+
return len(self.paths)

0 commit comments

Comments
 (0)