自從擴散模型發布以來,GAN的關注度和論文是越來越少了,但是它們里面的一些思路還是值得我們了解和學習。所以本文我們來使用Pytorch 來實現SN-GAN
譜歸一化生成對抗網絡是一種生成對抗網絡,它使用譜歸一化技術來穩定鑒別器的訓練。譜歸一化是一種權值歸一化技術,它約束了鑒別器中每一層的譜范數。這有助于防止鑒別器變得過于強大,從而導致不穩定和糟糕的結果。
SN-GAN由Miyato等人(2018)在論文“生成對抗網絡的譜歸一化”中提出,作者證明了sn - gan在各種圖像生成任務上比其他gan具有更好的性能。
SN-GAN的訓練方式與其他gan相同。生成器網絡學習生成與真實圖像無法區分的圖像,而鑒別器網絡學習區分真實圖像和生成圖像。這兩個網絡以競爭的方式進行訓練,它們最終達到一個點,即生成器能夠產生逼真的圖像,從而欺騙鑒別器。
以下是SN-GAN相對于其他gan的優勢總結:
- 更穩定,更容易訓練
- 可以生成更高質量的圖像
- 更通用,可以用來生成更廣泛的內容。
模式崩潰
模式崩潰是生成對抗網絡(GANs)訓練中常見的問題。當GAN的生成器網絡無法產生多樣化的輸出,而是陷入特定的模式時,就會發生模式崩潰。這會導致生成的輸出出現重復,缺乏多樣性和細節,有時甚至與訓練數據完全無關。
GAN中發生模式崩潰有幾個原因。一個原因是生成器網絡可能對訓練數據過擬合。如果訓練數據不夠多樣化,或者生成器網絡太復雜,就會發生這種情況。另一個原因是生成器網絡可能陷入損失函數的局部最小值。如果學習率太高,或者損失函數定義不明確,就會發生這種情況。
以前有許多技術可以用來防止模式崩潰。比如使用更多樣化的訓練數據集?;蛘呤褂谜齽t化技術,例如dropout或批處理歸一化,使用合適的學習率和損失函數也很重要。
Wassersteian損失
Wasserstein損失,也稱為Earth Mover’s Distance(EMD)或Wasserstein GAN (WGAN)損失,是一種用于生成對抗網絡(GAN)的損失函數。引入它是為了解決與傳統GAN損失函數相關的一些問題,例如Jensen-Shannon散度和Kullback-Leibler散度。
Wasserstein損失測量真實數據和生成數據的概率分布之間的差異,同時確保它具有一定的數學性質。他的思想是最小化這兩個分布之間的Wassersteian距離(也稱為地球移動者距離)。Wasserstein距離可以被認為是將一個分布轉換為另一個分布所需的最小“成本”,其中“成本”被定義為將概率質量從一個位置移動到另一個位置所需的“工作量”。
Wasserstein損失的數學定義如下:
對于生成器G和鑒別器D, Wasserstein損失(Wasserstein距離)可以表示為:
Jensen-Shannon散度(JSD): Jensen-Shannon散度是一種對稱度量,用于量化兩個概率分布之間的差異
對于概率分布P和Q, JSD定義如下:
JSD(P∥Q)=1/2(KL(P∥M)+KL(Q∥M))
M為平均分布,KL為Kullback-Leibler散度,P∥Q為分布P與分布Q之間的JSD。
JSD總是非負的,在0和1之間有界,并且對稱(JSD(P|Q) = JSD(Q|P))。它可以被解釋為KL散度的“平滑”版本。
Kullback-Leibler散度(KL散度):Kullback-Leibler散度,通常被稱為KL散度或相對熵,通過量化“額外信息”來測量兩個概率分布之間的差異,這些“額外信息”需要使用另一個分布作為參考來編碼一個分布。
對于兩個概率分布P和Q,從Q到P的KL散度定義為:KL(P∥Q)=∑x P(x)log(Q(x)/P(x))。KL散度是非負非對稱的,即KL(P∥Q)≠KL(Q∥P)。當且僅當P和Q相等時它為零。KL散度是無界的,可以用來衡量分布之間的不相似性。
1-Lipschitz Contiunity
1- lipschitz函數是斜率的絕對值以1為界的函數。這意味著對于任意兩個輸入x和y,函數輸出之間的差不超過輸入之間的差。
數學上函數f是1-Lipschitz,如果對于f定義域內的所有x和y,以下不等式成立:
|f(x) — f(y)| <= |x — y|
在生成對抗網絡(GANs)中強制Lipschitz連續性是一種用于穩定訓練和防止與傳統GANs相關的一些問題的技術,例如模式崩潰和訓練不穩定。在GAN中實現Lipschitz連續性的主要方法是通過使用Lipschitz約束或正則化,一種常用的方法是Wasserstein GAN (WGAN)。
在標準gan中,鑒別器(也稱為WGAN中的批評家)被訓練來區分真實和虛假數據。為了加強Lipschitz連續性,WGAN增加了一個約束,即鑒別器函數應該是Lipschitz連續的,這意味著函數的梯度不應該增長得太大。在數學上,它被限制為:
∥∣D(x)?D(y)∣≤K?∥x?y∥
其中D(x)是評論家對數據點x的輸出,D(y)是y的輸出,K是Lipschitz 常數。
WGAN的權重裁剪:在原始的WGAN中,通過在每個訓練步驟后將鑒別器網絡的權重裁剪到一個小范圍(例如,[-0.01,0.01])來強制執行該約束。權重裁剪確保了鑒別器的梯度保持在一定范圍內,并加強了利普希茨連續性。
WGAN的梯度懲罰: WGAN的一種變體,稱為WGAN-GP,它使用梯度懲罰而不是權值裁剪來強制Lipschitz約束。WGAN-GP基于鑒別器的輸出相對于真實和虛假數據之間的隨機點的梯度,在損失函數中添加了一個懲罰項。這種懲罰鼓勵了Lipschitz約束,而不需要權重裁剪。
譜范數
從符號上看矩陣
-
編碼器
+關注
關注
45文章
3663瀏覽量
135028 -
生成器
+關注
關注
7文章
319瀏覽量
21075 -
頻譜儀
+關注
關注
7文章
342瀏覽量
36152 -
pytorch
+關注
關注
2文章
808瀏覽量
13321
發布評論請先 登錄
相關推薦
評論