深度學習Paper系列(15):U-GAT-IT
前幾回當中我們介紹了幾個Unpaired image-to-image translation的模型,今天我們也要來介紹另一個重要的影像轉換模型「U-GAT-IT」,其論文名稱叫做:
U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation
在前幾篇當中,我們介紹了很多GAN的模型,像是CycleGAN、MUNIT、DRIT等等,雖然這些模型在圖像轉換上都達到非常不錯的效果,但是對於這些模型架構,會因為不同的任務轉換效果的品質會有所差異。
具體上來說就是,先前的這些架構在像是照片轉油畫、風景照轉素描,這些專注於轉local texture的任務上表現比較好,但是對於人像轉動漫人物或是貓轉狗,兩個domain型態上差異非常大的任務上表現就不是很理想。
所以為了解決這樣的問題,這篇paper就提出了U-GAT-IT的架構,其中加入了
一個新的attention module和一個learnable normalization function
U-GAT-IT架構介紹
U-GAT-IT的架構,可以讓模型在學習的時候,透過引入一個輔助classifer然後利用CAM的技術產生attention map,如此一來就可以讓模型學習真正重要的區域。而有趣的是,這些attention maps是被嵌入在我們的generator和discriminator當中,所以可以讓模型針對相對應的domain來產生真實的影像。具體而言就是這些attention map
- 對於generator可以幫助模型去專注兩個domain的差異
- 對於discriminator可以幫助模型透過關注目標域中真實圖像和假圖像之間的差異來幫助fine-tuning
除此之外U-GAT-IT引入了所謂的Adaptive Layer-Instance Normalization (AdaLIN),簡單來說就是我們
可以藉由可學習的參數來控制Instance normalization和Layer normalization之間的權重參數,讓模型在訓練的時候可以邊學邊調整,我們是否應該要更傾向於shape和texture上的轉換。
U-GAT-IT理論
OK接下來我們一樣進入理論的部分。首先我們知道我們的任務就是訓練GAN讓其可以在source domain Xs和target domain Xt之間做轉換,所以這邊我們會需要4個基礎的network
- Generator Gs→t:將source domain圖像轉成target domain圖像
- Generator Gt→s:將target domain圖像轉成source domain圖像
- Discriminator Ds:辨識source domain真實圖像和Gt→s轉換圖像
- Discriminator Dt:辨識target domain真實圖像和Gs→t轉換圖像
Generator架構
不過實際上我們可以看到U-GAT-IT generator的架構,可以把它拆解成6個subnetworks分別是
- Source domain encoder Es
- Source domain decoder Gs
- Source domain auxiliary classifier ηs
- Target domain encoder Et
- Target domain decoder Gt
- Target domain auxiliary classifier ηt
encoder基本上做的事情就是對輸入圖像萃取特徵,所以在這邊我們會把每一層encoder抓取到圖像的特徵,針對不同層數乘上一個可學習的權重參數,而這些可學習的權重參數就是我們auxiliary classifier的參數,其方程式可以寫成:
這邊Esk(x)指的就是source domain encoder第k層的特徵圖,而i和j指的就是特徵圖pixel的位置,所以外面可以看到一個總和的運算符號,意思就是我們會在這邊做global average pooling和global max pooling把這一層特徵的每一個pixel數值加總其來,接著外面的wsk就是我們source domain auxiliary classifier的權重參數,最後我們會把剛剛每一層萃取的特徵乘上每一層相對應的參數,一樣加總起來然後以activation function輸出結果。
另外一邊我們可以看到我們會計算一個東西叫做domain specific attention feature map:
基本上他也是把encoder每一層取到的特徵乘上上面的權重參數,只是這邊沒有做加總運算而已。所以最後我們得到了這個domain specific attention feature map as(x),我們就會把他丟進decoder輸出轉換圖像Gt(as(x))。所以說
對於source domain轉到target domain可以寫成
對於target domain轉到source domain可以寫成
不過我們在這邊可以看到上圖的架構還多了一個fully connected layer,然後還有兩個參數γ和β,他們是做什麼的呢?其實他就是U-GAT-IT的另一個重點AdaLIN,其可以方程式可以寫成
在這邊我們可以看上面提到的參數γ和β就是由fully connected layer所產生的,而參數
- µI:channel-wise平均值
- µL:layer-wise mean平均值
- σI:channel-wise標準差
- σL:layer-wise標準差
τ是learning rate、∆ρ為參數更新的gradient,而在這邊我們會限制ρ的數值要藉在0到1之間。
上面的方程式雖然看起來很複雜,但是其想法非常簡單。我們剛剛在encoder和auxiliary classifier的部分,最後會算出所謂的domain specific attention feature map a,所以
- aI^就是對a做instance normalization的計算結果
- aL^則是對a做layer normalization的計算結果
所以我們可以藉由神經網絡學習的過程
不斷的調整ρ也就是instance normalization和layer normalization的比例,並且搭配fully connected layer產生的參數γ和β來做計算。
Discriminator架構
針對discriminator的架構,其實想法和概念也和generator相似,我們可以把discriminator拆解成:
- source domain discriminator encoder EDs
- source domain discriminator classifier CDs
- source domain discriminator auxiliary classifier ηDs
- target domain discriminator encoder EDt
- target domain discriminator classifier CDt
- target domain discriminator auxiliary classifier ηDt
我們會把影像輸入到discriminator encoder中,然後一樣把每一層萃取到的特徵乘上discriminator auxiliary classifier的權重參數,就可以得到我attention feature map,最後這個attention feature map就會輸入到discriminator classifier裡面作判別。所以
對於souce domain discriminator我們可以寫成
對於target domain discriminator我們可以寫成
U-GAT-IT Loss function
對於U-GAT-IT的學習方程式,其包含以下幾個loss
- Adversarial loss
- Cycle loss
- Identity loss
- CAM loss
接下來我們就會對這些loss來進行介紹
Adversarial loss
在GAN的部分,其學習方程式是使用Least Squares GAN的架構,所以我們可以寫下source domain和target domain的adversarial loss方程式:
Cycle loss
針對Cycle loss的地方其就是CycleGAN裡面的cycle consistency loss,希望經過兩次轉換後重建的影像和原始影像相同,這邊方程式可以寫成:
Identity loss
這邊的identy loss也是從CycleGAN論文提出的loss function,其主要是希望轉換出來的影像和原本影像的顏色分布是相似的,其學習方程式可以寫成:
CAM loss
針對CAM loss的部分,其主要就是用來更新我們auxiliary classifier的權重參數,所以針對generator的部分我們可以把學習方程式寫成
關於這兩個方程式,我們可以看到其方程式就是非常標準的cross-entropy loss,所以auxiliary classifier的目標就是
希望能夠區分source domain的特徵和target domain的特徵。
另外一邊就是針對discriminator auxiliary classifier的學習方程式了,其可以表示成:
我們可以看到這兩個loss function是使用LSGAN loss來訓練,這個loss的計算方式是
希望原始影像的特徵和轉換影像的特徵越像越好。
這點其實基本上和原本discriminator的學習目標一模一樣,只是差在discrminator這邊還會把attention map乘上attention weight再丟進去classifer裡面,而discriminator auxiliary classifier是直接針對attention map去計算。
Full objective
最後我們就可以把所有的loss加總在一起,就是我們U-GAT-IT的學習方程式:
實驗結果
針對U-GAT-IT模型的部分,作者在這邊做了一串實驗來證明他的效果。
模型比較
我們可以看到其和先前一些有名的Unpaired image-to-image translation模型去做比較,其中有我們先前介紹過的CycleGAN、UNIT、MUNIT、DRIT。我們可以看到在大部分的任務上U-GAT-IT都達到最好的轉換品質。
CAM ANALYSIS
我們可以看到圖(a)是我們的輸入影像,圖(b)為generator的attention map,圖(c)和圖(d)分別是discriminator local和global部分的attention map。我們會發現一件有趣的事情
- generator的auxiliary classifier學習目標是區分source domain影像特徵和targte domain影像特徵,所以他抓到的特徵重點會是那張圖片有特色的地方,所以我們基本上可以看到CAM highlight了整張圖片
- discriminator的auxiliary classifier學習目標是讓原始影像特徵和轉換影像特徵越像越好,所以他抓到特徵會是輸入的轉換圖片哪個特徵讓他看起來不像真的
我們最後可以看到有使用auxiliary classifier輔助轉換出來的圖片更加真實。
AdaLIN ANALYSIS
在這邊也做了有關於AdaLIN的有效性實驗,我們可以看到圖(a)是輸入影像,圖(b)是使用AdaLIN轉換出來的影像,其看起來是最自然的且同時最根據原始影像做轉換的結果。
在圖(c)我們可以看到只使用IN的話,因為channel-wise normalized feature在我們decoder的residual blocks架構被使用,所以輸入影像的特徵可以被很好的保存在轉換影像上,但是我們可以同時發現,對於target domain的global style沒辦法很有效的在up-sampling convolution blocks被抓到,所以轉換出來的圖片並不理想。
在圖(d)我們可以看到只使用LN的例子,其效果和IN相反,因為在layerwise normalized feature在我們up-sampling convolution上被使用,所以對於target domain的風格可以很好的被轉換過去,但是相對的他犧牲的就是source domain的特徵在residual blocks就沒辦法被很好的保存。
最後兩個例子就是使用AdaIN和group normalization (GN),我們可以看到他們的轉換結果也沒有使用AdaIN來得好。
實作
最後我們來看實作的部分,這個地方我們使用的是這個repository:
一開始我們一樣先來看network架構
Generator
generator的架構在原始paper中有給出其架構,以下是其相對應的程式碼:
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, light=False):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.n_blocks = n_blocks
self.img_size = img_size
self.light = light
DownBlock = []
DownBlock += [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)]
# Down-Sampling
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
DownBlock += [nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)]
# Down-Sampling Bottleneck
mult = 2**n_downsampling
for i in range(n_blocks):
DownBlock += [ResnetBlock(ngf * mult, use_bias=False)]
# Class Activation Map
self.gap_fc = nn.Linear(ngf * mult, 1, bias=False)
self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False)
self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True)
self.relu = nn.ReLU(True)
# Gamma, Beta block
if self.light:
FC = [nn.Linear(ngf * mult, ngf * mult, bias=False),
nn.ReLU(True),
nn.Linear(ngf * mult, ngf * mult, bias=False),
nn.ReLU(True)]
else:
FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False),
nn.ReLU(True),
nn.Linear(ngf * mult, ngf * mult, bias=False),
nn.ReLU(True)]
self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False)
self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False)
# Up-Sampling Bottleneck
for i in range(n_blocks):
setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False))
# Up-Sampling
UpBlock2 = []
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False),
ILN(int(ngf * mult / 2)),
nn.ReLU(True)]
UpBlock2 += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False),
nn.Tanh()]
self.DownBlock = nn.Sequential(*DownBlock)
self.FC = nn.Sequential(*FC)
self.UpBlock2 = nn.Sequential(*UpBlock2)
def forward(self, input):
x = self.DownBlock(input)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap_weight = list(self.gap_fc.parameters())[0]
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp_weight = list(self.gmp_fc.parameters())[0]
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
x_ = self.FC(x_.view(x_.shape[0], -1))
else:
x_ = self.FC(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for i in range(self.n_blocks):
x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta)
out = self.UpBlock2(x)
return out, cam_logit, heatmap
基本上整體架構和CycleGAN的discriminator蠻相似的,一開始我們的輸入影像會經過encoder self.DownBlock
做down-sampling,接下來就是產生attention map的部分,這邊在上面有提到我們會做global average pooling torch.nn.functional.adaptive_avg_pool2d(x, 1)
和global max pooling torch.nn.functional.adaptive_max_pool2d(x, 1)
,如此一來針對auxiliary classifier的部分,我們就會把這些特徵輸入到相對應的fully-connected network self.gap_fc
和self.gmp_fc
,最後把他們的輸出結果合併起來就是我們auxiliary classifier的輸出cam_logit
。
另外一部分就是把輸出結果傳到decoder,這個地方我們會把剛剛經過encoder和fully-connected network的特徵乘上相對應的attention weight gap_weight
和gmp_weight
,最後我們就是把這些特徵gap
和gmp
合併起來丟入decoder前面的fully-connected layer self.FC(x.view(x.shape[0], -1))
,然後再餵到decoder self.UpBlock2
其輸出out
就是我們的轉換影像。
AdaILN
class adaILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(adaILN, self).__init__()
self.eps = eps
self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
self.rho.data.fill_(0.9)
def forward(self, input, gamma, beta):
in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
我們可以看到adaLIN的架構,基本上可以對應到我們上面講adaLIN的方程式,我們會根據IN和LN的比例輸出結果 out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
。
Discriminator
最後針對discriminator的部分,paper也有給出他的架構
- local discriminator
- global discriminator
class Discriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=5):
super(Discriminator, self).__init__()
model = [nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
nn.LeakyReLU(0.2, True)]
for i in range(1, n_layers - 2):
mult = 2 ** (i - 1)
model += [nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
nn.LeakyReLU(0.2, True)]
mult = 2 ** (n_layers - 2 - 1)
model += [nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
nn.LeakyReLU(0.2, True)]
# Class Activation Map
mult = 2 ** (n_layers - 2)
self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
self.leaky_relu = nn.LeakyReLU(0.2, True)
self.pad = nn.ReflectionPad2d(1)
self.conv = nn.utils.spectral_norm(
nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))
self.model = nn.Sequential(*model)
def forward(self, input):
x = self.model(input)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap_weight = list(self.gap_fc.parameters())[0]
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp_weight = list(self.gmp_fc.parameters())[0]
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.leaky_relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
x = self.pad(x)
out = self.conv(x)
return out, cam_logit, heatmap
這邊我們也可以發現對應paper模型架構的部分,他會先把輸入影像丟進前面的down-sampling block,然後做global average pooling torch.nn.functional.adaptive_avg_pool2d(x, 1)
和global max pooling torch.nn.functional.adaptive_max_pool2d(x, 1)
,然後經過fully-connected network乘上相對應的attention weight gap_weight
和gmp_weight
,最後把這些特徵合併起來丟進classifer就會得到我們的輸出out
。
auxiliary classifier的部分也一樣,就是把fully-connected network輸出的gap_logit
和gmp_logit
合併起來,就是其輸出cam_logit
。
更新discriminator
# Update D
self.D_optim.zero_grad()
fake_A2B, _, _ = self.genA2B(real_A)
fake_B2A, _, _ = self.genB2A(real_B)
real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)
fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))
D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))
D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))
D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))
D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)
Discriminator_loss = D_loss_A + D_loss_B
Discriminator_loss.backward()
self.D_optim.step()
我們可以看到上面的程式碼為訓練U-GAT-IT discriminator的流程。首先我們會
- 透過A domain轉B domain的generator
self.genA2B
把真實照片real_A
轉換成假照片fake_A2B
- 透過B domain轉A domain的generator
self.genB2A
把真實照片real_B
轉換成假照片fake_B2A
接下來我們會
- 把真實照片
real_A
和轉換照片fake_B2A
分別丟到global discriminatorself.disGA
和local discriminatorself.disLA
當中 - 把真實照片
real_B
和轉換照片fake_B2A
分別丟到global discriminatorself.disGB
和local discriminatorself.disLB
當中
最後我們就可以根據這些discriminator的輸出結果計算
- A domain global adversarial loss
D_ad_loss_GA
- A domain global adversarial CAM loss
D_ad_cam_loss_GA
- A domain local adversarial loss
D_ad_loss_LA
- A domain local adversarial CAM loss
D_ad_cam_loss_LA
- B domain global adversarial loss
D_ad_loss_GB
- B domain global adversarial CAM loss
D_ad_cam_loss_GB
- B domain local adversarial loss
D_ad_loss_LB
- B domain local adversarial CAM loss
D_ad_cam_loss_LB
最後加總起來就可以去更新我們所有的discriminator
更新generator
# Update G
self.G_optim.zero_grad()
fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
fake_A2B2A, _, _ = self.genB2A(fake_A2B)
fake_B2A2B, _, _ = self.genA2B(fake_B2A)
fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)
fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))
G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))
G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))
G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))
G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)
G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
G_identity_loss_B = self.L1_loss(fake_B2B, real_B)
G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))
G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B
Generator_loss = G_loss_A + G_loss_B
Generator_loss.backward()
self.G_optim.step()
接下來對應generator的地方,我們會做兩次translation,第一次先得到兩個domain相對應的轉換影像,第二次得到兩個domain相對應的重建影像,如此一來我們就可以去計算
- A domain global adversarial loss
G_ad_loss_GA
- A domain global adversarial CAM loss
G_ad_cam_loss_GA
- A domain local adversarial loss
G_ad_loss_LA
- A domain local adversarial CAM loss
G_ad_cam_loss_LA
- B domain global adversarial loss
G_ad_loss_GB
- B domain global adversarial CAM loss
G_ad_cam_loss_GB
- B domain local adversarial loss
G_ad_loss_LB
- B domain local adversarial CAM loss
G_ad_cam_loss_LB
- A domain cycle-consistency loss
G_recon_loss_A
- B domain cycle-consistency loss
G_recon_loss_B
- A domain identity loss
G_identity_loss_A
- B domain identity loss
G_identity_loss_B
- A domain generator CAM loss
G_cam_loss_A
- B domain generator CAM loss
G_cam_loss_B
加總起來就可以來更新我們generator的參數。
小結語
個人覺得U-GAT-IT利用輔助分類器產生CAM attention map的方法還蠻聰明的,可以用視覺化的方式讓我們知道轉換圖片有哪些地方需要改進,不過個人覺得貢獻最大的應該還是AdaILN,讓我們可以透過學習參數針對不同的任務調整IN和LN的多寡。
Reference
[1] Kim, J., Kim, M., Kang, H., & Lee, K. (2019). U-gat-it: Unsupervised generative attentional networks with adaptive layer-instance normalization for image-to-image translation. arXiv preprint arXiv:1907.10830.