深度學習Paper系列(16):CUT

劉智皓 (Chih-Hao Liu)
38 min readDec 9, 2023

--

今天我們要來帶大家繼續看unpaired image-to-image translation的模型,而今天要介紹的這個模型和先前講的CycleGAN、MUNIT、DRIT等等的不太一樣,這個架構他不需要一次訓練兩個generator和兩個discriminator,其只需要一個generator和discriminator就可以做影像轉換了,而這個架構叫做「CUT」,論文名稱為

Contrastive Learning for Unpaired Image-to-Image Translation

CUT架構介紹

在影像轉換的問題上,我們是希望從source domain影像轉換到target domain影像的過程中

除了能輸出target domain的appearance,還能夠保存source domain影像原有的content

因為我們大部分訓練的情境都是target domain資料和source domain資料是不成對的,用單一一個GAN做轉換很容易產生的圖片雖然看起來和target domain很像,但是失去原有source domain的型態和輪廓,所以大部分的模型架構都是直接訓練兩個generator然後做cycle-consistency來限制轉換圖象的樣貌。

但是很明顯的,如果這樣的做法會

用到雙倍的訓練資源和訓練時間

所以在這邊我們希望只用一個generator和一個discriminator就能做到轉換,同時轉換圖片有target domain的appearance又有source domain的content,在這邊CUT所用的方式就是用

基於self-supervised learning的PatchNCE loss

這邊簡單扼要說明一下他的原理,self-supervised learning的想法就是在沒有任何資料標註的情況下讓模型自我學習,而其中一種方式為「Contrastive Learning

Contrastive Learning

運作方式就像上圖那樣,我們模型今天看到狗和沙發的照片,他根本不知道這是什麼,對他來說就是一串Pixel矩陣而已,那要怎麼讓模型在沒有標註的情況下還能讓他知道狗和沙發長什麼樣子呢?其實很簡單,我們就只要把兩張照片隨便擷取一個部份甚至把顏色做一下更改,然後告訴模型

來自相同圖片的擷取影像要越相似越好(attract),不同的要越不像越好(repel)。

如此一來我們今天有一大堆沒有標註的照片,就可以利用這套方法讓模型學到東西。

這邊大家一定會有疑問啦!那如果今天兩張不同的照片,他們都是狗的圖片,那我們又希望他們越不像越好不是很奇怪嗎?沒有錯!Contrastive Learning學習過程中的確會希望這兩張越不像越好,但是如果我們本身訓練的影像資料本身很龐大的話,你就會發現因為沙發的型態和狗的型態差太多了,所以最後你會看到當我們把狗和沙發圖像的embedding拿出來投影到平面,兩張不同狗照片的距離會比狗和沙發照片的距離來得接近。

回到CUT的架構,如下圖所示

我們可以看到CUT在學習的時候,我們會從輸入圖象中裁切一些圖片,然後從這些圖片挑選一個部位做為positive sample z+,剩下的作為negative sample z-,然後在轉換圖片中相對應positive sample的地方也裁切一張圖片z,然後我們希望

  • 轉換圖像擷取圖片z和輸入圖像擷取圖片的positive sample z+越像越好
  • 轉換圖像擷取圖片z和輸入圖像擷取圖片的negative sample z-越不像越好

舉例來說,我們可以看到上面圖例子,輸入影像是馬,轉換影像是斑馬,所以我們當然會希望馬的頭的特徵和斑馬頭的特徵越像越好,然後馬的其他部位,甚至是草地這些地方和斑馬頭越不像越好。如此一來

原本的generator和discriminator可以學到target domain的appearance,新加的constrative loss可以學到source domain的content

CUT理論

OK接下來就是來更詳細的講解CUT是怎麼學習的啦!這邊我們有兩個domain,一個是source domain X的影像x和target domain Y的影像y,我們今天想要訓練一個generator可以把影像從source domain轉到target domain。這個過程我們可以寫成

其中generator的部分,我們可以更進一步拆解成encoder Genc和decoder Gdec。接下來我們會有一個discriminator D去辨識生成圖片和原始圖片,透過對抗學習讓產生的圖片越來越像真的,而其adversarial loss可以寫成:

Mutual information maximization

接下來就是進入到我們的重頭戲constrastive learning的地方,這裡我們會

