DiscoGAN

DiscoGAN

背景介绍

  DiscoGAN(Discover Cross-Domain Relations with Generative Adversarial Networks):于2017年发表在ICML上,可以实现图像的风格迁移,风格迁移是GAN网络提出后才出现在人们视野里面的图像处理算法,在生成式对抗网络问世之前,人们很难通过传统的图像处理算法实现风格迁移,今天带小伙伴们看一看瞧一瞧。

discogan

DiscoGAN理论思想

DiscoGAN引入了4个网络结构,分别是生成器GAB,生成器GBA,判别器DA,判别器DB。
GAB的输入是风格A的图像,输出是风格B的图像,目的是将风格A的图像转换为风格B的图像。
GBA的输入是风格B的图像,输出是风格A的图像,目的是将风格B的图像转换为风格A的图像。
DA的输入是风格A的图像,输出是对输入图像的分类,目的是判断输入图像是否为由B转换的风格A的图像
DB的输入是风格B的图像,输出是对输入图像的分类,目的是判断输入图像是否为由A转换的风格B的图像

其中的图像名称有很多,在这里进行简单的介绍。
image_A, image_B指数据集中读取的真实图像,使用DA和DB进行预测时,结果应该是全1。
fake_A指imge_B由GBA生成的风格A类型的图像,使用DA预测时,希望应该是全1,fake_B指imge_A由GAB生成的风格B类型的图像,使用DB预测时,希望应该是全1。
recon_A指fake_B由GBA生成风格A类型的图像,也就是原图image_A经过GAB,再经过GBA生成风格A的图像,希望和image_A越接近越好。
recon_B指fake_A由GAB生成风格B类型的图像,也就是原图image_B经过GBA,再经过GAB生成风格B的图像,希望和image_B越接近越好。

DiscoGAN的特点

  使用InstanceNormalization代替BatchNormalization
  生成器使用UNet网络结构对图像进行深层特征提取
  生成器损失函数采用绝对误差,判别器损失函数采用均方误差
  对生成器损失函数的权重进行调节,使网络更多关注于生成的图像质量

DiscoGAN图像分析

generator
discriminator

TensorFlow2.0实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import os
import glob
import numpy as np
import cv2 as cv
from functools import reduce
import tensorflow as tf
import tensorflow.keras as keras


def compose(*funcs):
if funcs:
return reduce(lambda f, g: lambda *a, **kw: g(f(*a, **kw)), funcs)
else:
raise ValueError('Composition of empty sequence not supported.')


