深度學習Paper系列(12):UNIT
上一回我們介紹了VAE-GAN的架構,其把兩個不同的生成模型VAE和GAN做結合。而我們今天要介紹的這個模型,是基於VAE-GAN的unpaired image-to-image translation模型 — UNIT,其原始論文名稱叫做:
Unsupervised Image-to-Image Translation Networks
UNIT架構概述
下面的圖片就是我們今天所要介紹的UNIT架構
我們可以看到UNIT的架構使用了兩個VAE-GAN,但是有趣的是
這兩個VAE-GAN是共享latent space的
什麼意思呢?在UNIT架構中我們會有6個networks
Encoder 1 E₁:用來壓縮照片x₁到latent space z
Encoder 2 E₂:用來壓縮照片x₂到latent space z
Decoder 1 G₁:用來解壓縮latent space z回照片~x₁
Decoder 2 G₂:用來解壓縮latent space z回照片~x₂
Discriminator 1 D₁:用來衡量原始照片x₁和重建照片~x₁的差異
Discriminator 1 D₂:用來衡量原始照片x₂和重建照片~x₂的差異
所以在這邊我們可以看到,在圖像壓縮的過程中不管是x₁還是x₂都會被encode到相同的latent space。
而共享latent space的好處就是,假設我今天想把X₁ domain的照片轉換成X₂ domain的照片,我只要把X₁ domain的照片先經過E₁壓縮到latent space,然後再用G₂解壓縮就會轉成X₂ domain的照片。相對的,假設我今天想把X₂ domain的照片轉換成X₁ domain的照片,我只要把X₂ domain的照片先經過E₂壓縮到latent space,然後再用G₁解壓縮就會轉成X₁ domain的照片。
UNIT理論
OK簡單介紹完了UNIT的架構後,接下來就是大家最愛的理論部分了。在這邊我們設x₁為屬於X₁ domain的照片,而x₂為屬於X₂ domain的照片,兩個照片是成對的,並且存在一個latent space Z中彼此存在一個共享的latent code。
所以基於這些假設,我們現在有四個network,分別是E*₁、E*₂、G*₁、G*₂,其使得
- z = E*₁(x₁) = E*₂(x₂)
- x₁ = G*₁(z)、x₂ = G*₂(z)
- x₁ = G*₁(E*₂(x₂))、x₂ = G*₂(E*₁(x₁))
- x₁ = G*₁(E*₂(G*₂(E*₁(x₁))))、x₂ = G*₂(E*₁(G*₁(E*₂(x₂))))
其中第四點就是所謂的cylce-consistency。
而在這邊也假設其存在一個共享的intermediate representation h
其意思就是說今天從latent space轉換到照片的空間時,這個共享的latent code會先轉成這個中間的representation h,最後才會再各自轉成x₁和x₂。所以我們可以把兩個generator各自寫成
G*₁ = G*_L,1 ◦ G*_H
G*₂ = G*_L,2 ◦ G*_H
其中G*_H為high-level generation function可以將z轉為h,而G*_L,1和 G*_L,2為low-level generation function分別可以將h轉為x₁和x₂。
所以相對的,我們同樣也可以把兩個encoder各自寫成
E*₁ = E*_L,1 ◦ E*_H
E*₂ = E*_L,2 ◦ E*_H
其中E*_H為high-level encoding function可以將h轉為z,而E*_L,1和 E*_L,2為low-level encoding function分別可以將x₁和x₂轉為h。
VAE
因為UNIT是基於兩個VAE-GAN,所以這邊我們可以把它個別拆解成兩個VAE的問題。
對於E₁和G₁,encode可以寫成
decode可以寫成
對於E₂和G₂,encode可以寫成
decode可以寫成
Weight-sharing
另外在上面的部分有提到所謂的intermediate representation,所以對於兩個encoder E₁和E₂,其共享最後幾層的模型參數。相對的,對於decoder G₁和G₂,其共享前面幾層的模型參數。
GAN
剛剛我們提到了關於兩個VAE壓縮和解壓縮的過程,另外這邊我們要來介紹其相對應GAN的部分。實際上我們可以把decode的過程分成兩個stream:
reconstruction stream
translation stream
我們可以明顯看到兩個stream的取樣來源分別是來自X₁ domain和X₂ domain。不過在UNIT針對GAN的訓練過程中,其僅使用translation stream所生成的圖像去計算adversarial loss,主要原因是reconstruction stream可用簡單的supervised loss去衡量計算。
Cycle-consistency (CC)
對於cycle-consistency的概念,實際上我們在先前CycleGAN那篇就有介紹他的概念,也就是希望轉換的影像,重建回來後和原始影像相同,不過這邊得注意一件事,cycle-consistency指的重建影像,是原始影像經過兩個enocder和兩個decoder所產生出來的,和上面reconstruction stream所講的重建影像是不同的。
Learning function
所以這邊我們就可以把上面所講的東西全部兜在一起,就是我們UNIT的學習方程式了!
VAE loss
針對VAE的部分,其可以寫成:
GAN loss
針對VAE的部分,其可以寫成:
cycle-consistency loss
針對CC的部分,其可以寫成:
對於這個方程式,其可以類比上面的VAE loss,我們會發現其就只是把encoder q中做更改。我們會發現
第一項主要是希望我們enocder學出來的壓縮分布q(z|x)和我們所設定的高斯分布p_η(z)越接近越好
第二項代表我們的轉換影像經過enocder學出來的壓縮分布q(z|x)和我們所設定的高斯分布p_η(z)越接近越好
第三項代表我們希望轉換影像encode成z後,他轉回原來影像的機率越大越好
所以這三項簡單來說,就是構成了cycle-consistency的學習方程式。
而在network參數更新上,我們會先鎖住encoder和decoder的參數來更新discriminator,接下來再鎖住discriminator的參數更新encoder和decoder。
實驗結果
接下來我們來看UNIT轉換出來的結果:
我們可以看到UNIT在圖像轉換上,可以轉換出各種不同風格的影像
實作
在實作上,我們選用這個repository
實際上在原始paper的附錄當中,有給出UNIT的network架構:
我們可以看到對應的程式碼
Encoder:
class ContentEncoder(nn.Module):
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
super(ContentEncoder, self).__init__()
self.model = []
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
# downsampling blocks
for i in range(n_downsample):
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
# residual blocks
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim
def forward(self, x):
return self.model(x)
Decoder:
class Decoder(nn.Module):
def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
super(Decoder, self).__init__()
self.model = []
# AdaIN residual blocks
self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
# upsampling blocks
for i in range(n_upsample):
self.model += [nn.Upsample(scale_factor=2),
Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
dim //= 2
# use reflection padding in the last conv layer
self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
def forward(self, x):
return self.model(x)
Discriminator:
class MsImageDis(nn.Module):
# Multi-scale discriminator architecture
def __init__(self, input_dim, params):
super(MsImageDis, self).__init__()
self.n_layer = params['n_layer']
self.gan_type = params['gan_type']
self.dim = params['dim']
self.norm = params['norm']
self.activ = params['activ']
self.num_scales = params['num_scales']
self.pad_type = params['pad_type']
self.input_dim = input_dim
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
self.cnns = nn.ModuleList()
for _ in range(self.num_scales):
self.cnns.append(self._make_net())
def _make_net(self):
dim = self.dim
cnn_x = []
cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)]
for i in range(self.n_layer - 1):
cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)]
dim *= 2
cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)]
cnn_x = nn.Sequential(*cnn_x)
return cnn_x
def forward(self, x):
outputs = []
for model in self.cnns:
outputs.append(model(x))
x = self.downsample(x)
return outputs
def calc_dis_loss(self, input_fake, input_real):
# calculate the loss to train D
outs0 = self.forward(input_fake)
outs1 = self.forward(input_real)
loss = 0
for it, (out0, out1) in enumerate(zip(outs0, outs1)):
if self.gan_type == 'lsgan':
loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
elif self.gan_type == 'nsgan':
all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
F.binary_cross_entropy(F.sigmoid(out1), all1))
else:
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
return loss
def calc_gen_loss(self, input_fake):
# calculate the loss to train G
outs0 = self.forward(input_fake)
loss = 0
for it, (out0) in enumerate(outs0):
if self.gan_type == 'lsgan':
loss += torch.mean((out0 - 1)**2) # LSGAN
elif self.gan_type == 'nsgan':
all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
else:
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
return loss
另外對於計算loss的部分我們也可以看到相對的程式碼:
self.gen_opt.zero_grad()
# encode
h_a, n_a = self.gen_a.encode(x_a)
h_b, n_b = self.gen_b.encode(x_b)
# decode (within domain)
x_a_recon = self.gen_a.decode(h_a + n_a)
x_b_recon = self.gen_b.decode(h_b + n_b)
# decode (cross domain)
x_ba = self.gen_a.decode(h_b + n_b)
x_ab = self.gen_b.decode(h_a + n_a)
# encode again
h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
# decode again (if needed)
x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
# reconstruction loss
self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
# GAN loss
self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
# domain-invariant perceptual loss
self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
# total loss
self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
hyperparameters['gan_w'] * self.loss_gen_adv_b + \
hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
hyperparameters['vgg_w'] * self.loss_gen_vgg_b
self.loss_gen_total.backward()
self.gen_opt.step()
這邊我們可以看到在實作上,有引入所謂的perceptual loss,其主要就是用預訓練好的VGG16去抓取特徵來計算loss。
Reference
[1] Liu, M. Y., Breuel, T., & Kautz, J. (2017). Unsupervised image-to-image translation networks. Advances in neural information processing systems, 30.