前言

CGAN这篇论文算是十分的简单了,和DCGAN的篇幅比起来不知道少到哪里去了,我这里主要挑选了论文中的主要片段进行理解

CGAN的原理

<math> <semantics> <mrow> </mrow> <annotation encoding="application&#47;x&#45;tex"> \quad </annotation> </semantics> </math>论文提出了GAN的有条件(限制)的版本,在数据中添加一个数据 y ,y 是在生成器和辨别器中都需要考虑的。对抗网络相对于马尔科夫决策链优点众多:梯度下降时,只需要反向传播算法,在学习规程中不需要做推断,许多因素以及因素之间的相互关系可以在模型中融合的很好。论文提出的 CGAN 是在某些特定条件下,增加目标或者限制(任何标签)进而影响生成器的生成过程。
<math> <semantics> <mrow> </mrow> <annotation encoding="application&#47;x&#45;tex"> \quad </annotation> </semantics> </math>对于原始的GAN就不赘述了,这里直接说到CGAN。条件生成式对抗网络(CGAN)是对原始GAN的一个扩展,生成器和判别器都增加额外信息 y为条件, y 可以使任意信息,例如类别信息,或者其他模态的数据。如 Figure 1 所示,通过将额外信息 y 输送给判别模型和生成模型,作为输入层的一部分,从而实现条件GAN。在生成模型中,先验输入噪声 p(z) 和条件信息 y 联合组成了联合隐层表征。对抗训练框架在隐层表征的组成方式方面相当地灵活。类似地,条件 GAN 的目标函数是带有条件概率的二人极小极大值博弈。条件GAN可以表示为:
网络结构如下:

实验

Mnist数据集

<math> <semantics> <mrow> </mrow> <annotation encoding="application&#47;x&#45;tex"> \quad </annotation> </semantics> </math>在MNIST上以类别标签为条件(one-hot编码)训练条件GAN,可以根据标签条件信息,生成对应的数字。生成模型的输入是100维服从均匀分布的噪声向量,条件变量y是类别标签的one hot编码。噪声z和标签y分别映射到隐层(200和1000个单元),在映射到第二层前,联合所有单元。最终有一个sigmoid生成模型的输出(784维),即28*28的单通道图像。 判别模型的输入是784维的图像数据和条件变量y(类别标签的one hot编码),输出是该样本来自训练集的概率。 结果如下:

多模态学习用于图像自动标注

<math> <semantics> <mrow> </mrow> <annotation encoding="application&#47;x&#45;tex"> \quad </annotation> </semantics> </math> 大概就是说自动标注图像:使用多标签预测。使用条件GAN生成tag-vector在图像特征条件上的分布。数据集: MIR Flickr 25,000 dataset 语言模型:训练一个skip-gram模型,带有一个200维的词向量。细节可以结合论文理解,效果大概为这样:

将来的工作

  • 提出更复杂的方法,探索CGAN的细节和详细地分析它们的性能和特性。
  • 当前生成的每个tag是相互独立的,没有体现更丰富的信息。
  • 另一个遗留下的方向是构建一个联合训练的调度方法去学习language model。

########################################################################################

利用CGAN和MNIST数据集生成特定数字

#coding=utf-8
#CGAN.py
import tensorflow as tf
import numpy as np
import os
import pickle #序列化对象并保存到磁盘中,并在需要的时候读取出来,任何对象都可以执行序列化操作
import matplotlib.pyplot as plt
# 输入数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
#创建保存模型文件的文件夹
if not os.path.exists("logdir"):
    os.makedirs("logdir")
LOGDIR = "logdir"
#定义固定参数
real_img_size = mnist.train.images[0].shape[0]
noise_size = 100
noise = 'normal0-1'
learning_rate = 0.001

#训练参数
batch_size = 100
epochs = 120

def leakyRelu(x, alpha=0.01):
    return tf.maximum(x, alpha*x)