從輸入影像x取樣1個positive sample v+和N個negative sample v-,並依據positive sample的位置從轉換影像y^得到相對應的v

在這邊我們的目標是希望v和v+越像越好、v和v-越不像越好,所以我們可以透過cross-entropy loss達到這個目標

這邊的τ指的就是temperature,可以用來控制v和其他sample之間的距離。所以我們會發現最小化這個loss,就是希望最大化log裡面的東西,更進一步來看,我們希望分母總和v和negative sample v-運算那邊越小越好、v和positive sample v+運算那邊越大越好,簡單來說就是

negative sample v-和v的相似度越低越好、positive sample v+和v的相似度越高越好

如此一來我們就可以透過這個loss來最佳化我們的generator。

Multilayer, patchwise contrastive learning

不過相信大家很快就會看到一個大問題,這邊v、v+和v-他們都是向量,那他們到底怎麼來的?在這邊我們就會

把generator的encoder Genc拿出來,接著把圖片輸入encoder中,選取encoder裡面L層的特徵,然後再把這些特徵丟到兩層的MLP H裡面,其輸出就是我們用來計算上面Loss function的向量

  • source domain 向量
  • target domain 向量

這邊大家一定又會有一個疑問了,影像丟進去encoder裡面照理來說會被壓縮或變小,像是256 x 256 pixel影像,可能經過encoder的第一層,feature map就變成128 x 128 pixel,我怎麼知道這一層的特徵能對應到原始圖片的哪個區塊?

實際上我們根本就不用去擔心這件事,因為我們其實可以把每一層特徵都當成一個獨立的圖像,然後一樣對其進行取樣,因為我們主要的目標是

讓模型透過對比學習原始圖片和轉換圖片之間的相關性

所以在這邊我們可以設我們想要萃取encoder Genc的層數有L個,然後第l層的特徵中,我們取樣Sl個sample,然後丟到對應的MLP Hl就可以得到其向量vl。所以我們在這邊就可以定義我們PatchNCE loss:

這邊上標s代表的就是positive sample,S\s代表的是negative sample,所以最後我們會把L層encoder特徵對比學習算出來的loss加總起來。

我們上面有提到我們計算PatchNCE loss的方式是透過對比學習計算輸入圖像和轉換圖像不同patch的相似度,然後希望相同位置的patch(positive sample)越像越好、不相同位置的patch(negative sample)越不像越好。

那既然我們的輸入圖像取出的negative sample只要和轉換圖像中positive sample對應的patch越不像越好,那麼我們negative sample是不是其實也可以從source domain裡面隨機挑一張影像然後取幾個patch也可以呢?沒錯CUT在這邊把這種方式計算出來的loss叫做external NCE loss,其可以表示成:

這邊的zl~就是我們從source domain隨機取出一張照片然後丟到encoder Genc取出第l層的特徵,然後丟到相對應層的MLP Hl所得出來的向量。

小補充

實際上CUT所使用的兩個loss function,可以對應到兩大self-supervised learning的框架

  • PatchNCE loss可以對應的SimCLR
  • External NCE loss可以對應到MoCo

對於SimCLR來說,其就是每次取樣一個batch,然後在裡面挑取positive sample和negative sample來做對比學習,所以套用SimCLR在CUT上面時,我們的positive sample和negative sample都是取樣於同一張照片。

對於MoCo而言,他的negative sample會從不同的batch取樣,所以我們把MoCo套用在CUT上時,negative sample就是從另一張照片去取樣。

Final objective

最後我們把三個不同的loss function合併起來,就是我們CUT的學習方程式,而在這邊我們同樣會加入所謂的identity loss,讓我們的轉換影像能夠同時保存原本圖像中的顏色。

實驗結果

對於CUT的模型架構,作者在這邊有分成兩個不同的模型,一個是一般版的CUT,另一個是訓練比較快速的FastCUT,兩個的差異在於FastCUT並沒有使用identity loss來訓練。

模型架構比較

我們可以看到作者在這邊跟6個不同的模型去比較,可以看到CUT在轉換上的速度和品質都達到SOTA。

Ablation Study and Analysis

