生成對抗網(wǎng)絡(luò)(Generative Adversarial Network,GAN)迅猛地占領(lǐng)了機(jī)器學(xué)習(xí)社區(qū)。優(yōu)雅的理論基礎(chǔ)和在計算機(jī)視覺領(lǐng)域不斷提升的優(yōu)越表現(xiàn)使其成為近年來機(jī)器學(xué)習(xí)最活躍的研究課題之一。事實(shí)上,F(xiàn)acebook AI Research的領(lǐng)導(dǎo)人Yann Lecun在2016年說過,“在我看來,GAN及其新提出的變體是機(jī)器學(xué)習(xí)在過去10年最有意思的想法。”想要了解這一課題的最新進(jìn)展,請參閱這篇The GAN Zoo(GAN動物園)。
盡管GAN已被證明是很出色的圖像生成模型,例如生成面部圖像和臥室圖像,GAN尚未在其他數(shù)據(jù)集上進(jìn)行過廣泛測試,例如由工廠提供的數(shù)據(jù)集,其中包含大量來自生產(chǎn)線上的傳感器的測量值。不同于諸如圖片之類的靜態(tài)數(shù)據(jù),這樣的數(shù)據(jù)集甚至可能包括時序信息,機(jī)器學(xué)習(xí)模型需要利用這些時序信息預(yù)測未來的事件。在這類數(shù)據(jù)上應(yīng)用生成模型可能很有用,例如,如果我們的預(yù)測模型需要更多樣本進(jìn)行訓(xùn)練以提升其概括性。另外,如果我們提出一個可以生成優(yōu)質(zhì)合成數(shù)據(jù)的模型,那么這個模型必定學(xué)習(xí)到了原始數(shù)據(jù)的潛在結(jié)構(gòu)。既然模型學(xué)習(xí)到了潛在結(jié)構(gòu),預(yù)測模型就可以將該表示作為新特征集來利用!
本文將介紹一些可能有助于數(shù)據(jù)集增強(qiáng)的GAN體系結(jié)構(gòu),包括樣本增強(qiáng)和特征增強(qiáng)。讓我們從基本的GAN開始。
生成對抗網(wǎng)絡(luò)
GAN模型由兩部分組成:生成器(generator)和判別器(discriminator)。這里我們認(rèn)為它們都是由參數(shù)確定的神經(jīng)網(wǎng)絡(luò):G和D。判別網(wǎng)絡(luò)的參數(shù)為最大化正確區(qū)分真實(shí)數(shù)據(jù)和偽造數(shù)據(jù)(生成網(wǎng)絡(luò)偽造的數(shù)據(jù))的概率這一目標(biāo)而優(yōu)化,而生成網(wǎng)絡(luò)的目標(biāo)是最大化判別網(wǎng)絡(luò)不能識別其偽造的樣本的概率。
生成網(wǎng)絡(luò)如此產(chǎn)生樣本:接受一個輸入向量z,該向量取樣自一個潛分布(latent distribution),應(yīng)用由網(wǎng)絡(luò)定義的函數(shù)G至該向量,得到G(z)。判別網(wǎng)絡(luò)交替接受G(z)和x(一個真實(shí)數(shù)據(jù)樣本),輸出輸入為真的概率。
通過適當(dāng)?shù)某瑓?shù)調(diào)優(yōu)和足夠的訓(xùn)練迭代次數(shù),生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)將一起收斂(通過梯度下降方法進(jìn)行參數(shù)更新)至描述偽造數(shù)據(jù)的分布和取樣真實(shí)數(shù)據(jù)的分布相一致的點(diǎn)。
本文接下來的部分將通過基于MNIST數(shù)據(jù)集生成新數(shù)字或編碼原始數(shù)字至潛空間來演示GAN是如何工作的。我們也會看下如何將GAN應(yīng)用到類別數(shù)據(jù)和時序數(shù)據(jù)上。
作為開始,下面是一個在MNIST數(shù)據(jù)集上訓(xùn)練的、基于多層感知器(MLP)的簡單GAN模型生成的一些樣本。
圖二:生成新數(shù)字
GAN并非盡善盡美
盡管GAN能如我們所見的那樣工作,在實(shí)踐中,GAN有一些缺點(diǎn),自Ian Goodfellow等在2014年發(fā)表GAN的原始論文起,如何克服GAN的缺點(diǎn)一直是研究的熱點(diǎn)。GAN的主要缺點(diǎn)涉及它的訓(xùn)練,GAN因極難訓(xùn)練而聲名狼藉:首先,GAN的訓(xùn)練高度依賴超參數(shù)。其次,也是最重要的,(生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)的)損失函數(shù)不提供必要的信息:盡管生成的樣本可能已經(jīng)開始貼切地重現(xiàn)真實(shí)數(shù)據(jù)——顯著逼近真實(shí)數(shù)據(jù)的分布——一般而言無法通過損失的趨勢來指示這一表現(xiàn)。這意味著我們不能基于損失運(yùn)行skopt之類的超參數(shù)優(yōu)化器,相反必須手工迭代調(diào)優(yōu),真是可恥。
GAN架構(gòu)的另一個缺點(diǎn)和它的功能有關(guān)。使用圖一顯示的基于原始的交叉熵?fù)p失的GAN,我們無法:
控制生成什么數(shù)據(jù)。
生成類別數(shù)據(jù)。
訪問潛空間以便將其作為特征使用。
生成類別數(shù)據(jù)對GAN而言是一個特大難題。Ian Goodfellow在這個reddit帖子中以非常直觀的方式解釋了這一點(diǎn):
僅當(dāng)合成數(shù)據(jù)基于連續(xù)數(shù)值時,你才能對合成數(shù)據(jù)作出微小的改動。基于離散數(shù)值無法作出微小的改動。
例如,如果你輸出的圖像的像素值為1.0,你可以在下一步將該像素值改為1.0001.
如果你輸出單詞“企鵝”,你無法在下一步將其修改為“企鵝 + .001”,因為并不存在“企鵝 + .001”這樣的單詞。你需要經(jīng)歷從“企鵝”到“鴕鳥”的整個過程。
關(guān)鍵的想法是,生成網(wǎng)絡(luò)不可能從一個實(shí)體(如“企鵝”)一路前進(jìn)到另一個實(shí)體(如“鴕鳥”)。因為兩者之間的空間出現(xiàn)實(shí)體的概率為0,判別網(wǎng)絡(luò)可以輕易地識別出該空間內(nèi)的樣本是不真實(shí)的,因而它不可能被生成網(wǎng)絡(luò)所愚弄。
GAN變體
為了解決原始GAN的問題,研發(fā)了一些其他的訓(xùn)練方式和架構(gòu)。下面將加以簡要介紹。這些介紹的目標(biāo)是讓你對如何應(yīng)用這些方法至結(jié)構(gòu)化數(shù)據(jù)(比如Kaggle競賽中的數(shù)據(jù))有所了解。
條件GAN
前面提到的GAN能生成看起來像MNIST數(shù)據(jù)集中的隨機(jī)數(shù)字。但是如果我們想生成特定數(shù)字呢?只需在訓(xùn)練過程中做出一個小小的改動,我們就能告訴生成網(wǎng)絡(luò)生成我們所要求的數(shù)字。在每次迭代中,生成網(wǎng)絡(luò)的輸入不僅包括z,還包括指明數(shù)字的one-hot編碼向量。同樣,判別網(wǎng)絡(luò)的輸入不僅包括真實(shí)樣本或偽造樣本,還包括同樣的標(biāo)簽向量。
圖三:條件GAN
基于與前述GAN相同的流程,但是加上了這一輸出上的微小改動,條件GAN(CGAN)學(xué)習(xí)生成以輸入的標(biāo)簽為條件的樣本。
讓我們?yōu)槊總€數(shù)字生成一個樣本!在潛空間取樣時,我們同時輸入一個one-hot編碼的向量指明我們所需的分類。對所有10個分類中的數(shù)字進(jìn)行這一過程,得到圖四的結(jié)果:
圖四:根據(jù)條件生成的數(shù)字樣本
Wasserstein GAN(WGAN)是最流行的GAN之一,它改變了目標(biāo),從而提高了訓(xùn)練穩(wěn)定性和可解釋性(損失和樣本質(zhì)量的相關(guān)性),同時能夠生成類別數(shù)據(jù)。關(guān)鍵點(diǎn)在于,生成網(wǎng)絡(luò)的目標(biāo)是逼近真實(shí)數(shù)據(jù)分布,因此衡量分布間的距離的指標(biāo)很重要,因為該指標(biāo)將是最小化的目標(biāo)。WGAN選擇了Wasserstein距離。Wasserstein距離也稱為推土機(jī)(Earth-Mover)距離。另外,WGAN實(shí)際上采用的是Wasserstein距離的近似。WGAN選擇Wasserstein距離是因為Wasserstein距離能在Kullback-Leibler散度和Jensen-Shannon散度無法收斂的分布上收斂。如果你對理論感興趣,可以看下原始論文或這篇出色總結(jié)Read-through: Wasserstein GAN。
在實(shí)現(xiàn)層面,總結(jié)一下逼近Wasserstein距離意味著什么:
判別器的輸出不再是概率了,這也是將判別器改名為批評者(critic)的動機(jī)。
判別器的參數(shù)截斷至某個閾值(或者進(jìn)行梯度懲罰)。
在每個訓(xùn)練迭代中,判別器的參數(shù)比生成器的參數(shù)更新更頻繁。
用于類別數(shù)據(jù)的Wasserstein GAN
WGAN論文的作者展示了通過這種方式訓(xùn)練的GAN顯示了訓(xùn)練上的穩(wěn)定性和可解釋性,但之后有研究證明,Wasserstein距離的使用賦予了GAN生成類別(categorical)數(shù)據(jù)的能力(即,并非圖像之類的連續(xù)值數(shù)據(jù),甚至不是像用1表示周日、用2表示周一這樣的整型編碼數(shù)據(jù))。當(dāng)在這類數(shù)據(jù)上訓(xùn)練原始的GAN時,判別網(wǎng)絡(luò)的損失會在多次迭代中保持較低的水平,而生成網(wǎng)絡(luò)的損失會不停增長。而WGAN在類別數(shù)據(jù)上訓(xùn)練的方式和在連續(xù)值數(shù)據(jù)一樣。
我們只需如此做(圖五是一個例子):數(shù)據(jù)集中的每個類別變量都對應(yīng)一個生成網(wǎng)絡(luò)的softmax輸出,該輸出的維度和可能的離散值數(shù)目相等。判別網(wǎng)絡(luò)并不接受one-hot編碼的softmax輸出作為輸入,相反,將原始的softmax輸出當(dāng)做一組連續(xù)值變量,傳給判別網(wǎng)絡(luò)作為輸入。這樣訓(xùn)練就能收斂!在測試時,只需one-hot編碼生成網(wǎng)絡(luò)的離散輸出即可生成偽造的類別數(shù)據(jù)。
圖五:混合類別變量和連續(xù)變量的生成器的例子
上圖中的類別變量1為3個可能值中的1個,類別變量2為2個可能值中的1個。此外還有1個連續(xù)變量。
圖六展示了一個在類別值的數(shù)據(jù)集上訓(xùn)練基于梯度懲罰的WGAN的例子,你可以在圖中看到穩(wěn)定的、收斂的損失函數(shù)的美麗曲線。這一個例子是在Kaggle競賽中的Sberbank Russian Housing Market數(shù)據(jù)集(俄羅斯聯(lián)邦儲蓄銀行的房產(chǎn)市場數(shù)據(jù)集)上訓(xùn)練的,該數(shù)據(jù)集同時包含連續(xù)變量和類別變量。
圖六:在俄羅斯聯(lián)邦儲蓄銀行的房產(chǎn)市場數(shù)據(jù)集上訓(xùn)練的WGAN-GP
當(dāng)然,你也可以組合WGAN和CGAN,以監(jiān)督學(xué)習(xí)的方式訓(xùn)練WGAN,以生成以分類標(biāo)簽為條件的樣本!
注意:Cramer GAN進(jìn)一步改進(jìn)了Wasserstein GAN,其目標(biāo)是提供質(zhì)量更優(yōu)的樣本,同時提高訓(xùn)練穩(wěn)定性。是否能用它生成類別數(shù)據(jù)是以后的研究課題。
雙向GAN
盡管WGAN看上去解決了很多問題,但它不允許訪問數(shù)據(jù)的潛空間表示。尋找這樣的表示可能很有幫助,不僅是因為可以通過在潛空間的連續(xù)移動控制生成什么樣的數(shù)據(jù),還因為可以通過潛空間提取特征。
圖七:雙向GAN
雙向GAN(Bidirectional GAN,BiGAN)是解決這一問題的一個嘗試。它如此工作:不僅學(xué)習(xí)一個生成式網(wǎng)絡(luò),同時學(xué)習(xí)一個編碼網(wǎng)絡(luò)E,該編碼網(wǎng)絡(luò)映射數(shù)據(jù)至生成網(wǎng)絡(luò)的潛空間。對抗配置中,使用一個判別網(wǎng)絡(luò)應(yīng)對生成任務(wù)和編碼任務(wù)。BiGAN的作者展示了,在這一限制下,G和E這一對網(wǎng)絡(luò)形成了一個自動編碼器(autoencoder):通過E編碼數(shù)據(jù)樣本,再通過G解碼,可以得到原始樣本。
InfoGAN
之前我們看到,CGAN允許調(diào)節(jié)生成網(wǎng)絡(luò)以根據(jù)標(biāo)簽生成樣本。不過,是否可以通過在GAN的潛空間中強(qiáng)制一個類別化的結(jié)構(gòu),以完全無監(jiān)督的方式學(xué)習(xí)辨別數(shù)字呢?可不可以設(shè)置一個連續(xù)的代碼空間,讓我們可以訪問這一空間以描述數(shù)據(jù)樣本的連續(xù)語義變體呢?(在MNIST的例子中,連續(xù)語義變體可能是數(shù)字的寬度和斜度。)
上述兩個問題的答案都是可以。比那更好的是:我們可以同時做到這兩點(diǎn)。真相是,我們可以施加任何我們發(fā)現(xiàn)有用的代碼空間分布,然后訓(xùn)練GAN編碼這些分布中有意義的特性。每份代碼將學(xué)習(xí)包含數(shù)據(jù)的不同語義特性,結(jié)果等效于信息退相干(information disentanglement)。
圖八:InfoGAN
允許我們這么干的GAN是InfoGAN。簡單來說,InfoGAN試圖最大化生成網(wǎng)絡(luò)代碼空間和推斷網(wǎng)絡(luò)輸出的共同信息。推斷網(wǎng)絡(luò)可以簡單配置為判別網(wǎng)絡(luò)的一個輸出層,共享其他參數(shù),意味著它是算力免費(fèi)(computationally free)的。一旦訓(xùn)練完成,InfoGAN的判別網(wǎng)絡(luò)的推斷輸出層可以用來提取特征,或者,如果代碼空間包含標(biāo)簽信息,可以用來分類!
創(chuàng)建一個配有兩個代碼空間的InfoGAN——一個連續(xù)的二維空間和一個離散的十維空間——我們能夠以離散代碼為條件生成特定的數(shù)字,同時以連續(xù)代碼為條件生成特定風(fēng)格的數(shù)字,生成如圖九所示的數(shù)據(jù)。注意,在整個無監(jiān)督學(xué)習(xí)計劃中,沒有標(biāo)簽的位置——在潛空間中施加一個類別分布足以讓模型學(xué)習(xí)編碼該分布的標(biāo)簽信息!
圖九:固定生成網(wǎng)絡(luò)的離散代碼輸入,同時使用不同的連續(xù)代碼輸入
對抗自動編碼器
圖十:對抗自動編碼器
對抗自動編碼器(Adversarial Autoencoder,AAE)結(jié)合了自動編碼器和GAN。這一模型優(yōu)化兩個目標(biāo):其一,最小化通過編碼網(wǎng)絡(luò)P和解碼網(wǎng)絡(luò)Q的數(shù)據(jù)x的重建錯誤。其二,通過對抗訓(xùn)練在代碼P(x)上施加一個先驗分布,在對抗訓(xùn)練中,P為生成網(wǎng)絡(luò)。所以,優(yōu)化P和Q以最小化x和Q(z)的距離,其中z是自動編碼器的代碼空間向量,同時優(yōu)化作為GAN的P和D,以迫使代碼空間P(x)匹配預(yù)先定義的結(jié)構(gòu)。這可以看成對自動編碼器的正則化,迫使它學(xué)習(xí)有意義、結(jié)構(gòu)化、內(nèi)聚的代碼空間(而不是斷裂的代碼空間,參考Geoffrey Hinton講座筆記第76頁),以允許進(jìn)行有效的特征提取和降維。同時,由于在代碼上施加了一個已知先驗分布,從該先驗分布取樣,并將樣本傳給解碼網(wǎng)絡(luò)Q,形成了一個生成式建模計劃!
讓我們在自動編碼器的對抗訓(xùn)練中,在代碼空間上施加一個標(biāo)準(zhǔn)差為5的二維高斯分布。取樣該空間的相鄰點(diǎn),得到一些生成數(shù)字的連續(xù)變體集。
左:在二維代碼空間中驗證數(shù)據(jù)。右:從代碼空間的相鄰點(diǎn)中取樣,解碼樣本以生成數(shù)字
我們還可以基于標(biāo)簽訓(xùn)練AAE,以強(qiáng)制標(biāo)簽和數(shù)字風(fēng)格信息的退相干。這樣,通過固定想要的標(biāo)簽,施加的連續(xù)潛空間中的變體將導(dǎo)致不同風(fēng)格的同一數(shù)字。以數(shù)字八為例:
圖十二:固定標(biāo)簽,從潛AAE空間的相鄰點(diǎn)中取樣
很明顯,相鄰點(diǎn)間存在有意義的關(guān)系!為我們的數(shù)據(jù)集增強(qiáng)問題生成樣本時,這一性質(zhì)可能會提供便利。
時序數(shù)據(jù)?
現(xiàn)實(shí)世界的結(jié)構(gòu)化數(shù)據(jù)常常包含時序。在這樣的數(shù)據(jù)中,每個樣本和之前的樣本間存在某種依賴關(guān)系。經(jīng)常選擇基于循環(huán)神經(jīng)網(wǎng)絡(luò)的模型來處理這種數(shù)據(jù),原因是它們具備建模這種數(shù)據(jù)的內(nèi)在能力。在我們的GAN模型中利用這些神經(jīng)網(wǎng)絡(luò),在原則上可以產(chǎn)生更高質(zhì)量的樣本和特征!
循環(huán)GAN
讓我們將之前的GAN中的MLP替換為RNN,就像這篇論文所做的那樣。具體而言,我們將采用RNN的變體長短時記憶(LSTM)單元(事實(shí)上我們在談?wù)撋疃葘W(xué)習(xí)最最時髦的行話——哎喲,我又這么干了),在波形(Waves)數(shù)據(jù)集上進(jìn)行訓(xùn)練。這一數(shù)據(jù)集包含偏移、頻率、振幅不同的一維正弦信號和鋸齒信號,所有信號的時步相同。從RNN的視角來看,每個樣本包含一個30時步的波形。
我們的CGAN的生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)都將采用基于LSTM的神經(jīng)網(wǎng)絡(luò),將其轉(zhuǎn)化為一個RCGAN。我們將訓(xùn)練該RCGAN學(xué)習(xí)按需生成正弦、鋸齒波形。
圖十三:左:生成的正弦波形;右:生成的鋸齒波形
訓(xùn)練之后,我們也將查看潛空間中的變體是如何產(chǎn)生生成樣本特性體現(xiàn)的連續(xù)變化的。具體而言,如果我們施加一個二維正態(tài)分布潛空間,并將分類標(biāo)簽固定為正弦波形,我們將得到圖十四中顯示的樣本。其中,我們能很明顯地看到頻率和振幅由低到高的連續(xù)變化,這意味著RCGAN學(xué)習(xí)到了一個有意義的潛空間!
圖十四:固定標(biāo)簽,從潛RCGAN空間的相鄰點(diǎn)中取樣
盡管在GAN中使用RNN對生成實(shí)值的序列化數(shù)據(jù)很有用,它仍然無法用于離散序列,是否可以配合RNN使用Wasserstein距離尚不清楚(在RNN上施加Lipschitz限制是以后的研究課題)。SeqGAN和最近的ARAE的目標(biāo)是解決這一問題。
結(jié)論
我們看到,在因為GAN具有生成非常酷的圖像的能力而生成的那些大驚小怪的報道(看過沒有?)之外,一些架構(gòu)也可能有助于處理更一般的機(jī)器學(xué)習(xí)問題,包括連續(xù)和離散的數(shù)據(jù)。本文的目的是介紹這一想法,并不打算嚴(yán)格地比較這些多用途生成式模型,不過本文確實(shí)證明了應(yīng)該進(jìn)行這樣的涉及GAN的研究。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4779瀏覽量
101052 -
生成器
+關(guān)注
關(guān)注
7文章
319瀏覽量
21083 -
數(shù)據(jù)集
+關(guān)注
關(guān)注
4文章
1209瀏覽量
24793
原文標(biāo)題:一文概覽用于數(shù)據(jù)集增強(qiáng)的對抗生成網(wǎng)絡(luò)架構(gòu)
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論