def get_inputs(real_img_size, noise_size):
    real_img = tf.placeholder(tf.float32, shape=[None, real_img_size], name = "real_img")
    real_img_digit = tf.placeholder(tf.float32, shape=[None, 10])
    noise_img = tf.placeholder(tf.float32, shape=[None, noise_size], name = "noise_img")
    return real_img, noise_img, real_img_digit
# WX + b
def fully_connected(name, value, output_shape):
    with tf.variable_scope(name, reuse=None) as scope:
        shape = value.get_shape().as_list()
        w = tf.get_variable('w', [shape[1], output_shape], dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=0.01))
        b = tf.get_variable('b', [output_shape], dtype=tf.float32, initializer=tf.constant_initializer(0.0))
        return tf.matmul(value, w) + b
# 输入加性噪声
def get_noise(noise, batch_size):
    if noise == 'uniform':
        batch_size = np.random.uniform(-1, 1, size=(batch_size, noise_size)) #从均匀分布[low, high)中采样
    elif noise == 'normal':
        batch_size = np.random.normal(-1, 1, size=(batch_size, noise_size)) #高斯分布,参数分别为,均值,标准差,输出的shape
    elif noise == 'normal0-1':
        batch_noise = np.random.normal(0, 1, size=(batch_size, noise_size))
    elif noise == 'uniform0-1':
        batch_size = np.random.uniform(0, 1, size=(batch_size, noise_size))
    return batch_noise
# 构造生成器
def get_generator(digit, noise_img, reuse=False):
    with tf.variable_scope("generator", reuse=reuse):
        concatenated_img_digit = tf.concat([digit, noise_img], 1)
        output = fully_connected('gf1', concatenated_img_digit, 128)
        output = leakyRelu(output)
        output = tf.layers.dropout(output, rate=0.5)

        output = fully_connected('gf2', output, 128)
        output = leakyRelu(output)
        output = tf.layers.dropout(output, rate=0.5)

        logits = fully_connected('gf3', output, 784)
        outputs = tf.tanh(logits)
        return logits, outputs
# 构造鉴别器
def get_discriminator(digit, img, reuse = False):
    with tf.variable_scope("discriminator", reuse=reuse):
        concatenated_img_digit = tf.concat([digit, img], 1)
        output = fully_connected('df1', concatenated_img_digit, 128)
        output = leakyRelu(output)
        output = tf.layers.dropout(output, rate=0.5)

        output = fully_connected('df2', output, 128)
        output = leakyRelu(output)
        output = tf.layers.dropout(output, rate=0.5)

        logits = fully_connected('df3', output, 1)
        output = tf.sigmoid(logits)
        return logits, output
# 保存生成器产生的手写数字
def save_genImages(gen, epoch):
    r, c = 10, 10
    fig, axs = plt.subplots(r, c)
    cnt = 0
#    print(gen.shape)
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen[cnt][:,:], cmap='Greys_r')
            axs[i, j].axis('off')
            cnt += 1
    if not os.path.exists('gen_mnist'):
        os.makedirs('gen_mnist')
    fig.savefig('gen_mnist/%d.jpg' % epoch)
    plt.close()
#保存loss曲线
def  plot_loss(loss):
    fig, ax = plt.subplots(figsize=(20, 7))
    losses = np.array(loss)
    plt.plot(losses.T[0], label='Discriminator Loss')
    plt.plot(losses.T[1], label='Discriminator_real_loss')
    plt.plot(losses.T[2], label='Discriminator_fake_loss')
    plt.plot(losses.T[3], label='Generator Loss')
    plt.title("Training Losses")
    plt.legend()
    plt.savefig('loss1.jpg')
    plt.show()
# 保存损失函数的值
def Save_lossValue(e, epochs, train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g):
    with open('loss1.txt', 'a') as f:
        f.write("Epoch {}/{}".format(e+1, epochs), "Discriminator loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})".format(train_loss_d, train_loss_d_real, train_loss_d_fake),
                "Generator loss: {:.4f}".format(train_loss_d))

tf.reset_default_graph() #清除每次运行时,tensorflow不断增加的节点并重置整个default graph