我們可以看到在影像轉換上作者也做了實驗比較,沒有使identity loss (no id)、只取encoder最後一層特徵(last layer only),以及只有使用外部圖片當negative sample沒有使用原始圖片當negative sample(external only)。

Internal negatives比External negatives更有效

我們主要可以看到external only的那組實驗,轉換出來的圖像失去了原本輸入圖像大部分的特徵,畢竟internal negatives和positive sample來自同一張圖像,所以使用internal negatives可以幫助轉換時保有原始圖像的content

使用multiple layers encoder的重要性

我們上面有介紹到我們取特徵算loss的時候,會從encoder不同層來取特徵,在這邊CUT的作法是每四層從pixel到第16層,這個方式實際上和傳統的L1+VGG loss是一樣的,我們在這邊可以看到如果我們只嘗試使用encoder最後一層的特徵(last layer only),整個CUT的性能就會大幅下降

Identity loss穩定訓練

我們可以看到使用identity loss的目的,就是希望如果我們今天輸入target domain的影像到generator裡面,其輸出要和原本的輸入圖像一樣,而這樣的限制可以幫助我們的模型訓練的時候更加穩定,所以我們可以看到no id的模型轉換出來的照片品質非常差。

透過encoder視覺化學習到的相似性

在這邊作者也做了另一個有趣的實驗,我們可以看到圖(a)有藍色的點和紅色的點,這邊我們把藍色的點和圖(b)其他地方去計算相似性exp(v ·v−/τ ),同樣紅色的點也去和圖(b)其他地方去計算相似性,我們會發現CUT所學出來的Network可以神奇的抓到各種相關的特徵,我們可以看到藍色點位於斑馬的脖子上,所以馬照片上,馬身體部分的相似性比較高,相反的紅色點位於草地,所以馬照片上屬於草地部分的相似性比較高。

除此之外作者也把輸入圖像和轉換圖像所計算出來的特徵向量用PCA壓縮成3維,也就是RGB格式,我們可以看到輸入圖像和轉換圖像的結果基本上是非常接近的。

除此之外作者在這邊還有做其他更多的實驗設計,像是用15個或是255個negative sample,encoder取特徵的時候只用最後一層和是全部都用,有沒有使用內部圖像或是外部圖像,不過我個人覺得這邊都算是hyperparameter需要針對不同的情境去調整,所以這邊就不贅述了。

Additional applications

另外這邊也是作者團隊做的一些其他的圖像轉換實驗結果。

High-Resolution Single Image Translation

最後作者在這邊針對CUT,去融合了另一個GAN的架構叫做SinGAN:

SinGAN的架構主要就是我們只有一張照片我們怎麼做圖像轉換,礙於篇幅這邊就不贅述,有興趣的朋友可以去查SinGAN這篇paper。

所以結合了CUT和SinGAN,作者就把這個CUT的變種叫做SinCUT,這邊我們可以把它應用在super-resolution上面,產生更高畫質的影像。

實作

針對CUT的實作部分,我們使用以下repository:

這邊我們先來看network架構

Generator

class ResnetGenerator(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.opt = opt
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
if(no_antialias):
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
else:
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True),
Downsample(ngf * mult * 2)]
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
if no_antialias_up:
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
else:
model += [Upsample(ngf * mult),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=1,
padding=1, # output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input, layers=[], encode_only=False):
if -1 in layers:
layers.append(len(self.model))
if len(layers) > 0:
feat = input
feats = []
for layer_id, layer in enumerate(self.model):
feat = layer(feat)
if layer_id in layers:
feats.append(feat)
else:
pass
if layer_id == layers[-1] and encode_only:
return feats # return intermediate features alone; stop in the last layers
return feat, feats # return both output and intermediate features
else:
"""Standard forward"""
fake = self.model(input)
return fake

基本上我們可以看到在模型架構上,大部分的unpaired image-to-image translation都是用一串residual block ResnetBlock的架構,這邊比較特別的地方就是在forward function,如果我們今天在做影像轉換的時候,我們generator輸出是feat,然後我們也會同時輸出我們encoder裡面萃取的特徵feats,這邊的作法就是我們會給定我們想要擷取的layer layer_id,然後再把到的特徵加到feats裡面。

Discrimininator

class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
if(no_antialias):
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
else:
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
if(no_antialias):
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
else:
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
Downsample(ndf * nf_mult)]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)

