譯者 | Sambodhi??
生成對抗網絡(Generative Adversarial Network,GAN)由 Goodfellow 等人在 2014 年提出,它徹底改變了計算機視覺中的圖像生成領域:沒有人能夠相信這些令人驚嘆而生動的圖像實際上是純粹由機器生成的。
事實上,人們曾經認為生成的任務是不可能的,并且被 GAN 的力量所震驚,因為傳統上,根本沒有任何事實可以比較我們生成的圖像。
本文介紹了創建 GAN 背后的簡單直覺,然后介紹了通過 PyTorch 實現的卷積 GAN 及其訓練過程。
GAN 背后的直覺
不同于傳統分類方法,我們的網絡預測可以直接與事實的正確答案相比較,而生成圖像的“正確性”是很難定義和衡量的。Goodfellow 等人在他們的原創論文《生成對抗網絡》(Generative Adversarial Network)中提出了一個有趣的想法:使用經過訓練的分類器來區分生成的圖像和實際圖像。如果存在這樣的分類器,我們可以創建并訓練一個生成器網絡,直到它輸出的圖像能完全騙過分類器。
圖 1 GAN 管道
GAN 是這一過程的產物:它包含一個根據給定的數據集生成圖像的生成器,以及一個區分圖像是真實的還是生成的判別器(分類器)。GAN 的詳細管道見圖 1。
損失函數
對生成器和判別器進行優化都很困難,因為正如你所想象的那樣,這兩個網絡的目標完全相反:生成器希望盡可能地創造出真實的東西,但判別器希望區分生成的材料。
為了說明這一點,我們讓 D(x) 是判別器的輸出,也就是 x 是真實圖像的概率,而 G(z) 是我們的生成器的輸出。判別器類似于一個二元分類器,因此判別器的目標是使函數最大化:
本質上是二元交叉熵損失,沒有開頭的負號。另一方面,生成器的目標是使判別器做出正確判斷的機會最小化,因此它的目標是最小化函數。所以,最終的損失函數將是兩個分類器之間的一個極小極大博弈(minimax game),具體如下:
從理論上講,這將收斂到判別器,預測所有事件的概率為 0.5。
但在實踐中,極小極大博弈往往會導致網絡無法收斂,因此仔細調整訓練過程非常重要。像學習率這樣的超參數對于訓練 GAN 時顯然更為重要:一個微小的變化會導致 GAN 產生一個輸出,而與輸入噪聲無關。
運算環境
庫
我們通過 PyTorch 庫(包括 torchvision)來構建整個程序。GAN 的生成結果的可視化是通過 Matplotlib 庫繪制的。下面的代碼導入了所有的庫:
importGAN.py
""" Import necessary libraries to create a generative adversarial network The code is mainly developed using the PyTorch library """ import time import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import transforms from model import discriminator, generator import numpy as np import matplotlib.pyplot as plt
數據集
在 GAN 訓練中,數據集是一個重要方面。圖像的非結構化性質意味著任何給定的類別(如狗、貓或手寫的數字)都可以有一個可能的數據分布,而這種分布最終是 GAN 生成內容的基礎。
為了演示,本文將使用最簡單的 MNIST 數據集,其中包含 60000 張從 0 到 9 的手寫數字圖像。事實上,像 MNIST 這樣的非結構化數據集可以在 Graviti 上找到。這是一家年輕的創業公司,他們希望通過非結構化數據集為社區提供幫助,在他們的 平臺 上有一些最好的公共非結構化數據集,包括 MNIST。
硬件要求
最好的方法是用 GPU 訓練神經網絡,它可以顯著地提高訓練速度。但是,如果只有 CPU 可用,你仍然可以測試程序。要使你的程序能夠自行確定硬件,你可以使用以下方法:
torchDevice.py
""" Determine if any GPUs are available """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
實施
網絡架構
由于數字的簡單性,這兩種架構——判別器和生成器,都是由全連接層構建的。請注意,在某些情況下,全連接的 GAN 也比 DCGAN 略微容易收斂。
以下是兩種架構的 PyTorch 實現:
GANArchitecture.py
""" Network Architectures The following are the discriminator and generator architectures """ class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 1) self.activation = nn.LeakyReLU(0.1) def forward(self, x): x = x.view(-1, 784) x = self.activation(self.fc1(x)) x = self.fc2(x) return nn.Sigmoid()(x) class generator(nn.Module): def __init__(self): super(generator, self).__init__() self.fc1 = nn.Linear(128, 1024) self.fc2 = nn.Linear(1024, 2048) self.fc3 = nn.Linear(2048, 784) self.activation = nn.ReLU() def forward(self, x): x = self.activation(self.fc1(x)) x = self.activation(self.fc2(x)) x = self.fc3(x) x = x.view(-1, 1, 28, 28) return nn.Tanh()(x)訓練
在訓練 GAN 時,我們優化了判別器的結果,同時也改進了我們的生成器。這樣,在每次迭代過程中會有兩個相互矛盾的損失來同時優化它們。我們送入生成器的是隨機噪聲,而生成器理應根據給定噪聲的微小差異來生成圖像:
trainGAN.py
""" Network training procedure Every step both the loss for disciminator and generator is updated Discriminator aims to classify reals and fakes Generator aims to generate images as realistic as possible """ for epoch in range(epochs): for idx, (imgs, _) in enumerate(train_loader): idx += 1 # Training the discriminator # Real inputs are actual images of the MNIST dataset # Fake inputs are from the generator # Real inputs should be classified as 1 and fake as 0 real_inputs = imgs.to(device) real_outputs = D(real_inputs) real_label = torch.ones(real_inputs.shape[0], 1).to(device) noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device) outputs = torch.cat((real_outputs, fake_outputs), 0) targets = torch.cat((real_label, fake_label), 0) D_loss = loss(outputs, targets) D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() # Training the generator # For generator, goal is to make the discriminator believe everything is 1 noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device) G_loss = loss(fake_outputs, fake_targets) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() if idx % 100 == 0 or idx == len(train_loader): print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item())) if (epoch+1) % 10 == 0: torch.save(G, 'Generator_epoch_{}.pth'.format(epoch)) print('Model saved.')?結? 果
當 100 個輪數(epoch)之后,我們可以繪制數據集,并看到從隨機噪音中生成的數字的結果:
圖 2:GAN 生成的結
如上圖所示,生成的結果看起來確實相當像真實的結果。鑒于網絡非常簡單,所以結果看起來確實很有希望!
超越單純的內容創作
GAN 的創造與計算機視覺領域的先前工作如此不同。隨后的眾多應用使學術界對深度網絡的能力感到驚訝。下面將介紹一些令人驚訝的工作。
CycleGAN
Zhu 等人的 CycleGAN 引入了一種概念,它無需配對樣本就可以將圖像從 X 域翻譯成 Y 域。馬被轉化為斑馬,夏日的陽光被轉化為暴風雪,CycleGAN 的結果令人驚訝且準確。
GauGAN
Nvidia 利用 GAN 的力量,把簡單的繪畫,根據畫筆的語義,轉換成優雅而逼真的照片。盡管訓練資源的計算成本很高,但它創造了一個全新的研究和應用領域。
AdvGAN
GAN 還擴展到清理對抗性圖像,并將其轉化為不會欺騙分類器的干凈樣本。關于對抗性攻擊和防御的更多信息可以在 這里 到。
結? 語
所以,你已經擁有了它!希望這篇文章對如何構建 GAN 提供了一個概覽。
作者簡介:
Ta-ying Cheng,中國香港人,牛津大學哲學博士新生,愛好 3D 視覺、深度學習。
編輯:黃飛
?
評論
查看更多