GAN生成式对抗网络

GAN生成式对抗网络,GAN网络由生成网络和判别网络两个网络组成,生成式网络通过学习图片后将噪声转换为图片,而判别式网络将区分数据集图片和生成图片进行判别。生成式与判别式网络类似与小偷与警察,通过不断的博弈互相升级,在这场游戏中我们更加希望小偷取得胜利。既是生成式网络生成的图片可以让判别器难以判断。

网络结构

GAN网络由两部分网络过程:生成式网络和判别式网络构成
GANs.jpg

GAN中鉴别器的目标

引入了JS散度的概念来计算,事实上也可以看成是交叉熵乘一个负号,如图所示:
jS.png

pytorch实现

生成式网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.main=nn.Sequential(
nn.Linear(100,256),
nn.ReLU(),
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,784),
nn.Tanh()#对于生成器,最后一个激活函数是tanh,值域:-1到1
)
#定义前向传播
def forward(self,x): #x表示长度为100的noise输入
img = self.main(x)
img=img.view(-1,28,28)#转换成图片的形式
return img

通常生成图片采用长度为100的一维随机噪声用于生成图像,当然也能自定义

判别式网络

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main = nn.Sequential(
        nn.Linear(784,512),
        nn.LeakyReLU(),
        nn.Linear(512,256),
        nn.LeakyReLU(),
        nn.Linear(256,1),
        nn.Sigmoid()
        )
    def forward(self,x):
        x =x.view(-1,784) #展平
        x =self.main(x)
        return x

使用sigmoid激活函数是因为能够将判决概率固定到【0,1】,而我们本身是知晓程序所产生的图片的真假。

训练过程

Step 1:初始化生成器和判别器的参数
Step 2:固定住生成器更新判别器,判别器给生成器的生成图像打低分给真实标签图像打高分
Step 3:固定住判别器更新生成器,生成器通过输入的随机采样的vector生成图像尝试骗过判别器使其让它打高分。
通过交替训练两个模型使其不断完善

扩展

原始GAN网络存在训练收敛问题,容易早停,这是训练过程中需要生成的内容过于复杂,而判别网络训练过快导致生成网络得不到训练,进而早停。

WGAN网络

cycle GAN网络

风格转换网络,能够将输入图像转化风格,例如真人图像转换为二次元图像。
改模型一共需要训练四个网络
cycle GAN.png

参考文献

  1. https://openatomworkshop.csdn.net/664eed3db12a9d168eb725f4.html?dp_token=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6NTc5ODYwMywiZXhwIjoxNzI2ODE0NzE3LCJpYXQiOjE3MjYyMDk5MTcsInVzZXJuYW1lIjoiYXBwbGVfNjg0MDQ1NTkifQ.GePWcNoUb6VhMXd6SQFfUEJf-ELE8pmzPiKz4d-2vSk&spm=1001.2101.3001.6650.8&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7Eactivity-8-125825727-blog-125389000.235%5Ev43%5Epc_blog_bottom_relevance_base2&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7Eactivity-8-125825727-blog-125389000.235%5Ev43%5Epc_blog_bottom_relevance_base2
  2. https://devpress.csdn.net/awstech/64ddd6faff5c3157f8babb8b.html?dp_token=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6NTc5ODYwMywiZXhwIjoxNzI3MzQzMjczLCJpYXQiOjE3MjY3Mzg0NzMsInVzZXJuYW1lIjoiYXBwbGVfNjg0MDQ1NTkifQ.6IhDSqK8AehCrMBCyq9OUyhorbX0k-f2Pd469jexnjU&spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7Eactivity-2-120001421-blog-125323622.235%5Ev43%5Epc_blog_bottom_relevance_base2&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7Eactivity-2-120001421-blog-125323622.235%5Ev43%5Epc_blog_bottom_relevance_base2&utm_relevant_index=5