WGAN

WGAN

背景介绍

  WGAN(Wasserstein Generative Adversarial Networks):于2017年提出,和LSGAN类似,没有对网络结构做太多修改,分析了GAN网络中判别器效果越好,生成器梯度消失越严重的问题,而且提出了一种新的损失函数,构建了一个更加稳定,收敛更快,质量更高的生成式对抗网络

wgan

WGAN特点

  保持GAN的网络结构不变,将判别器网络最后的sigmoid删去
  将损失函数中的log删去
  每次更新判别器的参数,将参数绝对值截断到一个固定常数c
  不使用基于动量的优化算法(Adam),推荐使用RMSProp,SGD等方法

WGAN图像分析

generator
discriminator

TensorFlow2.0实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import numpy as np
import cv2 as cv
from functools import reduce
import tensorflow as tf
import tensorflow.keras as keras


def compose(*funcs):
if funcs:
return reduce(lambda f, g: lambda *a, **kw: g(f(*a, **kw)), funcs)
else:
raise ValueError('Composition of empty sequence not supported.')


def generator(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = compose(keras.layers.Dense(256, activation='relu', name='dense_relu1'),
keras.layers.BatchNormalization(momentum=0.8, name='bn1'),
keras.layers.Dense(512, activation='relu', name='dense_relu2'),
keras.layers.BatchNormalization(momentum=0.8, name='bn2'),
keras.layers.Dense(1024, activation='relu', name='dense_relu3'),
keras.layers.BatchNormalization(momentum=0.8, name='bn3'),
keras.layers.Dense(784, activation='tanh', name='dense_tanh'),
keras.layers.Reshape((28, 28, 1), name='reshape'))(x)

model = keras.Model(input_tensor, x, name='WGAN-Generator')

return model


def discriminator(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = compose(keras.layers.Flatten(name='flatten'),
keras.layers.Dense(512, activation='relu', name='dense_relu1'),
keras.layers.Dense(256, activation='relu', name='dense_relu2'),
keras.layers.Dense(1, name='dense'))(x)

model = keras.Model(input_tensor, x, name='WGAN-Discriminator')

return model


def wgan(input_shape, model_g, model_d):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = model_g(x)
model_d.trainable = False
x = model_d(x)

model = keras.Model(input_tensor, x, name='WGAN')

return model


def wasserstein_loss(y_true, y_pred):
return -tf.reduce_mean(y_true * y_pred)


def save_picture(image, save_path, picture_num):
image = ((image + 1) * 127.5).astype(np.uint8)
image = np.concatenate([image[i * picture_num:(i + 1) * picture_num] for i in range(picture_num)], axis=2)
image = np.concatenate([image[i] for i in range(picture_num)], axis=0)
cv.imwrite(save_path, image)


if __name__ == '__main__':
(x, _), (_, _) = keras.datasets.mnist.load_data()
batch_size = 256
epochs = 50
tf.random.set_seed(22)
c = 0.01
save_path = r'.\wgan'
if not os.path.exists(save_path):
os.makedirs(save_path)

x = x[..., np.newaxis].astype(np.float32) / 127.5 - 1
x = tf.data.Dataset.from_tensor_slices(x).batch(batch_size)

optimizer_g = keras.optimizers.RMSprop(0.00005)
optimizer_d = keras.optimizers.RMSprop(0.0002)

real_dmean = keras.metrics.Mean()
fake_dmean = keras.metrics.Mean()
gmean = keras.metrics.Mean()

model_d = discriminator(input_shape=(28, 28, 1))
model_d.compile(optimizer=optimizer_d, loss=wasserstein_loss)

model_g = generator(input_shape=(100,))

model_g.build(input_shape=(100,))
model_g.summary()
keras.utils.plot_model(model_g, 'WGAN-generator.png', show_shapes=True, show_layer_names=True)

model_d.build(input_shape=(28, 28, 1))
model_d.summary()
keras.utils.plot_model(model_d, 'WGAN-discriminator.png', show_shapes=True, show_layer_names=True)

model = wgan(input_shape=(100,), model_g=model_g, model_d=model_d)
model.compile(optimizer=optimizer_g, loss=wasserstein_loss)

model.build(input_shape=(100,))
model.summary()
keras.utils.plot_model(model, 'WGAN.png', show_shapes=True, show_layer_names=True)

for epoch in range(epochs):
x = x.shuffle(np.random.randint(0, 10000))
x_db = iter(x)

for step, real_image in enumerate(x_db):
noise = np.random.normal(0, 1, (real_image.shape[0], 100))
fake_image = model_g(noise)

real_dmean(model_d(real_image))
fake_dmean(model_d(fake_image))
gmean(model_d(fake_image))

real_dloss = model_d.train_on_batch(real_image, np.ones((real_image.shape[0], 1)))
fake_dloss = model_d.train_on_batch(fake_image, -np.ones((real_image.shape[0], 1)))

_ = [x.assign(tf.clip_by_value(x, -c, c)) for x in model_d.variables]

gloss = model.train_on_batch(noise, np.ones((real_image.shape[0], 1)))

if step % 20 == 0:
print('epoch = {}, step = {}, real_dmean = {}, fake_dmean = {}, gmean = {}'.format(epoch, step, real_dmean.result(), fake_dmean.result(), gmean.result()))
real_dmean.reset_states()
fake_dmean.reset_states()
gmean.reset_states()
fake_data = np.random.normal(0, 1, (100, 100))
fake_image = model_g(fake_data)
save_picture(fake_image.numpy(), save_path + '\\epoch{}_step{}.jpg'.format(epoch, step), 10)

wgan

模型运行结果

wgan

小技巧

  1. 图像输入可以先将其归一化到0-1之间或者-1-1之间,因为网络的参数一般都比较小,所以归一化后计算方便,收敛较快。
  2. 注意其中的一些维度变换和numpytensorflow常用操作,否则在阅读代码时可能会产生一些困难。
  3. 可以设置一些权重的保存方式学习率的下降方式早停方式
  4. WGAN对于网络结构,优化器参数,网络层的一些超参数都是非常敏感的,效果不好不容易发现原因,这可能需要较多的工程实践经验
  5. 先创建判别器,然后进行compile,这样判别器就固定了,然后创建生成器时,不要训练判别器,需要将判别器的trainable改成False,此时不会影响之前固定的判别器,这个可以通过模型的_collection_collected_trainable_weights属性查看,如果该属性为空,则模型不训练,否则模型可以训练,compile之后,该属性固定,无论后面如何修改trainable,只要不重新compile,都不影响训练。
  6. 本博客中的WGAN是在GAN的基础上进行修改,当然小伙伴们也可以尝试在DCGAN,CGAN等模型上进行尝试,可能一些超参数设置的不是非常合理,所以WGAN的效果不是特别好,小伙伴们在使用时可以自己修改

WGAN小结

  WGAN在提出时对网络的损失函数进行了大量的分析,引入W距离,Lipschitz常数等等,我不是大佬,也不对数学公式进行过多的阐述,可能我说了会让小伙伴们更加迷糊,因此有需要的小伙伴们可以去网上搜索相关资料。因为WGAN基本没有修改网络结构,因此网络参数和GAN完全相同

-------------本文结束感谢您的阅读-------------
0%