本文为对抗样本生成系列文章的第二篇文章,主要对GAN的原理进行介绍,并对其中关键部分的使用pytorch代码进行介绍,另外如果有需要完整代码的同学可以关注我的github。
该系列包含的文章还包括:
GAN(Generative Adversarial Network)
GAN中文名称生成对抗网络,是一种利用模型对抗技术来生成指定类型样本的技术,与VAE一起是目前主要的两种文本生成技术之一。GAN主要包含generater(生成器)和discriminator(判别器)两部分,generator负责生成假的样本来骗过discriminator,discriminator负责对样本进行打分,判断是否为生成网络生成的样本。
Generator
输入:noise sample(一个随机生成的指定纬度向量)
输出:目标样本(fake image等)
Generator在GAN中负责接收随机的噪声输入,进行目标文本、图像的生成,其目标就是尽可能的生成更加真实的图片、文字去欺骗discriminator。具体的实现可以使用任何在其他领域证明有效的神经网络,本文使用最简单的全连接网络作为Generator进行实验。
1 | ### 生成器结构 |
Discriminator
输入:样本(包含生成的样本和真实样本两部分)
输出:score(一个是否为真实样本的分数,分数越高是真实样本的置信的越高,越低越可能时生成样本)
Discriminator在GAN网络中负责将对输入的图像、文本进行判别,对其进行打分,打分越高越接近真实的图片,打分越低越可能是Generator生成的图像、文本,其目标是尽可能准确的对真实样本与生成样本进行准确的区分。与Generator一样Discriminator也可以使用任何网络实现,下面是pytorch中最简单的一种实现。
1 | ### 判别器结构 |
Model train
GAN中由于两部分需要进行对抗,因此两部分并不是与一般神经网络一样整个网络同时进行跟新训练的,而是两部分分别进行训练。训练的基本思路如下所示:
Epoch:
1. 生成器使用初始化的参数随机输入向量生成图片。 2. 生成器进行判别,使用判别器结果对判器参数进行更新。 3. 固定判别器参数,对生成器使用更新好的判别器进行
1 | for epoch in range(num_epochs): |
从上面的实现过程我们可以发现一个问题:在进行判别模型训练损失函数的计算由两部分组成,而生成模型进行训练时只由一部分组成,并且该部分的交叉熵还是一种反常的使用方式,这是为什么呢?
损失函数
整体的损失函数表现形式:
Generator Loss
对于判别器进行训练时,其目标为:
而对比交叉熵损失函数的计算公式:
二者其实在表现形式形式上是完全一致的,这是因为判别器就是区分样本是否为真实的样本,是一个简单的0/1分类问题,所以形式与交叉熵一致。在另一个角度我们可以观察,当输入样本为真实的样本时,$E{x\in\ P{G}}\ [log(1-G(D(x)))]$为0,只剩下$E{x\in\ P{data}}\ [logD(x)]$,为了使其最大只能优化网络时D(x)尽可能大,即真实样本判别器给出的得分更高。当输入为生成样本时,$E{x\in\ P{data}}\ [logD(x)]$为0,只剩下$E{x\in\ P{G}}\ [log(1-G(D(x)))]$,为使其最大只能使D(x)尽可能小,即使生成样本判别器给出的分数尽可能低,使用交叉熵损失函数正好与目标相符。
因此,判别器训练相关的代码如下,其中可以看到损失函数直接使用了二进制交叉熵进行。
1 | criterion = nn.BCELoss() |
Discriminator Loss
对于生成器其训练的目标为:
对于生成器,在D固定的情况下,$E{x\in\ P{data}}\ [logD(x)]$为固定值,因此可以不做考虑,表达式转为:
使用该表达式作为目标函数进行参数更新存在的问题就是在训练的起始阶段,由于开始时生成样本的质量很低,因此判别器很容易给一个很低的分数,即D(x)非常小,而log(1-x)的函数在值接近0时斜率也很小,因此使用该函数作为损失函数在开始时很难进行参数更新。
因此生成器采用了一种与log(1-x)的更新方向一致并且在起始时斜率更大的函数。
该损失函数在代码实现中一般还是使用反标签的二进制交叉熵损失函数来进行实现,所谓反标签即为将生成的样本标注为1进行训练(正常生成样本标签为0),涉及到该部分的代码为:
1 | criterion = nn.BCELoss() |
GAN与VAE对比
GAN和VAE都是样本生成领域非常常用的两个模型流派,那这两种模型有什么不同点呢?
VAE进行对抗样本生成时,VAE的Encoder和GAN的Generator输入同样都为图片等真实样本,但VAE的Encoder输出的中间结果为隐藏向量值,而GAN的Generator输出的中间结果为生成的图片等生成样本。
最终用来生成样本的部分不同。VAE最终使用Decoder部分来进行样本生成,GAN使用Generator进行样本生成。
在实际的使用过程中还存在这下面的区别使GAN比VAE更被广泛使用:
VAE生成样本点的连续性不好。VAE进行生成采用的方式是每个像素点进行生成的,很难考虑像素点之间的联系,因此经常出现一些不连续的坏点。
要生成同样品质的样本,VAE需要更大的神经网络。
【参考文献】
李宏毅在线课程:https://www.youtube.com/watch?v=DQNNMiAP5lw&list=PLJV_el3uVTsMq6JEFPW35BCiOQTsoqwNw