這邊的架構也是採用PatchGAN discriminator,基本上也和大部分的模型架構相同。

MLP

class PatchSampleF(nn.Module):
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
# potential issues: currently, we use the same patch_ids for multiple images in the batch
super(PatchSampleF, self).__init__()
self.l2norm = Normalize(2)
self.use_mlp = use_mlp
self.nc = nc # hard-coded
self.mlp_init = False
self.init_type = init_type
self.init_gain = init_gain
self.gpu_ids = gpu_ids
def create_mlp(self, feats):
for mlp_id, feat in enumerate(feats):
input_nc = feat.shape[1]
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
if len(self.gpu_ids) > 0:
mlp.cuda()
setattr(self, 'mlp_%d' % mlp_id, mlp)
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
self.mlp_init = True
def forward(self, feats, num_patches=64, patch_ids=None):
return_ids = []
return_feats = []
if self.use_mlp and not self.mlp_init:
self.create_mlp(feats)
for feat_id, feat in enumerate(feats):
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
if num_patches > 0:
if patch_ids is not None:
patch_id = patch_ids[feat_id]
else:
# torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83
patch_id = np.random.permutation(feat_reshape.shape[1])
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
else:
x_sample = feat_reshape
patch_id = []
if self.use_mlp:
mlp = getattr(self, 'mlp_%d' % feat_id)
x_sample = mlp(x_sample)
return_ids.append(patch_id)
x_sample = self.l2norm(x_sample)

if num_patches == 0:
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
return_feats.append(x_sample)
return return_feats, return_ids

對於MLP的架構,其會根據encoder每一層萃取到的特徵數量來決定其大小,不過基本架構就是兩層nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]),最後MLP會輸出每一層的特徵向量return_feats和相對應的層數ID return_ids

PatchNCE loss

class PatchNCELoss(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
def forward(self, feat_q, feat_k):
num_patches = feat_q.shape[0]
dim = feat_q.shape[1]
feat_k = feat_k.detach()
# pos logit
l_pos = torch.bmm(
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
l_pos = l_pos.view(num_patches, 1)
# neg logit
# Should the negatives from the other samples of a minibatch be utilized?
# In CUT and FastCUT, we found that it's best to only include negatives
# from the same image. Therefore, we set
# --nce_includes_all_negatives_from_minibatch as False
# However, for single-image translation, the minibatch consists of
# crops from the "same" high-resolution image.
# Therefore, we will include the negatives from the entire minibatch.
if self.opt.nce_includes_all_negatives_from_minibatch:
# reshape features as if they are all negatives of minibatch of size 1.
batch_dim_for_bmm = 1
else:
batch_dim_for_bmm = self.opt.batch_size
# reshape features to batch size
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
npatches = feat_q.size(1)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
# diagonal entries are similarity between same features, and hence meaningless.
# just fill the diagonal with very small number, which is exp(-10) and almost zero
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
l_neg_curbatch.masked_fill_(diagonal, -10.0)
l_neg = l_neg_curbatch.view(-1, npatches)
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
device=feat_q.device))
return loss

OK接下來我們覺得算是CUT的精隨了,就是如何計算PatchNCE loss,我們上面有說到,假設我們會從source domain圖像裡面一次取樣32個patch,然後1張做positive sample,剩下的31張做negative sample,接下來我們把target domain轉換出來的圖片,根據positive sample patch的位置同樣擷取一個patch,來計算PatchNCE loss。不過在實作上

我們其實會在target domain轉換圖像擷取和source domain圖像取樣32個patch相同位置的32個patch。

這樣做的好處就是我們就可以直接算一個32 x 32的相似度矩陣,對角線diagonal的數值就是postive sample,也就是我們希望他們越大越好,其餘位置全部都是negative sample,我們希望他們越小越好。

所以我們會看到我們一開始做的事情就是把要計算的兩組特徵向量feat_qfeat_k

  • 內積(inner product)就可以得到positive logit l_pos
  • 外積(outer product))就可以得到negative logit l_neg

不過這邊要注意的是negative logit對角線的部分其實就是positive logit的數值,所以這邊我們會把他替換成數值-10,避免計算negative sample時也希望這條對角線數值也越小越好。

