深度學習Paper系列(10):CycleGAN
前兩回我們介紹了兩個經典的image-to-image translation模型,分別是Pix2Pix和他的改良版Pix2Pix HD,不過對於這兩個模型其都是基於解決所謂的Paired translation問題,也就是說這兩個模型只適用在source domain和target domain是成對的情況下才能訓練。
但是實際上在這個世界上,要找到成對的影像幾乎是不太可能的一件事,比方說我們想要把狗轉換成貓,我們不可能找到姿勢、眼神、大小都一樣成對的貓狗照片。所以這種不成對的圖像轉換問題,我們稱之為
Unpaired image-to-image translation
而為了解決這個問題,就出現了我們今天所要介紹的這篇論文,名稱叫做:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
CycleGAN模型架構簡介
對於這篇paper所介紹了模型,其簡稱就叫做「CycleGAN」。
我們回到狗轉換成貓的例子,在CycleGAN的架構中,我們會有兩個generators和兩個discriminators,
Generator G:負責把狗的照片(source domain image x)轉換成貓的照片(target domain image y)
Generator F:負責把貓的照片(target domain image y)轉換回狗的照片(source domain image x)
Discriminator Dₓ:負責區分真的狗照片x和轉換的狗照片F(y)
Discriminator Dᵧ:負責區分真的貓照片y和轉換的貓照片G(x)
另外這邊CycleGAN引入了一個非常重要的概念:
Cycle-Consistency
就是說我今天希望狗的照片x經過Generator G轉換為貓的照片G(x),我們再把其丟到Generator F轉回狗的照片F(G(x))後,我們希望前後兩者一模一樣,也就是x = F(G(x))。
相反的我今天也希望貓的照片y經過Generator F轉換為狗的照片F(y),我們再把其丟到Generator G轉回貓的照片G(F(y))後,我們也希望前後兩者一模一樣,也就是y = G(F(y))。
這個概念就很像我們把中文丟到google翻譯翻成英文,然後再把翻出來的英文翻回去中文的概念一樣,我們希望前後兩個要相同。
而這樣做的目的就是希望
轉換出來的影像,其結構、輪廓、特徵等等是基於其原始的影像
方程式推導
OK上面我們簡單介紹了CycleGAN的架構之後,接下來我們就要來詳細介紹他的學習方程式。
首先我們今天的學習目標是將source domain X的影像x∈X轉換成target domain Y的影像y∈Y,在這邊我們把source domain的資料分布設為x~p_data(x)、target domain的資料分布設為y~p_data(y)。
我們現在有兩個generator,其可以做到把圖像到圖像的映射(mapping)
Generator G:X → Y
Generator F:Y → X
另外我們有兩個discriminators,其功能上面提到過了
Discriminator Dₓ:負責區分x和F(y)
Discriminator Dᵧ:負責區分y和G(x)
Adversarial Loss
我們現在有了兩個generator和兩個discriminator,我們的學習目標也非常明確,所以我們就可以列出兩個學習方程式:
和
第一個方程式我們可以看到就是針對把source domain影像x轉換成target domain影像G(x)的最佳化方程式,而另一個就是針對把target domain影像y轉換成source domain影像F(y)的最佳化方程式。
Cycle-Consistency Loss
最後這邊就是加上我們上面所提到的cycle-consistency loss了,我們希望轉換回來的影像和原始的影像相同,其方程式可以寫為下面的式子;
和
其中這邊的參數𝜆_𝑐𝑦𝑐(𝑥)和𝜆_𝑐𝑦𝑐(𝑦)分別就是控制兩個不同cycle-consistency loss的權重。
Full Objective
最後我們就可以把上面的兩個adversarial loss和cycle-consistency loss合併在一起就是我們CycleGAN的學習方程式了!
實作
雖然CycleGAN的架構看起來,非常簡單也非常容易理解,但是在實作層面有許多的細節需要去注意,這邊我們就用實現CycleGAN模型最多人用的repository:
Generator架構
實際上CycleGAN的generator架構是基於ResNet的架構,而不是U-Net
class ResnetGenerator(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""Standard forward"""
return self.model(input)
這邊不用U-Net的理由是因為,U-Net比較適合用在Paired-image-to-image translation的問題上,因為U-Net有所謂的skip-connection會把每一層取到的特徵在最後面合併起來,所以在細節的捕捉上會比較好,但是對於Unpaired的問題,我們想要轉換過去的影像實際上是不存在的,我們需要做的事情是根據source domain的特徵、樣貌、輪廓轉換成像是target domain的影像,所以這邊在萃取特徵上ResNet是比較適合的。
Identity loss
另外在實作上CycleGAN有時候也會引入一個額外的identity loss:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
這邊的loss是希望x和F(x)一樣、然後y跟G(y)一樣,這邊大家一定會覺得非常奇怪,generator G和generator F的輸入應該是x和y才對啊,為什麼這邊反過來了呢?
其實這個想法最初是來自於
Unsupervised cross-domain image generation
這篇paper的,我們知道G可以將x轉換成y、F可以將y轉換成x,所以我們今天設置identity loss的目的就是希望,G看到y的時候可以把他的樣貌、顏色、形狀保存好、也就是希望y跟G(y)一樣,反之我們也希望x和F(x)一樣。
主要這樣做的目的,在paper裡面有提到
其目標是想要把painting轉換成photo,我們可以看到原本沒有加入identity loss的CycleGAN,他轉換出來的影像都很真實、也很合理沒有錯,但是我們可以明顯看到,他把天空的顏色、海的顏色都轉錯了,所以為了改善這一點,我們可以引入identity loss讓generator能夠學會保存這些重要的顏色特徵。
LSGAN
另外一個執行上的細節就是,CycleGAN訓練GAN的方式不是使用傳統的negative log likelihood,而是使用
least-squares loss
主要原因是因為,在先前
Least Squares Generative Adversarial Networks
的論文中有證明,使用least-squares loss生成的圖片品質更好,而且訓練來的更穩定。而其最佳化方程式被改寫成:
實驗結果
我們在這邊可以看到作者做了很多有趣的實驗
其測試了5種條件,分別是(1)只使用cycle-consistency loss、(2)只使用adversarial loss、(3)只訓練foward那邊、(4)只訓練backward那邊以及(5)整個CycleGAN,我們可以看到CycleGAN得出來的結果是最好的。
另外我們還可以看到一些有趣的轉換結果
失敗的例子
這邊我們可以看到作者也有列出一些轉換失敗的例子,最經典的就是右上角那張馬轉斑馬那張,其把整個人也轉成斑馬的條紋。
所以在這邊也說明了CycleGAN他的一個很大的天生限制,這邊我可以問大家比較常看到人騎馬還是人騎斑馬?答案很明顯,因為CycleGAN轉換的時候,非常仰賴X和Y兩個domain的資訊要是對等的,如果不對等就會出現上面這種轉換失敗的情況。
不過這種問題其實也很好解決,我們實際上可以搭配文字一起學習,像是搭配"一個人騎著一匹馬",讓模型自己去學會圖片對應關係,就可以改善這個問題。
Reference
[1] Zhu, J. Y., Park, T., Isola, P., & Efros, A. A. (2017). Unpaired image-to-image translation using cycle-consistent adversarial networks. In Proceedings of the IEEE international conference on computer vision (pp. 2223–2232).
[2] Mao, X., Li, Q., Xie, H., Lau, R. Y., Wang, Z., & Paul Smolley, S. (2017). Least squares generative adversarial networks. In Proceedings of the IEEE international conference on computer vision (pp. 2794–2802).
[3] Taigman, Y., Polyak, A., & Wolf, L. (2016). Unsupervised cross-domain image generation. arXiv preprint arXiv:1611.02200.