前言
今天來聊聊Pytorch的gradient update這個寫法。對Pytorch不陌生的朋友應該知道,一個pytorch model training的起手式大概長這個樣子:1
2
3
4
5
6for idx, (batch_x, batch_y) in enumerate(data_loader):
output = model(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
這段code看似簡單,實際上他做了下列這些事情:
- 將data傳入model進行forward propagation
- 計算loss
- 清空前一次的gradient
- 根據loss進行back propagation,計算gradient
- 做gradient descent
如果用numpy純刻forward/backward是一件很累的事情,所以deep learning framework如Pytorch, tensorflow都幫你做完了,大感恩。
這篇文章主要是針對上面這個起手式來討論一些有的沒的,在往下看之前也可以先想想看這些問題:
- Pytorch為什麼要手動將gradient清空(
optimizer.zero_grad()
),不能把這一步自動化嗎? - 同理,為什麼gradient也要手動計算(
loss.backward()
),不能每一次forward做完就自動算出對應的gradient嗎?
Gradient accumulation
Pytorch不幫你自動清空gradient,而是要你呼叫optimizer.zero_grad()
來做這件事是因為,這樣子你可以有更大的彈性去做一些黑魔法,畢竟,誰規定每一次iteration都要清空gradient?
試想你今天GPU的資源就那麼小,可是你一定要訓練一個很大的model,然後如果batch size不大又train不起來,那這時候該怎麼辦?
雖然沒有課金解決不了的事情,如果有,那就多課一點…不是,這邊提供另外一種設計思維:
你可以將你的model每次都用小的batch size去做forward/backward,但是在update的時候是多做幾個iteration在做一次。
這個想法就是梯度累加(gradient accumulation),也就是說我們透過多次的iteration累積backward的loss,然後只對應做了一次update,間接的做到了大batch size時候的效果。
gradient accumulation寫起來也不難,我們來看一下大概的寫法:
1 | for idx, (batch_x, batch_y) in enumerate(data_loader): |
簡單吧? 當我執行了accumulation_step次之後,才進行gradient descent然後清空gradient。
注意這邊每一次iteration都還是有呼叫loss.backward()
,所以每一次迭代的時候gradient都會一直被累加,直到最後被呼叫了optimizer.zero_grad()
才將他們清空。
- 不過要注意的是,在進行這個trick的時候,learning rate也要對應的做出調整。
現在我們來找一些source code看看這個技巧唄! 知名頂頂的BERT裡面也有用到這一個技巧,有用過的就知道他有一個gradient_accumulation_steps
參數,那在code裡面怎麼寫呢? 我們來看一下transformers/src/transformers/trainer.py
473 | if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( |
Gradient accumulation in multi-task training
接下來我們看另一個情境,假設我們今天在做一個multi-task的訓練,那不同的task我們就會有不同的loss,於是乎我們想把這些loss加起來一起做然後得到最後的loss,在去做backward。
這樣的話,寫法大概會是下面這樣:1
2
3
4
5
6
7
8
9
10
11
12for idx, (batch_x, batch_y) in enumerate(data_loader):
output1 = model1(batch_x)
loss1 = criterion(output1, batch_y)
output2 = model2(batch_x)
loss2 = criterion(output2, batch_y)
loss = loss1 + loss2
optimizer.zero_grad()
loss.backward()
optimizer.step()
要記得在變數計算的時候背後都是有著一張graph的,所以在第2行loss1計算的背後會得到一張graph,在第5行loss2也會有一張graph,直到第8行才將這兩張graph進行了合併成一張,然後在backward()
的時候將梯度更新並釋放掉graph。
於是,在6行到第8行中間,device同時儲存了兩張graph,那如果loss一多,這樣對於memory的消耗就很大了。
但是套用gradient accumulation的概念,我們可以只存一張graph就完成這些計算,考慮以下程式碼:1
2
3
4
5
6
7
8
9
10
11
12for idx, (batch_x, batch_y) in enumerate(data_loader):
optimizer.zero_grad()
output1 = model1(batch_x)
loss1 = criterion(output1, batch_y)
loss1.backward()
output2 = model2(batch_x)
loss2 = criterion(output2, batch_y)
loss2.backward()
optimizer.step()
我們先針對loss1去進行backward,然後算完後就將graph釋放掉然後將gradient存在變數中,之後再去計算loss2,所以從頭到尾device同個時間上最多只會有一張graph的memory。
如果這樣看不懂為什麼可以work的話,我們來看一下數學公式QQ
把loss相加在去算gradient跟兩個分開來做是等價的,然後要記得我們在前一節提過的,沒做optimizer.zero_grad()
前gradient是會被累加的,所以在第12行我們只需要呼叫一次optimizer.step()
就可以做到同時update兩個loss。
Truncated Back Propagation Through Time in RNN
(這一段其實自己也不太熟,如果有錯誤再麻煩指正~)
在RNN訓練中,back propagation的方法叫做Back Propagation Through Time(BPTT),因為backward會牽涉到前一個時間的hidden state。
並且,由於每一個hidden state都跟前一個時刻有關,RNN在input過長的時候會出現gradient vanish/explosion的問題,解決方法向是LSTM或是GRU等具有gated mechanism的模型。
不過在訓練的技巧上有另外一種可以幫助解決這個問題的方法,那就是Truncated Back Propagation Through Time(TBPTT),其實他有很多種case,細節可以參閱A Gentle Introduction to Backpropagation Through Time。
不過這裡講的是其中一種case: 也就是為了避免input過長造成上述的問題,我們不將所有的input都用來計算gradient,也就是,對於某個長度之後的gradient我們都捨棄掉不理他。
1 | # non-truncated |
out.detach()
會在K timestamp的時候截斷graph,然後後面就從該變數之後就不會在做backward了,所以就實現了TBPTT
Reference
- PyTorch中在反向传播前为什么要手动将梯度清零?
- [NLP] RNN 前向传播、延时间反向传播 BPTT 、延时间截断反向传播 TBTT
- A Gentle Introduction to Backpropagation Through Time
- Correct way to do backpropagation through time?