real_img, noise_img, real_img_digit = get_inputs(real_img_size, noise_size)
# 生成器
g_logits, g_outputs = get_generator(real_img_digit, noise_img)
sample_images = tf.reshape(g_outputs, [-1, 28, 28, 1])
tf.summary.image("sample_images", sample_images, 10) #10代表要生成图像的最大批处理元素数
# 判别器
d_logits_real, d_outputs_real = get_discriminator(real_img_digit, real_img)
d_logits_fake, d_outputs_fake = get_discriminator(real_img_digit, g_outputs, reuse=True)
# 判别器损失
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real)) * (1 - 0.05))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)))
d_loss = tf.add(d_loss_fake, d_loss_real)

# 生成器损失
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake) * (1 - 0.05)))

#tesnorboard序列化损失
tf.summary.scalar("d_loss_real", d_loss_real) #用来显示标量信息
tf.summary.scalar("d_loss_fake", d_loss_fake)
tf.summary.scalar("d_loss", d_loss)
tf.summary.scalar("g_loss", g_loss)

# 分别训练生成器和判别器
# optimizer
train_vars = tf.trainable_variables()
# generator tensor
g_vars = [var for var in train_vars if var.name.startswith("generator")]
#discriminator tensor
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
# optimizer
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

summary = tf.summary.merge_all() #将所有的summary全部保存到磁盘

saver = tf.train.Saver()
def train():
    #保存loss值
    losses = []
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        summary_writer = tf.summary.FileWriter(LOGDIR, sess.graph)
        for e in range(epochs):
            for i in range(mnist.train.num_examples//(batch_size * 10)):
                for j in range(10):
                    batch = mnist.train.next_batch(batch_size)
                    digits = batch[1]
                    images = batch[0].reshape((batch_size, 784))
                    images = 2 * images - 1 #生成器激活函数tanh(-1,1),将原始图像(0-1)也变为(-1,1)
                    noises = get_noise(noise, batch_size)
                    sess.run([d_train_opt, g_train_opt], feed_dict={real_img:images, noise_img:noises, real_img_digit:digits})

            #训练损失
            summary_str, train_loss_d_real, train_loss_d_fake, train_loss_g = sess.run([summary, d_loss_real, d_loss_fake, g_loss],
                                                                                       feed_dict={real_img : images, noise_img : noises, real_img_digit : digits})
            train_loss_d = train_loss_d_fake + train_loss_d_real
            losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))
            summary_writer.add_summary(summary_str, e)
            summary_writer.flush()
            print("Epoch {}/{}".format(e+1, epochs), "Discriminator loss : {:.4f}(Real: {:.4f} + Fake: {:.4f})".format(train_loss_d,
                                                                                                                       train_loss_d_real, train_loss_d_fake),
                  "Generator loss: {:.4f}".format(train_loss_g))
            #保存模型
            saver.save(sess, 'checkpoints/cgan.ckpt')
            #查看每轮结果
            gen_sample = get_noise(noise, batch_size)
            lable = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0] * batch_size #给定标签条件生成指定的数字
            labels = np.array(lable)
            labels = labels.reshape(-1, 10)
            _, gen = sess.run(get_generator(real_img_digit, noise_img, reuse=True), feed_dict={noise_img:gen_sample, real_img_digit:labels})
            if e % 1 == 0:
                gen = gen.reshape(-1, 28, 28)
                gen = (gen + 1) / 2 #拉回到原来取值范围
                save_genImages(gen, e)
        plot_loss(losses)

def test():
    saver = tf.train.Saver(var_list=g_vars)
    with tf.Session() as sess:
        saver.restore(sess, 'checkpoints/cgan.ckpt')
        sample_noise = get_noise(noise, batch_size)
        label = [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]*100
        labels = np.array(label)
        labels = labels.reshape(-1, 10)
        _, gen_samples = sess.run(get_generator(real_img_digit, noise_img, reuse=True), feed_dict={noise_img:sample_noise, real_img_digit:labels})
        for i in range(len(gen_samples)):
            plt.imshow(gen_samples[i].reshape(28, 28), cmap='Greys_r')
            plt.show()

if __name__ == '__main__':
    test()