SRGAN

SRGAN

背景介绍

  SRGAN(Supre Resolution Generative Adversarial Networks, 超分辨率生成式对抗网络):于2016年发表在CVPR上,图像处理的一个重要任务就是超分辨率,图像的大小是由分辨率决定的,常见的分辨率有128x128,256x256,512x512,1024x1024,像素越大,说明像素点越多,图像的表达能力更强,细节更加明显,看起来也会更加舒适。

srgan

分辨率提升的方法

在GAN问世以前,人们就做了许多关于分辨率提升的方法,其中最简单的也是最常用的方法就是插值法,在opencv库中有imresize函数,其中可以指定目标图像的尺寸,而且可以选择插值方法,一般选择双线性插值,但是插值方法有一个最致命的问题—模糊,插值的原理就是根据周围的点进行估计,因此插值的结果会导致某个区域的值都i非常接近,会导致照片模糊。

随着GAN的发展,人们发现GAN既然能够生成图像,能不能生成更高分辨率的图像,答案是肯定的,今天给小伙伴们介绍SRGAN的原理。
SRGAN引入了三个网络,一个是生成器,一个是判别器,还有一个是特征提取器(VGG19)
生成器的输入是低分辨率图像,输出是高分辨率图像,目的是根据输入的低分辨率图像生成高分辨率图像
判别器的输入是高分辨率图像,输出是对输入图像的分类,目的是判断输入的图像是生成的高分辨率图像还是原始的高分辨率图像
特征提取器的输入是高分辨率图像,输出是对高分辨率图像的特征提取,目的是使生成的图像和原始的高分辨率图像具有相同的特征

SRGAN的特点

  生成器使用ResNet结构+上采样对图像进行分辨率提升
  引入特征提取网络VGG19,对高分辨率特征进行提取
  特征提取损失函数采用均方误差,判别器损失函数采用二分类交叉熵
  对生成器损失函数的权重进行调节,使网络更多关注于生成的图像质量

SRGAN图像分析

generator
discriminator

TensorFlow2.0实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import glob
import numpy as np
import cv2 as cv
from functools import reduce
import tensorflow as tf
import tensorflow.keras as keras


def compose(*funcs):
if funcs:
return reduce(lambda f, g: lambda *a, **kw: g(f(*a, **kw)), funcs)
else:
raise ValueError('Composition of empty sequence not supported.')


def resblock(x, filters, name):
shortcut = x
x = compose(keras.layers.Conv2D(filters, (3, 3), (1, 1), 'same', name='{}_conv1'.format(name)),
keras.layers.BatchNormalization(momentum=0.8, name='{}_bn1'.format(name)),
keras.layers.ReLU(name='{}_relu1'.format(name)),
keras.layers.Conv2D(filters, (3, 3), (1, 1), 'same', name='{}_conv2'.format(name)),
keras.layers.BatchNormalization(momentum=0.8, name='{}_bn2'.format(name)))(x)
x = keras.layers.Add(name='{}_add'.format(name))([x, shortcut])

return x


class Conv_LeakyRelu_Bn(keras.layers.Layer):
def __init__(self, filters, kernel_size, strides, padding, name):
super(Conv_LeakyRelu_Bn, self).__init__()
self._name = name
self.block = keras.Sequential([keras.layers.Conv2D(filters, kernel_size, strides, padding),
keras.layers.LeakyReLU(0.2)])
if name.find('bn') != -1:
self.block.add(keras.layers.BatchNormalization(momentum=0.8))

def call(self, inputs, **kwargs):

return self.block(inputs)


