前言
最近在使用Pytorch寫RNN相關的模型,然後因為實驗室有兩張GPU可以用,所以我就想把model放到兩張GPU平行處理,參考Pytorch document就知道其實用法很簡單,只要使nn.DataParallel(model)
就可以了,不過在使用RNN相關模型的時候有一些問題要注意,既然最近遇到了就順便記錄一下,以免以後又遇到重複的問題。
坑-1
當使用RNN相關的model搭配nn.DataParallel()
時會出現下列warning:
RuntimeWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
參照網路上的說法,由於將model的參數放到gpu上的時候不保證放置的memory位置一定是連續的,所以可能會有fragmentation現象,這會造成效能的降低,透過flatten_parameters()
可以使得model的參數在gpu memory上的位置是連續的。
而這樣的情況也同樣地出現在使用多張gpu訓練,也就是呼叫nn.DataParallel()
的情況下,由於nn.DataParallel()
做的事情其實就是把model放到多張gpu上一起訓練,所以單張gpu上面會遇到fragmentation的問題在多張上也會遇到,所以同樣的需要flatten_parameters()
。
好,大概知道warning產生的原因後,那flatten_parameters()
到底要怎麼用、加在哪裡呢?
建議是加在model的forward()
的第一行,如此一來當model被放在多張gpu上訓練時,因為每個gpu上的model都會呼叫forward()
,所以也就都會呼叫到flatten_parameters()
這個function。
用法大概如下:
1 | class Model(nn.Module): |
坑-2
接下來要講的坑是當使用RNN系列搭配pad_packed_sequence()
時會遇到的問題,這邊沒有打算要介紹pack_padded_sequence()
和pad_packed_sequence()
,以後有時間再說(隨然通常這樣說最後都不會寫)(20200720更新: 哼!我有寫了! 看這: [Pytorch]Pack the data to train variable length sequences),不過如果你有在用這兩個function搭配RNN和DataParallel那就可能會遇到下面的Error:
RuntimeError: Gather got an input of invalid size: (blalblabla..反正就是size對不上)
一開始百思不得其解,不斷檢查自己的model的shape發現都沒問題,沒使用nn.DataParallel()時單獨在一張gpu上跑也沒問題,去google一開始也不知道下什麼關鍵字,不過後來查了一陣子終於找到問題:原來是pad_packed_sequence()
的坑。
這個問題主要是因為將data放到不同的gpu上跑時,由於使用了pack_padded_sequence()和pad_packed_sequence()
,每個batch的長度都是不固定的,在每張gpu上執行pad_packed_sequence()
時,會取它當下batch的最大長度來對其他句子進行padding,這時因為每個gpu上data不同導致當下的最大長度都會不同,在gather的時候就會產生維度不匹配的問題。
所以解決方法是,在使用pad_packed_sequence()
時要額外帶入一個參數告訴當下最長的長度是多少,大概像這樣寫:
1 | class MyModule(nn.Module): |
咦,你說這code怎麼有點眼熟?其實這是pytorch document上的,不過遇到問題當下不知道用什麼keyword去查因此費了不少力氣,現在把它記錄下來以後才不會重道覆轍。
Reference
- Pytorch document
- Why do we need “flatten_parameters” when using RNN with DataParallel
- RuntimeError】Gather got an input of invalid size【DataParallel问题】
- My recurrent network doesn’t work with data parallelism