CGAN

CGAN

背景介绍

  CGAN(Conditional Generative Adversarial Networks, 条件生成式对抗网络):于2014年提出,引入标签变量,可以通过控制其标签变量的值,产生不同类别的图像,其网络结构和GAN基本类似,只是多了一些条件变量的处理。

cgan

CGAN特点

  生成器的输入有两个,一个是随机数,一个是标签数据的one-hot编码形式,利用Concatenate层将两个输入融合
  判别器的输入也有两个,一个是输入图像,一个是标签数据的one-hot编码形式,首先利用Flatten将输入图像转化维一维向量,然后利用Concatenate层将两个输入融合

CGAN图像分析

generator
discriminator

TensorFlow2.0实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import numpy as np
import cv2 as cv
from functools import reduce
import tensorflow as tf
import tensorflow.keras as keras


def compose(*funcs):
if funcs:
return reduce(lambda f, g: lambda *a, **kw: g(f(*a, **kw)), funcs)
else:
raise ValueError('Composition of empty sequence not supported.')


def generator(input_label_shape, input_noise_shape):
label = keras.layers.Input(input_label_shape, name='input_label')
noise = keras.layers.Input(input_noise_shape, name='input_noise')

x = keras.layers.Concatenate(name='concatenate')([noise, label])

x = compose(keras.layers.Dense(256, activation='relu', name='dense_relu1'),
keras.layers.BatchNormalization(momentum=0.8, name='bn1'),
keras.layers.Dense(512, activation='relu', name='dense_relu2'),
keras.layers.BatchNormalization(momentum=0.8, name='bn2'),
keras.layers.Dense(1024, activation='relu', name='dense_relu3'),
keras.layers.BatchNormalization(momentum=0.8, name='bn3'),
keras.layers.Dense(784, activation='tanh', name='dense_tanh'),
keras.layers.Reshape((28, 28, 1), name='reshape'))(x)

model = keras.Model([noise, label], x, name='CGAN-Generator')

return model


def discriminator(input_image_shape, input_label_shape):
label = keras.layers.Input(input_label_shape, name='input_label')
image = keras.layers.Input(input_image_shape, name='input_image')

image_tensor = keras.layers.Flatten(name='flatten')(image)

x = keras.layers.Concatenate(name='concatenate')([image_tensor, label])

x = compose(keras.layers.Dense(512, activation='relu', name='dense_relu1'),
keras.layers.Dense(256, activation='relu', name='dense_relu2'))(x)

conf = keras.layers.Dense(1, activation='sigmoid', name='dense_sigmoid')(x)

model = keras.Model([image, label], conf, name='CGAN-Discriminator')

return model


def cgan(input_noise_shape, input_label_shape, model_g, model_d):
label = keras.layers.Input(input_label_shape, name='input_label')
noise = keras.layers.Input(input_noise_shape, name='input_noise')

x = model_g([noise, label])
model_d.trainable = False
conf = model_d([x, label])

model = keras.Model([noise, label], conf, name='CGAN')

return model


def save_picture(image, save_path, picture_num):
image = ((image + 1) * 127.5).astype(np.uint8)
image = np.concatenate([image[i * picture_num:(i + 1) * picture_num] for i in range(picture_num)], axis=2)
image = np.concatenate([image[i] for i in range(picture_num)], axis=0)
cv.imwrite(save_path, image)


if __name__ == '__main__':
(x, y), (_, _) = keras.datasets.mnist.load_data()
batch_size = 256
epochs = 20
tf.random.set_seed(22)
save_path = r'.\cgan'
if not os.path.exists(save_path):
os.makedirs(save_path)

x = x[..., np.newaxis].astype(np.float32) / 127.5 - 1
y = tf.one_hot(y, depth=10)
x = tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)

optimizer = keras.optimizers.Adam(0.0002, 0.5)
loss = keras.losses.BinaryCrossentropy()

real_dacc = keras.metrics.BinaryAccuracy()
fake_dacc = keras.metrics.BinaryAccuracy()
gacc = keras.metrics.BinaryAccuracy()

model_d = discriminator(input_image_shape=(28, 28, 1), input_label_shape=(10,))
model_d.compile(optimizer=optimizer, loss=['binary_crossentropy'])

