【[Tensorflow]從Pytorch到TF2的學習之路】所有文章:
- [Tensorflow]從Pytorch到TF2的學習之路 - Different Padding Algorithms
- [Tensorflow]從Pytorch到TF2的學習之路 - Training mode v.s. Inference mode
- [Tensorflow]從Pytorch到TF2的學習之路 - Custom Model & Custom training
前言
TF2.0中採用了和Pytorch相同的Eager Mode,並且使用了大量的Keras API,使得我們可以像Pytorch一樣透過建立動態圖來操作我們的運算的同時,能夠更有效率地來設計&訓練我們的深度學習模型。但在寫法上,Pytorch和TF2的到底有著什麼樣的差異呢?
這篇文章針對客製化模型和自定義訓練循環這兩個部分來撰寫,分別介紹使用Pytorch和TF2的不同,期望讓大家能夠快速地切換到不同的深度學習框架上。
客製化模型的不同之處
不囉So,咱們直接上code,先看一下兩個版本的程式碼,然後一併介紹他們的差異
首先先看Pytorch版本
0 | class CustomModel(torch.nn.Module): |
再來看TF2版本
0 | class CustomModel(tf.keras.Model): |
這邊簡單介紹一下差異的部分:
- Pytorch
- 在Pytorch中,如果要自定義一個模型,必須要繼承
torch.nn.Module
,然後實作__init__(self)
和forward(self)
兩個function - layer的定義是透過
torch.nn
來宣告
- 在Pytorch中,如果要自定義一個模型,必須要繼承
- TF2
- 而在TF2中,則是透過了Keras Model API來實作,因此必須繼承
tf.keras.Model
並且實作__init__(self)
和call(self)
兩個function- 實際上
tf.keras.Model
是繼承自tf.Module
,所以也可以直接繼承tf.Module
來實作自己的Model - 不過繼承
tf.keras.Model
還可以使用Model.fit()
,Model.evaluate()
和Model.save()
這些操作,比方說官方文件的範例code,繼承tf.Module
的話就要自己寫訓練循環
- 實際上
- layer的定義是透過
tf.keras.layers
來宣告 tf.Module
以及被他繼承的類別皆提供很方便的properties: variables and trainable_variables,使得我們能夠很好的管理模型的參數
- 而在TF2中,則是透過了Keras Model API來實作,因此必須繼承
最後,在一篇教學文章中看到了下面一個問題: 为什么模型类是重载 call() 方法而不是 __call__() 方法?,覺得這個觀念也值得一提,所以把內容直接複製上來
在 Python 中,对类的实例 myClass 进行形如 myClass() 的调用等价于 myClass.__call__() (具体请见本章初 “前置知识” 的 __call__() 部分)。那么看起来,为了使用 ypred = model(X) 的形式调用模型类,应该重写 \_call_() 方法才对呀?原因是 Keras 在模型调用的前后还需要有一些自己的内部操作,所以暴露出一个专门用于重载的 call() 方法。 tf.keras.Model 这一父类已经包含 \_call__() 的定义。 __call__() 中主要调用了 call() 方法,同时还需要在进行一些 keras 的内部操作。这里,我们通过继承 tf.keras.Model 并重载 call() 方法,即可在保持 keras 结构的同时加入模型调用的代码。
自定義訓練循環的不同之處
再來來看一下一個訓練的起手式寫法上,兩者有著什麼樣的差異:
首先從Pytorch看起,在[Pytorch]zero_grad()和backward()使用技巧中,我們提到了Pytorch訓練的起手式:
0 | for idx, (batch_x, batch_y) in enumerate(data_loader): |
短短幾行裡面做了很多事情:
- 將data傳入model進行forward propagation
- 計算loss
- 清空前一次的gradient
- 根據loss進行back propagation,計算gradient
- 做gradient descent
接下來我們來看一下TF2.0是如何撰寫的,儘管繼承tf.keras.Model
的Model可以無腦model.fit()
和model.predict()
,但我們先看如果要自己寫訓練循環時應該怎麼做:
0 | for batch_idx in range(batch_num): |
TF2做了以下的事情:
- 透過
tf.GradientTape()
紀錄並建構正向傳播的計算圖被包覆的操作 - 將data傳入model進行forward propagation
- 計算loss
- 根據loss計算對model.variables的梯度(使用
tf.GradientTape()
) - 做gradient descent更新梯度
- 透過
tf.keras.optimizer
更新,並且apply_gradients
需要接收grad_and_vars參數(需要將gradient和參數用zip包起來餵進去)
- 透過
對於tf.GradientTape()
,我覺得知乎上面這篇文章: tensorflow计算图与自动求导——tf.GradientTape寫得很詳細,以下節錄幾個重點:
- 預設watch_accessed_variables=True,也就是會記錄所有可訓練變數(trainable=True的tf.Variable)
- 也可以用
tape.watch()
來觀測特定的變數,即使他trainable=False - 預設persistent = False,也就是和Pytorch一樣,計算圖一但執行完就會銷毀(為了節省記憶體),如果需要保留計算圖可以設置persistent = True
- 在用
tape.gradient()
計算梯度時,如果計算失敗(例如計算圖中兩個點根本沒有連接),預設會返回None,但其實也可以指定參數unconnected_gradients=tf.UnconnectedGradients.ZERO來設置返回0
到這裡,了解了模型的寫法和起手式的寫法後,原本對於Pytorch較熟悉的人應該就能較輕鬆的寫出一個可以跑的TF2程式(也希望對於TF2的使用者這篇能夠幫助你快速熟悉Pytorch的寫法)。
References
- Making new Layers & Models via subclassing
- 简单粗暴 TensorFlow 2
- [Pytorch]zero_grad()和backward()使用技巧
- tensorflow计算图与自动求导——tf.GradientTape