吃掉你GPU記憶體的Cross-Entropy Loss
Hello大家好,好久不見最近一直加班沒時間寫文章QQ。相信大家絕對都知道Cross-Entropy這個經典的損失函數吧!
不過有在玩LLM training的朋友一定都知道,計算cross-entropy loss他會吃掉你大量的GPU記憶體,尤其是在那種vocabulary size很大的模型,Gemma2、Qwen2或是Llama3,一定是造成各位每次訓練GPU OOM (Out of Memory)的夢魘吧!
所以今天這篇文章我們就要來講為什麼Cross-entropy會吃掉你大量的GPU記憶體,還有介紹後來開源社群的大神們是怎麼解決這個問題的。
這邊我們會分成幾個部分:
1. Cross-Entropy:我會講Cross-entropy如何計算的,然後為什麼他在LLM訓練上會消耗大量的GPU記憶體。
2. Chunk-based Cross-entropy:會講用Chunk-based cross-entropy如何降低記憶體使用量。
3. Fused Linear Cross Entropy:會提到更猛的Linear cross entropy,把最後一層output embedding layer和cross entropy運算結合,還有把backward propagation搬到forward做省下更多的記憶體。
4. 其他優化:最後會探討針對Linear Cross-Entropy的更多優化及其相關paper。
OK那廢話不多說我們就開始吧!
什麼是vocabulary size?
在進入Cross-entropy之前,我們得先講一些LLM的基礎。假設我們LLM的vocabulary size是32000,那就代表說這個LLM的字彙量是32000,而每個LLM都有他自己的tokenizer,所以今天我們有一段文章,tokenizer就會把這個文章裡面所有的字,用32000個字彙去表示他。
所以你會看到一段文章會被tokenizer轉成一連串的數字,其叫做input ids,而單獨的一個數字稱之為token。而你會發現這些數字可以代表32000個字彙裡面的其中一個,像是你可能會看到38他代表dog,1283代表sofa。而這個編號在每個LLM可能都不一樣。所以這邊你可以很直覺的理解,vocabulary size越大,模型懂得字彙越多,所以「通常」模型能力也會越強。
接下來input ids就會被送進我們的LLM裡面做運算,這邊大家可能又會有疑問了,那LLM怎麼被訓練成看得懂這些一連串的數字的呢?其實很簡單,在LLM當中的第一個module和最後一個module,就是我們的input embedding layer和output linear layer。
Input Embedding Layer
其是一個大小為vocabulary size × hidden size的矩陣,裡面都是浮點數,這邊的hidden size你可能把他想像成模型的寬度,所以今天假設一個文章經過tokenizer轉出來有2048個tokens,我們的input embedding layer會把他轉成大小為2048 × hidden size的hidden state。
所以實際上就是假設這個token的數值是5678,那input embedding layer第5678個column的vector(長度為hidden size),就是代表這個token的hidden state。接下來這個hidden state就會進入一層又一層的Transformer Block運算。
Output Linear Layer
經過一連串Transformer Block的運算,最後我們的會乘上一個大小為hidden size × vocabulary size的矩陣,所以一樣假設我們有2048個tokens輸入,最後我們LLM會輸出一個大小為2048 × 32000 (vocabulary size)的矩陣,其叫做logits。
也就是說每一個token其logit相對應的長度為32000 (vocabulary size),所以我們可以取32000個數值裡面的最大值,像是有可能這個token的logit第4090個數值最大,那我們就可以用tokenizer以相同的方法,從字彙裡面看4090代表哪個字,其就可以轉成我們LLM的文字輸出!
Next Token Prediction
看到上面大家可能還是會感到很疑惑,輸出最大值到底代表什麼意思?這邊我們可以假設我們今天的輸入是「I saw a cat on a mat」,然後tokenizer把我這句話轉成input ids 「24, 789, 9090, 8765, 23333, 3456」總共7個tokens,這些input ids經過input embedding layer會轉成 7 × hidden size的矩陣,然後經過Transformer Block做運算。
最後Output Linear Layer會輸出大小為7 × 32000 (vocabulary size)大小的矩陣,而這邊在訓練LLM的時候,我們希望第1個token (24)其logit的最大值在第789個,也就是說我們希望LLM看到I這個字他會把他轉成他的下一個saw這個字,同樣道理,我們會希望第2個token (789)其logit的最大值在第9090的位置,也就是LLM看到I saw這兩個字後,會把第二個token轉成a這個字。
所以我們可以說我們的label是[789, 9090, 8765, 23333, 3456, -100],-100代表不會計算這邊的loss。而這就是所謂的next token prediction,也就是李弘毅老師說的「玩文字接龍」。
Cross-Entropy
OK了解了上面的原理後,我們就要來看Cross-Entropy Loss到底怎麼算
我們可以看到這邊C代表的是類別的數量,也就是vocabulary size,t_i就是我們的target,也就是label,而s_i是logit也就是模型output linear layer的輸出。而實際上算loss的時候,我們會把logits先經過softmax function,然後才算cross-entropy loss。
這邊大家可能會有一個疑惑,為什麼要經過softmax呢?
其主要是為了能夠把數值限制在0~1之間
那你可能又會問了,阿為啥不直接把數值除以加總就好,還要多一個exponential呢?
其實主要的原因就是為了能夠讓數值大的logit更大,數值小的logit更小
這邊舉個例子,假設我們的logits是[2.0, 1.0, 0.0]直接normalize的數值是[0.5, 0.33, 0.17],使用softmax的結果是[0.71,0.26,0.04],而這樣的好處就是可以讓我們的梯度更大。
Forward
所以這邊我們可以把Cross entropy loss寫成
這邊可以把他理解成,左邊的就是logit所有數值取exponential,並且加總起來取log,右邊就是我們label對應到logit的數值,舉例來說label的數值是78,右邊減去的項就是logit中第78個數值。
Backward
接下來我們來看Cross entropy的backward,這邊其實就是把我們的Cross entropy function微分
我們可以看到針對位置等於label的地方和不等於label的地方,就是差一個1。
Overflow
不過事實上Cross-Entropy並沒有那麼簡單,我們可以看到因為我們運算大部分都是exponential,所以
有可能會出現exponential後,數值超過fp32可以表示的範圍,也就是「Overflow」的情況。
所以實作Cross-Entroy時,我們會在exponential中減去logit中的最大值,所以真實Cross-Entropy Loss可以表示成
Forward:
backward:
Memory Analysis
OK那到底為什麼算Cross-Entropy Loss會消耗我們大量的記憶體呢?
我們假設我們的sequence length是N個tokens,然後vocabulary size是V,所以我們logits大小就會是一個N × V的矩陣,而Forward時,因為logits的每個數值都要做exp(s_i-max(s)),所以又會產生一個N × V的矩陣,另外每個token還有一個自己的sumexp(分母)、max值、loss值和label,所以
Foward的時候總共會消耗N × (2V+4) GB的記憶體
而Backward的時候我們用Forward算出來大小為N × V的exp(s_i-max(s)),以及sumexp和max值,回推logits的gradient,這個時候gradient的大小會跟logits相同,所以加上backward回傳的loss gradient (大小為N)
backward記憶體使用量也是N × (2V+4) GB
所以假如說我們的
- Batch Size是16
- sequence legnth是2048個tokens
- vocabulary size是131072 (大概是Llama-3的大小)
而算Loss的時候是用FP32精度去計算,所以總共會有
16 × 4(FP32) × 2048 × (2 × 131073 + 4) = 32 GB
卡在你的GPU上。所以你肯定會看到CUDA RUN Out Of Memory!
而且我們會發現這樣先求exp(s_i-max(s))再算log的過程,他其實需要把記憶體從GPU的HBM和SMEM間搬來搬去,
而這邊HBM指的就是GPU的Global Memory,所以你聽到什麼Nvidia A100 80GB,代表A100的HBM有80GB,而GPU上面有很多個stream multiprocessor (SM),其就是負責做運算的,像是A100就有108個。
而每一個SM都有自己相對應的記憶體SMEM,當今天我們要做運算的時候,我們會把資料從HBM搬到SM的SMEM上,算完後再把結果存回HBM。
使用的過程就是
Forward
- 把s從HBM load到SMEM上面運算求出max(s),然後把max(s)寫回HBM
- 把s從HBM load到SMEM算exp(s_i-max(s)),然後把exp(s_i-max(s))寫回HBM
- 把exp(s_i-max(s))從HBM load到SMEM上算sumexp,然後把sumexp寫回HBM。
- 把exp(s_i-max(s))、sumexp和label從HBM load到SMEM上算出loss,然後把loss寫回HBM。
Backward
- 把exp(s_i-max(s))、sumexp和label從HBM load到SMEM算出Cross-entropy的微分,並把其存回HBM。
- 把loss gradient和Cross-entropy的微分從HBM load到SMEM算出logits的gradient,最後把其存回HBM。
而我們可以看到這樣大量的I/O會使我們的運算的效率很低。
Chunk-based Cross-entropy
所以為了解決中間產生exponential那個N × V超大的矩陣,我們可以非常直覺得想到,那在Foward和Backward的步驟,我們不要讓tensor在HBM和SMEM之間傳來傳去,我們一次load到SMEM把他算完。
但實際上一定沒有那麼美好,這邊我們可以看到A100的SMEM只有192kB,所以如果我們vocabulary size是131072的話,一個token的大小就是4 × 131072 = 512kB,其就超過我們SMEM的大小。
所以Unsloth這家AI新創提出了所謂的Chunk-based Cross-entropy。
這個想法就是當vocabulary size很大的時候,我們把他切成幾個chunk來算,像是我們vocabulary size大小是256k,那我就把它分成4個chunk來算,一個是chunk大小是65536。
你可能會想,奇怪?阿655362還是有256kB,還是超過SMEM大小阿!其主要原因是在Nvidia的GPU上面,每一個SM可以使用的register數量可以達到65536個,所以我們就可以放得下啦!
等等但是我們是不是還忘了一件事情?我們回顧一下Cross-Entropy的算法
算exponential的地方要的是max(s),也就是說vocabulary size是131072的時候,如果我們切成65536,那
這兩個chunk各自算出來max數值都是local maxima,所以最後算出來的loss是錯誤的!
不過別擔心,我們用一些簡單的數學等價公式就可以輕鬆解決這個問題,我們可以把上面的Cross-Entropy分母求logsumexp的地方寫成
得到了logsumexp,我們就可以算出loss了
所以實際上是怎麼情況呢?我們把vocabulary size切成數個chunks,然後每個chunk分別算local logumexp,然後我們可以用下面的公式算出global logsumexp。
這邊不使用直接相加的原因,也是為了numerical stability。
Forward解決了,那Backward呢?其實我們可以直接把backward方程式改寫成這樣。
也就是說我們只要利用Forward算出來的logsumexp就可以得到Backward gradient的數值。
另外這邊還有一個Trick,就是Pytorch裡面很常用的一個手法「in-place」運算,我們算backward gradient的時候是需要s的,但我們回推cross entropy gradient後就再也不會用到s,而s的gradient和s的維度、大小、資料型態相同,所以Backward的時候,我們可以直接把s的gradient覆寫到s上,也就是對s做in-place運算。
Memory Analysis
OK接下來我們來做記憶體的分析
針對Forward的地方,假設sequence length是N個tokens,而vocabulary size大小是V,我們logits的大小是N × V,接下來我們要先把local logsumexp算出來,其大小是N × V / 65536,這邊65536就是chunk的大小。
接下來我們就可以算global logsumexp和loss了,而這邊global logsumexp大小為N,另外我們可以直接把global logsumexp覆蓋掉local logsumexp,而loss大小為N,最後加上label的大小也是N。所以我們
Forward的記憶體使用量是N × (V+V / 65536+2)。
Backward的部分我們需要global logsumexp、logits和labels、loss的 gradient,但是因為logits在回推cross-entropy的gradient後,就不會再用到了,所以就可以如上面提到的作in-place運算。所以這裡
backward記憶體使用量是N × (V+3)。
所以一樣我們假設
- batch size是16
- sequence length是2048
- vocabulary size是131072
Forward的記憶體使用量是
4(FP32) × 16 × 2048 × (131072 + 131072/65536 + 2) = 16GB。
Backward的部分也大概是16GB。所以我們可以看到我們節省了將近快一半的記憶體。
而運算操作流程如下:
Forward
- 把logits和labels從HBM load到SMEM,算出local logsumexp寫到HBM,並且也先把-s_i寫到loss。
- 把local logsumexp從HBM load到SMEM,算出global logsumexp,並寫回HBM。
- 把global logsumexp和loss從HBM load到SMEM,算出真的的loss值,再寫回去原本的loss。
Backward
- 把logits、labels和global logsumexp從HBM load到SMEM,算出logits gradient,並把值直接寫回logits自己。
我們可以看到使用Chunk-based cross entropy也省下了很多I/O。
Fused Linear Cross Entropy
但這個時候大家可能還是會想,雖然省到一半了,但是16GB還是好大,能不能再壓的更低呢?沒錯這個就是我們接下來要介紹的由LinkedIn所發布liger kernel Project的Fused Linear Cross Entropy!(順帶一提發布liger kernel的是臺灣人,臺大的大學長)
首先我們要先了解LLM Training的計算過程,如果我們LLM的hidden size是H。
Forward在最後面的時候,hidden state (N × H)會和output linear layer (H × V)做矩陣相乘算出logits (N × V),接下來logits會labels (N)去計算Cross Entropy Loss (N)。
Backward的時候我們會先把Cross Entropy Loss的gradient(N)回推logits的gradient (N × V),而這個gradient會往前傳到output linear layer,這個時候我們會回推兩個gradient
- 一個是往前回傳的hidden state的gradient (N × H)
- 另一個是output linear layer自己本身參數的gradient (H × V)
那我們在這邊怎麼做可以進一步的省下GPU記憶體呢?
1. In-place Fused Foward & Backward Cross-Entropy
首先liger kernel的第一招就是"Fused Forward & Backward Computation",簡單來說就是我在forward的時候,連backward一起算。什麼?backward不是要前一層的gradient才能做,那怎麼forward的時候做backward?
這邊大家還記得我們做forward的時候,
forward最後一個運算就是cross-entropy loss,所以說backward的第一個運算必定也是回推cross-entropy的gradient,所以我們可以直接把cross-entropy loss的backward移到forward提前去做。
那這樣有什麼好處呢?
在前面的部分我們有提到我們會用logits和labels計算loss,然後用loss的gradient回推logits的gradient。因為logits算出loss後,我們就不會再用到logits了,所以
如果我們把backward移到forward去做,我們就可以直接把logits的gradient存回去logits他自己,如此一來就可以省下額外創建logits gradient的空間。
等等我們是不是中間忘了什麼細節?loss的gradient到底是什麼,其實這個超級簡單,我們還記前面提到,每一個tokens最後都會算出自己的loss,然後我們最後會使用mean,將所有loss取平均,讓最後呈現出來的只有一個loss。所以說loss的gradient就是1/(所有tokens數量)。
2. Softmax-Tiling
OK雖然我們把backward移到forward做減少了記憶體使用量,但是還是有一個大問題,那就是我們還是要算logsumexp,OK那我們一樣用chunk-based的方式去算就好了。但是又出現一個新的問題,我們今天的Kernel,是希望把Forward運算和Backward運算Fused在一起,也就是說
我們可以想辦法不要讓中間的logsumexp在SMEM算出來還要存回HBM,然後再讀上來SMEM做backward,
那要怎麼樣才可以省下這些HBM和SMEM之間的I/O,更精確地說logsumexp要怎麼樣才不會存在HBM上呢?
這個地方liger kernel借鑑了「softmax tiling」的技巧,(P.S. 這個方法也同樣使用在Flash Attention)
假設我們的vocabulary size大小是131072,那我們可以先把0~65535的logits從HBM load到SMEM上算出local sumexp、local maxima,接下來我再把65526~131071的logits一樣從HBM load到SMEM上,這個時候我們就可以用新一輪算出的結果去校正local sumexp和maxima,如此一來經過一輪一輪的迭代就可以算出global sumexp和maxima。最後就可以算出loss。
所以我們可以看到上面一整串的運算,完全都沒有把任何中間產生的數值從SMEM寫回HBM,這些完全在同一個Kernel上全部算完。
接下來我們可以直接在相同的kernel推算logits的gradient,這個時候
因為我們已經得到了global sumexp和maxima,所以我們可以直接再次分批把logits從HBM讀到SMEM,算出其相對應的logits gradient
這個時候我們可以用第一點提到的in-place技巧
直接把SMEM上算出來的logits gradient存回去在HBM上的logits。
到目前為止我們來分析一下記憶體使用量
設sequence length是N個tokens,而vocabulary size大小是V,我們logits的大小是N × V,因為forward和backward fused在一起了,而且利用了in-place操作,所以產生的logits gradient存回去了logits本身,所以總體記憶體使用量是N × V
所以一樣我們假設
- batch size是16
- sequence length是2048
- vocabulary size是131072
記憶體使用量是
4(FP32) × 16 × 2048 × 131072 = 16GB
奇怪做了這麼多優化好像跟Unsloth的kernel沒什麼差,難道不能再更低了嗎?
3. Linear Cross-Entropy
所以針對於此Liger Kernel,其又做了另一個操作就是
把最後一層output linear layer的運算和cross-entropy loss合在一起
為什麼這樣做可以節省記憶體呢?
還記得我們前面提到hidden state會輸入output linear layer,然後計算出logits,而logits算出後就是丟進上述的kernel得出logits的gradient,而接下來就是計算output linear layer的gradient和hidden state的gradient。
那既然我們都把cross-entropy forward和backward fused在一起了,那何不連output linear的forward和backward也一起呢?
沒錯!我們今天進入output linear前大小是N × H,跟據liger kernel的經驗公式,我們會把N切成V / H等分,所以說
每一個mini-batch大小就變成,(N × H / V) × H
所以這組mini-batch經過output linear產生的logits大小就是(N × H / V) × V,接下來我們就可以運用上述的kernel,並透過迴圈依序處理V / H組mini-batch,而這個時候
記憶體使用量就降到了(N × H / V) × V = N × H
這邊假設
- batch size是16
- sequence length是2048
- hidden dimension是4096
- vocabulary size是131072
記憶體使用量就變成
4(FP32) × 16 × 2048 × 4096 = 0.5GB。
4. Cast on SRAM
當我們在訓練的時候,有時候會使用混精度運算,像是BF16或FP16,但算loss的時候都統一用FP32,所以
- 做output lienar是用半精度
- 算loss是用全精度
也就是我們會先把半精度的logits先轉成全精度再算loss,但實際上,
我們可以把精度轉換這件事搬到SMEM上面做,如此一來記憶體使用量再混精度的情況又可以再減半。
Summary
接下來我們總結一下運算操作流程:
- 把hidden state的tokens長度切成H / V組,
- 把mini-batch的hidden state輸入output linear計算出logits
- 把logits和labels從HBM load到SMEM,先把logits轉成全精度,接著算出global maxima和sumexp。
- 計算出loss,並把其從SMEM save回HBM。
- 把logits從HBM load到SMEM,先把logits轉成全精度,接著算出logits的gradient。
- 把logits的gradient轉回半精度,從SMEM save回HBM上logits的記憶體位置。
- 計算出這個mini-batch hidden state的gradient。
- 計算出這個mini-batch hidden state針對output linear的gradient,並把其累加在output linear的gradient。
- 回到第2個步驟,運行下一組mini-batch hidden state。
所以總結來說Liger Kernel用到了兩個迴圈
- Token Chunked Loop: 把hidden state的tokens長度切成H / V組,然後依序處理這些mini-batch
- Vocabulary Chunked Loop: 把logit沿著vocabulary的維度以65536長度為一組,從HBM Load到SMEM。
其他優化
Sequence Packing
針對上述我們可以看到liger kernel大大的減少了記憶體的使用量,但是其有一個蠻大的缺點就是他要把其拆成好幾個mini-batch然後用for迴圈計算,所以其速度一定會比較慢一點。
所以針對於此我們可以更進一步使用一個技巧sequence packing,也就是把padding的地方或是我們不想算loss的tokens移除掉,並把這些有意義的tokens合併在一起來計算loss,如此一來就可以大大減少for迴圈的次數,進而增快運算速度。
Cut-Cross-Entropy
在今年Apple有提出一個全新的Cross-Entropy的優化演算法,號稱比liger kernel更快、更省記憶體,另外Unsloth也很快的把這個演算法整合進了他們的github repo。
但因為篇幅的關係下次有機會再跟大家講解這篇吧!
結論
老實說隨著ChatGPT的出現,針對LLM的優化越來越多,有的是針對AI架構去優化,有的是針對Kernel去做優化,有的甚至直接設計一個針對Transformers架構運行的晶片。
而我們今天所講的就是針對運算的kernel來優化,那這邊大家可能會想那除了cross-entropy以外,還有其他地方可以優化嗎?當然!而且非常多,像是MLP、Attention、Norm、Embedding都可以優化,而Unsloth和Liger Kernel就是專注於LLM Kernel Fusion優化的公開github repo,但一樣礙於篇幅我們下次有空介紹給大家。
作者:劉智皓
LinkedIn:Chih-Hao Liu