[Pytorch]zero_grad()和backward()使用技巧

Posted by John on 2020-05-28
Words 1.6k and Reading Time 6 Minutes
Viewed Times

前言

今天來聊聊Pytorch的gradient update這個寫法。對Pytorch不陌生的朋友應該知道,一個pytorch model training的起手式大概長這個樣子:

1
2
3
4
5
6
for idx, (batch_x, batch_y) in enumerate(data_loader):
output = model(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

這段code看似簡單,實際上他做了下列這些事情:

  1. 將data傳入model進行forward propagation
  2. 計算loss
  3. 清空前一次的gradient
  4. 根據loss進行back propagation,計算gradient
  5. 做gradient descent

如果用numpy純刻forward/backward是一件很累的事情,所以deep learning framework如Pytorch, tensorflow都幫你做完了,大感恩。

這篇文章主要是針對上面這個起手式來討論一些有的沒的,在往下看之前也可以先想想看這些問題:

  • Pytorch為什麼要手動將gradient清空(optimizer.zero_grad()),不能把這一步自動化嗎?
  • 同理,為什麼gradient也要手動計算(loss.backward()),不能每一次forward做完就自動算出對應的gradient嗎?

Gradient accumulation

Pytorch不幫你自動清空gradient,而是要你呼叫optimizer.zero_grad()來做這件事是因為,這樣子你可以有更大的彈性去做一些黑魔法,畢竟,誰規定每一次iteration都要清空gradient?

試想你今天GPU的資源就那麼小,可是你一定要訓練一個很大的model,然後如果batch size不大又train不起來,那這時候該怎麼辦?

雖然沒有課金解決不了的事情,如果有,那就多課一點…不是,這邊提供另外一種設計思維:

你可以將你的model每次都用小的batch size去做forward/backward,但是在update的時候是多做幾個iteration在做一次。

這個想法就是梯度累加(gradient accumulation),也就是說我們透過多次的iteration累積backward的loss,然後只對應做了一次update,間接的做到了大batch size時候的效果。

gradient accumulation寫起來也不難,我們來看一下大概的寫法:

1
2
3
4
5
6
7
8
9
10
for idx, (batch_x, batch_y) in enumerate(data_loader):
output = model(batch_x)
loss = criterion(output, batch_y)

loss = loss / accumulation_step
loss.backward()

if (idx % accumulation_step) == 0:
optimizer.step() # update
optimizer.zero_grad() # reset

簡單吧? 當我執行了accumulation_step次之後,才進行gradient descent然後清空gradient。

注意這邊每一次iteration都還是有呼叫loss.backward(),所以每一次迭代的時候gradient都會一直被累加,直到最後被呼叫了optimizer.zero_grad()才將他們清空。

  • 不過要注意的是,在進行這個trick的時候,learning rate也要對應的做出調整。

現在我們來找一些source code看看這個技巧唄! 知名頂頂的BERT裡面也有用到這一個技巧,有用過的就知道他有一個gradient_accumulation_steps參數,那在code裡面怎麼寫呢? 我們來看一下transformers/src/transformers/trainer.py

473
474
475
476
477
478
479
480
481
482
483
484
485
486
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
len(epoch_iterator) <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator)
):
if self.args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

if is_tpu_available():
xm.optimizer_step(optimizer)
else:
optimizer.step()

Gradient accumulation in multi-task training

接下來我們看另一個情境,假設我們今天在做一個multi-task的訓練,那不同的task我們就會有不同的loss,於是乎我們想把這些loss加起來一起做然後得到最後的loss,在去做backward。

這樣的話,寫法大概會是下面這樣:

1
2
3
4
5
6
7
8
9
10
11
12
for idx, (batch_x, batch_y) in enumerate(data_loader):
output1 = model1(batch_x)
loss1 = criterion(output1, batch_y)

output2 = model2(batch_x)
loss2 = criterion(output2, batch_y)

loss = loss1 + loss2

optimizer.zero_grad()
loss.backward()
optimizer.step()

要記得在變數計算的時候背後都是有著一張graph的,所以在第2行loss1計算的背後會得到一張graph,在第5行loss2也會有一張graph,直到第8行才將這兩張graph進行了合併成一張,然後在backward()的時候將梯度更新並釋放掉graph。

於是,在6行到第8行中間,device同時儲存了兩張graph,那如果loss一多,這樣對於memory的消耗就很大了。

但是套用gradient accumulation的概念,我們可以只存一張graph就完成這些計算,考慮以下程式碼:

1
2
3
4
5
6
7
8
9
10
11
12
for idx, (batch_x, batch_y) in enumerate(data_loader):
optimizer.zero_grad()

output1 = model1(batch_x)
loss1 = criterion(output1, batch_y)
loss1.backward()

output2 = model2(batch_x)
loss2 = criterion(output2, batch_y)
loss2.backward()

optimizer.step()

我們先針對loss1去進行backward,然後算完後就將graph釋放掉然後將gradient存在變數中,之後再去計算loss2,所以從頭到尾device同個時間上最多只會有一張graph的memory。

如果這樣看不懂為什麼可以work的話,我們來看一下數學公式QQ

把loss相加在去算gradient跟兩個分開來做是等價的,然後要記得我們在前一節提過的,沒做optimizer.zero_grad()前gradient是會被累加的,所以在第12行我們只需要呼叫一次optimizer.step()就可以做到同時update兩個loss。

Truncated Back Propagation Through Time in RNN

(這一段其實自己也不太熟,如果有錯誤再麻煩指正~)
在RNN訓練中,back propagation的方法叫做Back Propagation Through Time(BPTT),因為backward會牽涉到前一個時間的hidden state。

並且,由於每一個hidden state都跟前一個時刻有關,RNN在input過長的時候會出現gradient vanish/explosion的問題,解決方法向是LSTM或是GRU等具有gated mechanism的模型。

不過在訓練的技巧上有另外一種可以幫助解決這個問題的方法,那就是Truncated Back Propagation Through Time(TBPTT),其實他有很多種case,細節可以參閱A Gentle Introduction to Backpropagation Through Time

不過這裡講的是其中一種case: 也就是為了避免input過長造成上述的問題,我們不將所有的input都用來計算gradient,也就是,對於某個長度之後的gradient我們都捨棄掉不理他。

1
2
3
4
5
6
7
8
9
10
11
12
# non-truncated
for t in range(T):
out = model(out)
out.backward()

# truncated to the last K timesteps
for t in range(T):
out = model(out)
if T - t == K:
out.backward()
out.detach()
out.backward()

out.detach()會在K timestamp的時候截斷graph,然後後面就從該變數之後就不會在做backward了,所以就實現了TBPTT

Reference


>