class InstanceNormalization(keras.layers.Layer):
def __init__(self, beta_initializer='zeros', gamma_initializer='ones',
beta_regularizer=None, gamma_regularizer=None,
beta_constraint=None, gamma_constraint=None, epsilon=1e-5,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.epsilon = epsilon
self.beta_initializer = keras.initializers.get(beta_initializer)
self.gamma_initializer = keras.initializers.get(gamma_initializer)
self.beta_regularizer = keras.regularizers.get(beta_regularizer)
self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
self.beta_constraint = keras.constraints.get(beta_constraint)
self.gamma_constraint = keras.constraints.get(gamma_constraint)

def build(self, input_shape):
assert len(input_shape) == 4
self.gamma = self.add_weight(shape=(input_shape[-1],), name='gamma', initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer, constraint=self.gamma_constraint)
self.beta = self.add_weight(shape=(input_shape[-1],), name='beta', initializer=self.beta_initializer,
regularizer=self.beta_regularizer, constraint=self.beta_constraint)

def call(self, inputs, **kwargs):
mean, variance = tf.nn.moments(inputs, axes=[1, 2])
mean = tf.reshape(mean, shape=[-1, 1, 1, inputs.shape[-1]])
variance = tf.reshape(variance, shape=[-1, 1, 1, inputs.shape[-1]])
outputs = (inputs - mean) / tf.sqrt(variance + self.epsilon)
return outputs * self.gamma + self.beta

def get_config(self):
config = {
'epsilon': self.epsilon,
'beta_initializer': keras.initializers.serialize(self.beta_initializer),
'gamma_initializer': keras.initializers.serialize(self.gamma_initializer),
'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer),
'beta_constraint': keras.constraints.serialize(self.beta_constraint),
'gamma_constraint': keras.constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()

return dict(list(base_config.items()) + list(config.items()))


class Conv_Relu_In(keras.layers.Layer):
def __init__(self, filters, kernel_size, strides, padding, name):
super(Conv_Relu_In, self).__init__()
self._name = name
self.block = keras.Sequential([keras.layers.Conv2D(filters, kernel_size, strides, padding),
keras.layers.LeakyReLU(0.2)])
if name.find('in') != -1:
self.block.add(InstanceNormalization())

def call(self, inputs, **kwargs):

return self.block(inputs)


class Upsampling_Conv_Relu_In_Concatenate(keras.layers.Layer):
def __init__(self, filters, kernel_size, strides, padding, name):
super(Upsampling_Conv_Relu_In_Concatenate, self).__init__()
self._name = name
self.block = keras.Sequential([keras.layers.UpSampling2D((2, 2)),
keras.layers.Conv2D(filters, kernel_size, strides, padding, activation='relu'),
InstanceNormalization()])
self.concatenate = keras.layers.Concatenate()

def call(self, inputs, **kwargs):
x, shortcut = inputs
x = self.block(x)
output = self.concatenate([x, shortcut])

return output


def generator(input_shape, name):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x1 = Conv_Relu_In(64, (4, 4), (2, 2), 'same', name='conv_leakyrelu1')(x)
x2 = Conv_Relu_In(128, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in2')(x1)
x3 = Conv_Relu_In(256, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in3')(x2)
x4 = Conv_Relu_In(512, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in4')(x3)
x5 = Conv_Relu_In(512, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in5')(x4)
x6 = Conv_Relu_In(512, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in6')(x5)
x7 = Conv_Relu_In(512, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in7')(x6)

y6 = Upsampling_Conv_Relu_In_Concatenate(512, (4, 4), (1, 1), 'same', name='upsampling_conv_relu_in_concatenate1')([x7, x6])
y5 = Upsampling_Conv_Relu_In_Concatenate(512, (4, 4), (1, 1), 'same', name='upsampling_conv_relu_in_concatenate2')([y6, x5])
y4 = Upsampling_Conv_Relu_In_Concatenate(512, (4, 4), (1, 1), 'same', name='upsampling_conv_relu_in_concatenate3')([y5, x4])
y3 = Upsampling_Conv_Relu_In_Concatenate(256, (4, 4), (1, 1), 'same', name='upsampling_conv_relu_in_concatenate4')([y4, x3])
y2 = Upsampling_Conv_Relu_In_Concatenate(128, (4, 4), (1, 1), 'same', name='upsampling_conv_relu_in_concatenate5')([y3, x2])
y1 = Upsampling_Conv_Relu_In_Concatenate(64, (4, 4), (1, 1), 'same', name='upsampling_conv_relu_in_concatenate6')([y2, x1])

y = compose(keras.layers.UpSampling2D((2, 2), name='upsampling'),
keras.layers.Conv2D(3, (4, 4), (1, 1), 'same', activation='tanh', name='conv_tanh'))(y1)

model = keras.Model(input_tensor, y, name=name)

return model


def discriminator(input_shape, name):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = compose(Conv_Relu_In(64, (4, 4), (2, 2), 'same', name='conv_leakyrelu1'),
Conv_Relu_In(128, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in2'),
Conv_Relu_In(256, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in3'),
Conv_Relu_In(512, (4, 4), (2, 2), 'same', name='conv_leakyrelu_in4'),
keras.layers.Conv2D(1, (4, 4), (1, 1), 'same', name='conv'))(x)

model = keras.Model(input_tensor, x, name=name)

return model


def discogan(input_shapeA, input_shapeB, model_gAB, model_gBA, model_dA, model_dB):
input_tensorA = keras.layers.Input(input_shapeA, name='input_A')
input_tensorB = keras.layers.Input(input_shapeB, name='input_B')

# 输入风格B由BA生成的风格A类型的图像和输入风格A由AB生成的风格B类型的图像,称为假A和假B
fake_A = model_gBA(input_tensorB)
fake_B = model_gAB(input_tensorA)

# 输入假风格B由BA生成的重建风格A和假风格A由AB生成的重建风格B,称为重建A和重建B
recon_A = model_gBA(fake_B)
recon_B = model_gAB(fake_A)

model_dA.trainable = False
model_dB.trainable = False

conf_A = model_dA(fake_A)
conf_B = model_dB(fake_B)

model = keras.Model([input_tensorA, input_tensorB], [conf_A, conf_B, fake_A, fake_B, recon_A, recon_B], name='DiscoGAN')

return model


def read_data(data_path, img_size, batch_size):
filename = glob.glob(data_path + '\\*.jpg')
choose_name = np.random.choice(filename, batch_size)

image_A, image_B = [], []
for i in range(batch_size):
image = cv.imread(choose_name[i]).astype(np.float32)
image_A.append(cv.resize(image[:, 256:, :], img_size))
image_B.append(cv.resize(image[:, :256, :], img_size))

image_A = np.array(image_A) / 127.5 - 1
image_B = np.array(image_B) / 127.5 - 1

return image_A, image_B


if __name__ == '__main__':
batch_size = 2
epochs = 2000
tf.random.set_seed(22)
img_size = (128, 128)
data_path = r'.\edges2shoes\train'
save_path = r'.\discogan'
if not os.path.exists(save_path):
os.makedirs(save_path)

optimizer = keras.optimizers.Adam(0.0002, 0.5)
loss = keras.losses.BinaryCrossentropy()

real_dAmse = keras.metrics.MeanSquaredError()
fake_dAmse = keras.metrics.MeanSquaredError()
real_dBmse = keras.metrics.MeanSquaredError()
fake_dBmse = keras.metrics.MeanSquaredError()
gAmse = keras.metrics.MeanSquaredError()
gBmse = keras.metrics.MeanSquaredError()

model_dA = discriminator(input_shape=(img_size[0], img_size[1], 3), name='DiscoGAN-DiscriminatorA')
model_dA.compile(optimizer=optimizer, loss='mse')
model_dB = discriminator(input_shape=(img_size[0], img_size[1], 3), name='DiscoGAN-DiscriminatorB')
model_dB.compile(optimizer=optimizer, loss='mse')

model_gAB = generator(input_shape=(img_size[0], img_size[1], 3), name='DiscoGAN-GeneratorAB')
model_gBA = generator(input_shape=(img_size[0], img_size[1], 3), name='DiscoGAN-GeneratorBA')

model_gAB.build(input_shape=(img_size[0], img_size[1], 3))
model_gAB.summary()
keras.utils.plot_model(model_gAB, 'DiscoGAN-generatorAB.png', show_shapes=True, show_layer_names=True)

model_gBA.build(input_shape=(img_size[0], img_size[1], 3))
model_gBA.summary()
keras.utils.plot_model(model_gBA, 'DiscoGAN-generatorBA.png', show_shapes=True, show_layer_names=True)

model_dA.build(input_shape=(img_size[0], img_size[1], 3))
model_dA.summary()
keras.utils.plot_model(model_dA, 'DiscoGAN-discriminatorA.png', show_shapes=True, show_layer_names=True)

model_dB.build(input_shape=(img_size[0], img_size[1], 3))
model_dB.summary()
keras.utils.plot_model(model_dB, 'DiscoGAN-discriminatorB.png', show_shapes=True, show_layer_names=True)

model = discogan(input_shapeA=(img_size[0], img_size[1], 3), input_shapeB=(img_size[0], img_size[1], 3), model_gAB=model_gAB, model_gBA=model_gBA, model_dA=model_dA, model_dB=model_dB)
model.compile(optimizer=optimizer, loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'], loss_weights=[0.5, 0.5, 5, 5, 5, 5])

model.build(input_shape=[(img_size[0], img_size[1], 3), (img_size[0], img_size[1], 3)])
model.summary()
keras.utils.plot_model(model, 'DiscoGAN.png', show_shapes=True, show_layer_names=True)

for epoch in range(epochs):
image_A, image_B = read_data(data_path, img_size, batch_size)

fake_A = model_gBA(image_B)
fake_B = model_gAB(image_A)

real_dAmse(np.ones((batch_size, img_size[0] // 16, img_size[1] // 16, 1)), model_dA(image_A))
fake_dAmse(np.zeros((batch_size, img_size[0] // 16, img_size[1] // 16, 1)), model_dA(fake_A))
real_dBmse(np.ones((batch_size, img_size[0] // 16, img_size[1] // 16, 1)), model_dB(image_B))
fake_dBmse(np.zeros((batch_size, img_size[0] // 16, img_size[1] // 16, 1)), model_dB(fake_B))
gAmse(np.ones((batch_size, img_size[0] // 16, img_size[1] // 16, 1)), model([image_A, image_B])[0])
gBmse(np.ones((batch_size, img_size[0] // 16, img_size[1] // 16, 1)), model([image_A, image_B])[1])

real_dAloss = model_dA.train_on_batch(image_A, np.ones((batch_size, img_size[0] // 16, img_size[1] // 16, 1)))
fake_dAloss = model_dA.train_on_batch(fake_A, np.zeros((batch_size, img_size[0] // 16, img_size[1] // 16, 1)))
real_dBloss = model_dB.train_on_batch(image_B, np.ones((batch_size, img_size[0] // 16, img_size[1] // 16, 1)))
fake_dBloss = model_dB.train_on_batch(fake_B, np.zeros((batch_size, img_size[0] // 16, img_size[1] // 16, 1)))

gloss = model.train_on_batch([image_A, image_B], [np.ones((batch_size, img_size[0] // 16, img_size[1] // 16, 1)), np.ones((batch_size, img_size[0] // 16, img_size[1] // 16, 1)), image_A, image_B, image_A, image_B])

if epoch % 20 == 0:
print('epoch = {}, real_dAmse = {}, fake_dAmse = {}, real_dBmse = {}, fake_dBmse = {}, gAmse = {}, gBmse = {}'.format(epoch, real_dAmse.result(), fake_dAmse.result(), real_dBmse.result(), fake_dBmse.result(), gAmse.result(), gBmse.result()))
real_dAmse.reset_states()
fake_dAmse.reset_states()
real_dBmse.reset_states()
fake_dBmse.reset_states()
gAmse.reset_states()
gBmse.reset_states()
image_A, image_B = read_data(data_path, img_size, batch_size=1)
fake_A = ((model_gBA(image_B).numpy().squeeze() + 1) * 127.5).astype(np.uint8)
fake_B = ((model_gAB(image_A).numpy().squeeze() + 1) * 127.5).astype(np.uint8)
image_A = ((image_A.squeeze() + 1) * 127.5).astype(np.uint8)
image_B = ((image_B.squeeze() + 1) * 127.5).astype(np.uint8)
cv.imwrite(save_path + '\\epoch{}.jpg'.format(epoch), np.concatenate([np.concatenate([image_B, fake_A], axis=1), np.concatenate([image_A, fake_B], axis=1)], axis=0))

discogan

模型运行结果

discogan

小技巧

  1. 图像输入可以先将其归一化到0-1之间或者-1-1之间,因为网络的参数一般都比较小,所以归一化后计算方便,收敛较快。
  2. 注意其中的一些维度变换和numpytensorflow常用操作,否则在阅读代码时可能会产生一些困难。
  3. 可以设置一些权重的保存方式学习率的下降方式早停方式
  4. DiscoGAN对于网络结构,优化器参数,网络层的一些超参数都是非常敏感的,效果不好不容易发现原因,这可能需要较多的工程实践经验
  5. 先创建判别器,然后进行compile,这样判别器就固定了,然后创建生成器时,不要训练判别器,需要将判别器的trainable改成False,此时不会影响之前固定的判别器,这个可以通过模型的_collection_collected_trainable_weights属性查看,如果该属性为空,则模型不训练,否则模型可以训练,compile之后,该属性固定,无论后面如何修改trainable,只要不重新compile,都不影响训练。
  6. 在DiscoGAN的测试图像中,为了体现模型的效果,第一行的奇数个为鞋子的轮廓,第一行的偶数个为转换风格后的鞋子图像,第二行的奇数个为鞋子图像,第二行的偶数个为转换风格后的鞋子轮廓,这里只是训练了2000代,而且每一代只有2个图像就可以看出DiscoGAN的效果。小伙伴们可以选择更大的数据集,更加快速的GPU,训练更长的时间,这样风格迁移的效果就会更加明显。

DiscoGAN小结

  DiscoGAN是一种有效的风格迁移生成式对抗网络,和CycleGAN模型非常相似,只是更换了部分损失函数,网络结构从ResNet更换为UNet,从上图可以看出DiscoGAN模型的参数量有89M,可以实现任意风格之间的迁移,如果数据集足够,还可以生成人物表情包,是不是非常有趣呢?小伙伴们一定要掌握它。

-------------本文结束感谢您的阅读-------------
0%