def vgg(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

vgg = keras.applications.VGG19()
vgg.outputs = [vgg.layers[9].output]
x = vgg(x)

model = keras.Model(input_tensor, x, name='VGG19')

return model


def generator(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = keras.layers.Conv2D(64, (9, 9), (1, 1), 'same', name='conv1')(x)
shortcut = x

for i in range(16):
x = resblock(x, 64, name='resblock{}'.format(i + 1))

x = compose(keras.layers.Conv2D(64, (3, 3), (1, 1), 'same', name='conv2'),
keras.layers.BatchNormalization(momentum=0.8, name='bn2'))(x)

x = keras.layers.Add(name='add2')([x, shortcut])

x = compose(keras.layers.UpSampling2D((2, 2), name='upsampling1'),
keras.layers.Conv2D(256, (3, 3), (1, 1), 'same', activation='relu', name='conv3_relu'),
keras.layers.UpSampling2D((2, 2), name='upsampling2'),
keras.layers.Conv2D(256, (3, 3), (1, 1), 'same', activation='relu', name='conv4_relu'),
keras.layers.Conv2D(3, (9, 9), (1, 1), 'same', activation='tanh', name='conv5_tanh'))(x)

model = keras.Model(input_tensor, x, name='SRGAN-Generator')

return model


def discriminator(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = compose(Conv_LeakyRelu_Bn(64, (3, 3), (1, 1), 'same', name='conv_leakyrelu1'),
Conv_LeakyRelu_Bn(64, (3, 3), (2, 2), 'same', name='conv_leakyrelu_bn2'),
Conv_LeakyRelu_Bn(128, (3, 3), (1, 1), 'same', name='conv_leakyrelu_bn3'),
Conv_LeakyRelu_Bn(128, (3, 3), (2, 2), 'same', name='conv_leakyrelu_bn4'),
Conv_LeakyRelu_Bn(256, (3, 3), (1, 1), 'same', name='conv_leakyrelu_bn5'),
Conv_LeakyRelu_Bn(256, (3, 3), (2, 2), 'same', name='conv_leakyrelu_bn6'),
Conv_LeakyRelu_Bn(512, (3, 3), (1, 1), 'same', name='conv_leakyrelu_bn7'),
Conv_LeakyRelu_Bn(512, (3, 3), (2, 2), 'same', name='conv_leakyrelu_bn8'),
Conv_LeakyRelu_Bn(1024, (1, 1), (1, 1), 'same', name='conv_leakyrelu9'),
keras.layers.Conv2D(1, (1, 1), (1, 1), 'same', activation='sigmoid', name='conv_sigmoid'))(x)

model = keras.Model(input_tensor, x, name='SRGAN-Discriminator')

return model


def srgan(input_shape, model_vgg, model_g, model_d):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

model_vgg.trainable = False
model_d.trainable = False

fake_image = model_g(x)
fake_feature = model_vgg(fake_image)
conf = model_d(fake_image)

model = keras.Model(input_tensor, [conf, fake_feature], name='SRGAN')

return model


def read_data(data_path, high_resolution, low_resolution, batch_size):
filename = glob.glob(data_path)
choose_name = np.random.choice(filename, batch_size)

hr_image, lr_image = [], []
for name in choose_name:
image = cv.imread(name).astype(np.float32)
hr_image.append(cv.resize(image, high_resolution))
lr_image.append(cv.resize(image, low_resolution))

hr_image = np.array(hr_image) / 127.5 - 1
lr_image = np.array(lr_image) / 127.5 - 1

return hr_image, lr_image


if __name__ == '__main__':
batch_size = 2
epochs = 2000
tf.random.set_seed(22)
low_resolution = (56, 56)
high_resolution = (224, 224)
data_path = r'.\monet2photo\trainB\*.jpg'
save_path = r'.\srgan'
if not os.path.exists(save_path):
os.makedirs(save_path)

optimizer = keras.optimizers.Adam(0.0002, 0.5)
loss = keras.losses.BinaryCrossentropy()

real_dacc = keras.metrics.BinaryAccuracy()
fake_dacc = keras.metrics.BinaryAccuracy()
gacc = keras.metrics.BinaryAccuracy()

model_vgg = vgg(input_shape=(high_resolution[0], high_resolution[1], 3))

model_d = discriminator(input_shape=(high_resolution[0], high_resolution[1], 3))
model_d.compile(optimizer=optimizer, loss='binary_crossentropy')

model_g = generator(input_shape=(low_resolution[0], low_resolution[1], 3))

model_vgg.build(input_shape=(high_resolution[0], high_resolution[1], 3))
model_vgg.summary()
keras.utils.plot_model(model_vgg, 'SRGAN-vgg19.png', show_shapes=True, show_layer_names=True)

model_g.build(input_shape=(low_resolution[0], low_resolution[1], 3))
model_g.summary()
keras.utils.plot_model(model_g, 'SRGAN-generator.png', show_shapes=True, show_layer_names=True)

model_d.build(input_shape=(high_resolution[0], high_resolution[1], 3))
model_d.summary()
keras.utils.plot_model(model_d, 'SRGAN-discriminator.png', show_shapes=True, show_layer_names=True)

model = srgan(input_shape=(low_resolution[0], low_resolution[1], 3), model_vgg=model_vgg, model_g=model_g, model_d=model_d)
model.compile(optimizer=optimizer, loss=['binary_crossentropy', 'mse'], loss_weights=[1, 100])

model.build(input_shape=(low_resolution[0], low_resolution[1], 3))
model.summary()
keras.utils.plot_model(model, 'SRGAN.png', show_shapes=True, show_layer_names=True)

for epoch in range(epochs):
real_hr_image, real_lr_image = read_data(data_path, high_resolution, low_resolution, batch_size)

real_hr_feature = model_vgg(real_hr_image)
fake_hr_image = model_g(real_lr_image)

real_dacc(np.ones((batch_size, high_resolution[0] // 16, high_resolution[1] // 16, 1)), model_d(real_hr_image))
fake_dacc(np.zeros((batch_size, high_resolution[0] // 16, high_resolution[1] // 16, 1)), model_d(fake_hr_image))
gacc(np.ones((batch_size, high_resolution[0] // 16, high_resolution[1] // 16, 1)), model(real_lr_image)[0])

real_dloss = model_d.train_on_batch(real_hr_image, np.ones((batch_size, high_resolution[0] // 16, high_resolution[1] // 16, 1)))
fake_dloss = model_d.train_on_batch(fake_hr_image, np.zeros((batch_size, high_resolution[0] // 16, high_resolution[1] // 16, 1)))
gloss = model.train_on_batch(real_lr_image, [np.ones((batch_size, high_resolution[0] // 16, high_resolution[1] // 16, 1)), real_hr_feature])

if epoch % 20 == 0:
print('epoch = {}, real_dacc = {}, fake_dacc = {}, gacc = {}'.format(epoch, real_dacc.result(), fake_dacc.result(), gacc.result()))
real_dacc.reset_states()
fake_dacc.reset_states()
gacc.reset_states()
real_hr_image, real_lr_image = read_data(data_path, high_resolution, low_resolution, batch_size=1)
fake_hr_image = ((model_g(real_lr_image).numpy().squeeze() + 1) * 127.5).astype(np.uint8)
scale_hr_image = ((cv.resize(real_lr_image.squeeze(), high_resolution) + 1) * 127.5).astype(np.uint8)
cv.imwrite(save_path + '\\epoch{}.jpg'.format(epoch), np.concatenate([scale_hr_image, fake_hr_image], axis=1))

srgan

模型运行结果

srgan

小技巧

  1. 图像输入可以先将其归一化到0-1之间或者-1-1之间,因为网络的参数一般都比较小,所以归一化后计算方便,收敛较快。
  2. 注意其中的一些维度变换和numpytensorflow常用操作,否则在阅读代码时可能会产生一些困难。
  3. 可以设置一些权重的保存方式学习率的下降方式早停方式
  4. SRGAN对于网络结构,优化器参数,网络层的一些超参数都是非常敏感的,效果不好不容易发现原因,这可能需要较多的工程实践经验
  5. 先创建判别器,然后进行compile,这样判别器就固定了,然后创建生成器时,不要训练判别器,需要将判别器的trainable改成False,此时不会影响之前固定的判别器,这个可以通过模型的_collection_collected_trainable_weights属性查看,如果该属性为空,则模型不训练,否则模型可以训练,compile之后,该属性固定,无论后面如何修改trainable,只要不重新compile,都不影响训练。
  6. 要注意使用VGG19特征提取网络权重时,因为VGG19的输入尺寸是224x224x3的,因此图像的尺寸要匹配,如果想生成更大尺寸的高分辨率图像,需要自己训练一个合适的特征提取网络权重
  7. 在SRGAN的测试图像中,为了说明模型的优越性,左边为低分辨率图像直接resize到高分辨率图像的结果,右边为SRGAN生成的高分辨率图像的结果。只是训练了2000代,每一代只有2个图像就可以看出SRGAN的效果。小伙伴们可以选择更大的数据集,更加快速的GPU,训练更长的时间,这两种算法之间的差距会更加明显。

SRGAN小结

  SRGAN是一种有效的超分辨率生成式对抗网络,从上图可以看出SRGAN模型的参数量只有7M,最近AI老图像复原引起了人们的注意,训练好SRGAN模型后可以运用在AI老图像复原,可以将原来拍摄的低分辨率的图像转化为清晰的高分辨率图像,是不是非常有趣呢?

-------------本文结束感谢您的阅读-------------
0%