這邊如果我們有做external PatchNCE loss的話,也就是把其他影像也作為negative sample時,我們就會把nce_includes_all_negatives_from_minibatch設為True,讓batch_dim_for_bmm = 1,因為我們每次訓練時都是一個batch一個batch的訓練,照理來說用一般PatchNCE loss時,batch裡面的每個資料都是獨立的,但是我們使用external PatchNCE loss時,我們等於說可以把同一個batch裡面其他的資料都當作negative sample取樣的來源,所以最後我們會做一個feature reshape view(batch_dim_for_bmm, -1, dim)

def calculate_NCE_loss(self, src, tgt):
n_layers = len(self.nce_layers)
feat_q = self.netG(tgt, self.nce_layers, encode_only=True)

if self.opt.flip_equivariance and self.flipped_for_equivariance:
feat_q = [torch.flip(fq, [3]) for fq in feat_q]

feat_k = self.netG(src, self.nce_layers, encode_only=True)
feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)

total_nce_loss = 0.0
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
loss = crit(f_q, f_k) * self.opt.lambda_NCE
total_nce_loss += loss.mean()

return total_nce_loss / n_layers

接下來我們在整理計算Loss的地方可以看到,我們會給定要對比的兩個影像srctgt,然後把他們都輸入encoder裡面,所以這邊我們會指定encode_only=True,代表不用輸出圖像,我們只要encoder的特徵就好,來減少運算,如此一來就可以得到encoder每一層的特徵向量feat_kfeat_q,其就可以用來計算我們的loss total_nce_loss

foward function

def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
self.real = torch.flip(self.real, [3])

self.fake = self.netG(self.real)
self.fake_B = self.fake[:self.real_A.size(0)]
if self.opt.nce_idt:
self.idt_B = self.fake[self.real_A.size(0):]

我們可以看到CUT的流程就比CycleGAN和UNIT等簡單多了,第一次就是把原始影像self.real轉換成假照片self.fake,不過這邊我們會同時輸入A domain影像self.real_A和B domain影像self.real_B,前面的就是做一般的影像轉換得到self.fake_B,後面的就是為了要去計算identity loss,所以可以得到self.idt_B

更新discriminator

def compute_D_loss(self):
"""Calculate GAN loss for the discriminator"""
fake = self.fake_B.detach()
# Fake; stop backprop to the generator by detaching fake_B
pred_fake = self.netD(fake)
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
# Real
self.pred_real = self.netD(self.real_B)
loss_D_real = self.criterionGAN(self.pred_real, True)
self.loss_D_real = loss_D_real.mean()
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
return self.loss_D

基本上更新discriminator比較簡單,只需要計算adversarial loss就好

更新generator

def compute_G_loss(self):
"""Calculate GAN and NCE loss for the generator"""
fake = self.fake_B
# First, G(A) should fake the discriminator
if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD(fake)
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
else:
self.loss_G_GAN = 0.0
if self.opt.lambda_NCE > 0.0:
self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
else:
self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
else:
loss_NCE_both = self.loss_NCE
self.loss_G = self.loss_G_GAN + loss_NCE_both
return self.loss_G

generator的部分就比較麻煩,除了adversarial loss self.loss_G_GAN之外,我們還會計算PatchNCE loss self.loss_NCE和identity loss self.loss_NCE_Y

小結語

CUT的作法其實真的蠻有趣的,剛好搭上那個時候剛興起的Self-supervised learning,使用SimCLR和MoCo框架套在GAN上,讓我們可以做到unpaired的影像轉換。

Reference

[1] Park, T., Efros, A. A., Zhang, R., & Zhu, J. Y. (2020). Contrastive learning for unpaired image-to-image translation. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part IX 16 (pp. 319–345). Springer International Publishing.

--

--

劉智皓 (Chih-Hao Liu)
劉智皓 (Chih-Hao Liu)

Written by 劉智皓 (Chih-Hao Liu)

豬屎屋AI RD,熱愛AI研究、LLM/SD模型、RAG應用、CUDA/HIP加速運算、訓練推論加速,同時也是5G技術愛好者,研讀過3GPP/ETSI/O-RAN Spec和O-RAN/ONAP/OAI開源軟體。

No responses yet