model_g = generator(input_noise_shape=(100,), input_label_shape=(10,))

model_g.build(input_shape=[(100,), (10,)])
model_g.summary()
keras.utils.plot_model(model_g, 'CGAN-generator.png', show_shapes=True, show_layer_names=True)

model_d.build(input_shape=[(28, 28, 1), (10,)])
model_d.summary()
keras.utils.plot_model(model_d, 'CGAN-discriminator.png', show_shapes=True, show_layer_names=True)

model = cgan(input_noise_shape=(100,), input_label_shape=(10,), model_g=model_g, model_d=model_d)
model.compile(optimizer=optimizer, loss=['binary_crossentropy'])

model.build(input_shape=[(100,), (10,)])
model.summary()
keras.utils.plot_model(model, 'CGAN.png', show_shapes=True, show_layer_names=True)

for epoch in range(epochs):
x = x.shuffle(np.random.randint(0, 10000))
x_db = iter(x)

for step, (real_image, real_label) in enumerate(x_db):
noise = np.random.normal(0, 1, (real_image.shape[0], 100)).astype(np.float32)
fake_label = tf.one_hot(np.random.randint(0, 10, (real_image.shape[0])), depth=10)

fake_image = model_g([noise, fake_label])

real_dacc(np.ones((real_image.shape[0], 1)), model_d([real_image, real_label]))
fake_dacc(np.zeros((real_image.shape[0], 1)), model_d([fake_image, fake_label]))
gacc(np.ones((real_image.shape[0], 1)), model([noise, fake_label]))

real_dloss = model_d.train_on_batch([real_image, real_label], np.ones((real_image.shape[0], 1)))
fake_dloss = model_d.train_on_batch([fake_image, fake_label], np.zeros((real_image.shape[0], 1)))
gloss = model.train_on_batch([noise, fake_label], np.ones((real_image.shape[0], 1)))

if step % 20 == 0:
print('epoch = {}, step = {}, real_dacc = {}, fake_dacc = {}, gacc = {}'.format(epoch, step, real_dacc.result(), fake_dacc.result(), gacc.result()))
real_dacc.reset_states()
fake_dacc.reset_states()
gacc.reset_states()
fake_data = np.random.normal(0, 1, (100, 100)).astype(np.float32)
fake_label = tf.one_hot(np.array(list(range(10)) * 10), depth=10)
fake_image = model_g([fake_data, fake_label])
save_picture(fake_image.numpy(), save_path + '\\epoch{}_step{}.jpg'.format(epoch, step), 10)

cgan

模型运行结果

cgan

小技巧

  1. 图像输入可以先将其归一化到0-1之间或者-1-1之间,因为网络的参数一般都比较小,所以归一化后计算方便,收敛较快。
  2. 注意其中的一些维度变换和numpytensorflow常用操作,否则在阅读代码时可能会产生一些困难。
  3. 可以设置一些权重的保存方式学习率的下降方式早停方式
  4. CGAN对于网络结构,优化器参数,网络层的一些超参数都是非常敏感的,效果不好不容易发现原因,这可能需要较多的工程实践经验
  5. 先创建判别器,然后进行compile,这样判别器就固定了,然后创建生成器时,不要训练判别器,需要将判别器的trainable改成False,此时不会影响之前固定的判别器,这个可以通过模型的_collection_collected_trainable_weights属性查看,如果该属性为空,则模型不训练,否则模型可以训练,compile之后,该属性固定,无论后面如何修改trainable,只要不重新compile,都不影响训练。
  6. CGAN的网络结构和GAN基本相同,因此也只适合小图像的生成,其中生成器使用Concatenate实现标签数据和输入随机数的结合,判别器使用Concatenate实现标签数据和输入图像的结合,一定要注意首先要对图像进行Flatten处理,否则会出错

CGAN小结

  CGGAN是一种简单的生成式对抗网络,从上图可以看出CGAN模型的参数量只有2M,和普通的GAN网络差不多,通过CGAN可以实现指定类别的图像生成,不再是完全的随机数产生,因此对于实际的工程应用是有意义的,值得小伙伴们学习。

-------------本文结束感谢您的阅读-------------
0%