深度學習Paper系列(09):Pix2Pix HD

劉智皓 (Chih-Hao Liu)
17 min readNov 17, 2023

--

在上一篇當中,我們介紹了conditional GAN,也就是Pix2Pix的架構,接下來我們要來介紹Pix2Pix的改良版,也就是Pix2Pix HD,其論文名稱叫做

High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs

回顧Pix2Pix

我們在上一篇當中介紹conditional GAN的學習方程式可以分成兩個loss,一個是adversarial loss也就是生成對抗網絡主要的學習方程式,另一個是L1 loss其希望我們的generator轉換過去的影像和目標影像相同。

針對adversarial loss這個項目我們知道其方程式如下

在這邊G就是generator、D就是discriminator,而s就是source domain image,x則是target domain image,簡單來說,如果我們想把素描轉成油畫,那s就是素描、x就是油畫。另外這邊我們得注意到Pix2Pix HD和conditional GAN一樣,都是

Paired image-to-image translation model

也就是說s和x必須要是成對的,不能說我今天想要讓模型從素描轉油畫,我選了個城市街景的素描,讓他學習轉成熱帶雨林的油畫。

但是在原本的Pix2Pix當中,我們會發現他應用在地圖片大小為256 × 256 pixels,但是我們今天如果想要訓練更高解析度或是尺寸更大的影像的話,原有的Pix2Pix模型會在訓練的過程中,變得非常的不穩定,所以Pix2Pix HD的目的就是

達到高解析度的圖像到圖像轉換

Pix2Pix HD

OK那Pix2Pix HD他用了什麼東西呢?其實他的想法也非常的簡單暴力

一個generator和一個discriminator不夠,那就多用幾個

沒錯我們可以看到Pix2Pix HD就是用了兩個generator,一個是負責處理global訊息的G₁,另一個則是負責處理local訊息的G₂。

Global Generator

這邊我們先來看Global Generator,其架構基本上可以拆成3個部分:

front-end:做convolution,圖像會從這邊輸入

residual blocks:裡面包含數個residual module

back-end:做transposed convolution,圖像會從這邊輸出

Local Generator

另我們我們來看Local Generator,其架構基本上,也是包含front-end、residual blocks、back-end,這3個部分,但是這邊不一樣的地方是,我們會把Local Generator front-end和Global Generator back-end相加,然後再餵進去Local Generator的residual blocks中。

兩個Generator合起來

所以我們可以看到對於Local Generator,我們會輸入原始尺寸很大的影像,對於Global Generator,我們會輸入經過down sampling尺寸小兩倍的影像,也就是說如果我們原本想要轉換的影像大小如過是1024 × 1024 pixel,那麼給Local Generator的輸入一樣1024 × 1024 pixel,給Global Generator的輸入就會變成512 × 512 pixel。

看到這邊大家一定覺得非常奇怪,Global聽起來就比較大啊,那為什麼他的輸入影像反而比較小?

實際上這個概念我們之前講U-Net的時候有提到過了,因為我們的front-end網絡裡面包含了很多的convolution和max pooling模組,而當我圖像經過一層又一層的模組,我們神經網絡能抓到的receptive field也就是視域就會越大。所以說

因為我們Global Generator和Local Generator的模型架構一模一樣,所以能抓到的receptive field也一樣,也就是說Global Generator才比較能看到整個圖案的全貌。

我們這邊舉個例子,假設我們丟一張600 × 400 pixel的柯基照片,且我們Global Generator和Local Generator架構一樣,receptive field都是120 × 120 pixel,我們可以看到因為丟給Global Generator的影像被下採樣了,所以我們的Global Generator比較能看到整個柯基的全貌,不過相對的其犧牲的就是整個影像的細部特徵,所以這也是為什麼我們需要搭配Local Generator來抓取圖像的原始細節。

Multi-scale discriminators

OK既然我們在generator的地方弄了dowm-sampling搭配兩個不同的generator來抓取global和local特徵,那我們的discriminator也可以玩同一招。

沒錯在這篇paper,其就用了3個discriminator,這三個discriminator的任務分別是:

D₁:區分目標影像和轉換影像

D₂:區分下採樣2倍的目標影像和下採樣2倍的轉換影像

D₃:區分下採樣4倍的目標影像和下採樣4倍的轉換影像

這樣做的目的就跟上面提到的一樣,讓不同的discriminator專注在不同視域大小的特徵。所以這邊我們的最佳化方程式就可以寫成:

就是這麼的簡單粗暴。

Feature matching loss

我們都知道,其實GAN因為他是一個minimax問題,所以訓練的時候會常常不夠穩定,所以更別提我們Pix2Pix HD這樣的架構,他一次集成這麼多的模型,所以作者在這邊也提出了一個

feature matching loss

的概念。其這個想法也很簡單

