[Tensorflow]從Pytorch到TF2的學習之路 - Training mode v.s. Inference mode

Posted by John on 2020-08-16
Words 1.2k and Reading Time 5 Minutes
Viewed Times

「這個故事是描寫一位從原本在寫Pytorch的熱血少年,因為工作需求所以開始跳槽Tensorflow2,立志寫出厲害的TF2程式碼,在台灣締造的偉大抒情史詩」 (改寫自烘焙王開頭旁白)

【[Tensorflow]從Pytorch到TF2的學習之路】所有文章:

前言

在pytorch中會使用train(), eval()來控制一些在訓練(training mode)/測試階段(inference mode)下執行不同的操作,比方說Dropout, BatchNormalization

  • Dropout在inference mode下就不會在屏蔽neuron
  • BatchNormalization在inference mode下會使用training時得到的平均值作為alpha, beta的參數

但TF沒有這兩個function,那應該要怎麼去操控training mode / inference mode呢? 本文透過Dropout, BatchNormalization的TF官方文件搭配Source code來深入研究。

Dropout: training參數

在TF中則是透過參數來控制,例如Dropout可以透過training參數來控制,在官方文件中寫到:

Note that the Dropout layer only applies when training is set to True such that no values are dropped during inference. When using model.fit, training will be appropriately set to True automatically, and in other contexts, you can set the kwarg explicitly to True when calling the layer.

也就是說透過設置training=True來開啟Dropout功能,在testing的時候設置training=False即可。

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
tf.random.set_seed(0)
layer = tf.keras.layers.Dropout(.2, input_shape=(2,))
data = np.arange(10).reshape(5, 2).astype(np.float32)
print(data)
# [[0. 1.]
# [2. 3.]
# [4. 5.]
# [6. 7.]
# [8. 9.]]

outputs = layer(data, training=True)
print(outputs)
# [[ 0. 1.25]
# [ 2.5 3.75]
# [ 5. 6.25]
# [ 7.5 8.75]
# [10. 0. ]], shape=(5, 2), dtype=float32)

此外,文件中也提到training這個參數跟trainable這個參數是不同的,trainable是只說在propagation過程中要不要更新參數,但Dropout並沒有參數,所以Dropout設置trainable是沒有用的

(This is in contrast to setting trainable=False for a Dropout layer. trainable does not affect the layer’s behavior, as Dropout does not have any variables/weights that can be frozen during training.)

最後,training參數的預設值又是什麼呢?

首先先看Dropout的文件中提到:

When using model.fit, training will be appropriately set to True automatically, and in other contexts, you can set the kwarg explicitly to True when calling the layer.

除了呼叫model.fit()的情況下,不然需要手動設置training=True

所以這樣是說training預設是False嗎? 在tensorflow/tensorflow/python/keras/engine/base_layer.py中有提到對於training的設置考慮順序

945
946
947
948
949
950
951
952
953
# Training mode for `Layer.call` is set via (in order of priority):
# (1) The `training` argument passed to this `Layer.call`, if it is not None
# (2) The training mode of an outer `Layer.call`.
# (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set)
# (4) Any non-None default value for `training` specified in the call
# signature
# (5) False (treating the layer as if it's in inference)
args, kwargs, training_mode = self._set_training_mode(
args, kwargs, call_context)

BatchNormalization: trainable參數

而對於BatchNormalization,則是透過trainable來控制(因為BN有需要訓練的參數),而在BN中預設trainable=True

0
1
2
3
4
5
6
7
8
tf.keras.layers.BatchNormalization(
axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
moving_mean_initializer='zeros', moving_variance_initializer='ones',
beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
gamma_constraint=None, renorm=False, renorm_clipping=None, renorm_momentum=0.99,
fused=None, trainable=True, virtual_batch_size=None, adjustment=None, name=None,
**kwargs
)

官方文件中也提到:

About setting layer.trainable = False on a BatchNormalization layer:
The meaning of setting layer.trainable = False is to freeze the layer, i.e. its internal state will not change during training: its trainable weights will not be updated during fit() or train_on_batch(), and its state updates will not be run.

Usually, this does not necessarily mean that the layer is run in inference mode (which is normally controlled by the training argument that can be passed when calling a layer). “Frozen state” and “inference mode” are two separate concepts.

However, in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case.

儘管一般來說trainable=False只是代表凍結參數更新(Frozen state和inference mode是兩個不同的概念),但對於BN來說則比較特殊,BN中設置trainable = False代表使用inference mode,也就是BN中的兩個參數alpha, beta會使用平均的值

  • 注意這是在TF2之後才引入的,對於TF1中,trainable=False只是凍結了參數,並不會變成inference mode(也就是不會使用平均的值帶入BN參數)

至於trainable參數預設是什麼呢?

其他關於trainable參數要注意的地方

  • 對某一個layer設置trainable=True會連帶影響內部的所有layer的trainable參數
  • 如果再compile()後才更改trainable參數,那要等到再次呼叫compile()才會更新

References


>