我們可以看到,k就是我們discriminator的編號,而這個loss的意思就是我們希望當我們目標圖片x和轉換圖片G(s)丟進去這個discriminator Dₖ後,兩者在Dₖ第i層layer的features要一樣。所以整個最佳化方程式就可以被寫成

Using Instance Maps

另外作者在這邊有提到一個小細節,就是今天我們在做街景生成的時候,我們通常是希望把街景的標註轉換成街景的照片。

但這邊作者有提到,因為這些標註基本上都是一塊塊,裡面的數值都相同,所以對於模型訓練上不具有指標性,導致學習上會讓轉換出來照片在中間的地方很模糊。所以這邊作者就使用boundary map (instance maps),來達到更好的效果。

我們可以看到加入instance map之後,邊界的部分在細節的地方有明顯的改善。

Instance-level Feature Embedding

除此之外在這篇paper當中,作者也有提及到,這種paired image-to-image translation,在學習的時候是1對1的學習,也就是一個source domain的影像對應到相對應target domain的影像。這種方式會限制我們生成圖片的diversity。

所以這篇paper的作者引入了

low-dimensional feature channels

作為generator額外的輸入。

在這邊我們額外增加一個預訓練好的enocder E用來提取target domain影像的特徵,接下來利用這些特徵根據原本source domain的標註進行averge pooling,也就是instance-wise averge pooling,就可以得到上圖中的features。

因為這些features的每一類別pixels的數值代表這個類別的先驗機率分布,所以我們可以將提取的特徵進行K-mean clustering,用來代表K個不同的風格。具體來說像在這篇paper裡面K=10,所以對於樹就可以得到10個數的風格、對於車就有10種車的樣式。如此一來我們就可以控制這些feature來生成各種不同樣式的圖像。

結果

我們可以看到使用pix2pix HD在轉換上得出來的結果明顯比原本的pix2pix來的好。另外我們也可以看到作者在這邊也引入所謂的VGG loss,讓整體轉換上顯得更鮮豔。

Generator架構比較

在這邊我們討論的都是從標註轉到圖像的問題,那麼其實我們也可以做反向的圖像到標註的訓練,其就跟影像分割的任務一樣。所以作者在這邊也做了這樣的事情,去比較他們使用的generator架構和先前U-Net架構及CRN架構。

可以看到使用他們propose架構的generator在分割上,Intersection of Union (IoU)得到最好的結果。

Multi-discriminators效果

另外這邊也比較了使用單一和多個discriminator的效果。我們會發現使用多個的得到更好的分割準確率。

實作

對於Pix2Pix HD的實作,我們可以使用以下repository

我們可以看到在models/networks.py的檔案裡面,其會先定義我們的global generator,然後根據我們訂定的local generator數量來做串接。

class LocalEnhancer(nn.Module):
def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
super(LocalEnhancer, self).__init__()
self.n_local_enhancers = n_local_enhancers
###### global generator model #####
ngf_global = ngf * (2**n_local_enhancers)
model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
self.model = nn.Sequential(*model_global)
###### local enhancer layers #####
for n in range(1, n_local_enhancers+1):
### downsample
ngf_global = ngf * (2**(n_local_enhancers-n))
model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
norm_layer(ngf_global), nn.ReLU(True),
nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf_global * 2), nn.ReLU(True)]
### residual blocks
model_upsample = []
for i in range(n_blocks_local):
model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)]
### upsample
model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(ngf_global), nn.ReLU(True)]
### final convolution
if n == n_local_enhancers:
model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, input):
### create input pyramid
input_downsampled = [input]
for i in range(self.n_local_enhancers):
input_downsampled.append(self.downsample(input_downsampled[-1]))
### output at coarest level
output_prev = self.model(input_downsampled[-1])
### build up one layer at a time
for n_local_enhancers in range(1, self.n_local_enhancers+1):
model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
output_prev = model_upsample(model_downsample(input_i) + output_prev)
return output_prev

對於discriminator也一樣,其裡面可以一次定義多個discriminator。

class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
if getIntermFeat:
for j in range(n_layers+2):
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
else:
setattr(self, 'layer'+str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
else:
model = getattr(self, 'layer'+str(num_D-1-i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D-1):
input_downsampled = self.downsample(input_downsampled)
return result

Reference

[1] Wang, T. C., Liu, M. Y., Zhu, J. Y., Tao, A., Kautz, J., & Catanzaro, B. (2018). High-resolution image synthesis and semantic manipulation with conditional gans. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 8798–8807).

--

--

劉智皓 (Chih-Hao Liu)
劉智皓 (Chih-Hao Liu)

Written by 劉智皓 (Chih-Hao Liu)

豬屎屋AI Engineer,熱愛AI研究、LLM/SD模型、RAG應用、CUDA/HIP加速運算、訓練推論加速,同時也是5G技術愛好者,研讀過3GPP/ETSI/O-RAN Spec和O-RAN/ONAP/OAI開源軟體。