Implementing batch normalization seems to be advanced tensorflow programming. We'll go over our current implementation below.
Train/Test behavior
Typically want the following behavior during train/test
- train:
- compute and use batch moments
- update exponential moving averages
- validation:
- use saved exponential moving average's (ema's) (no updates or computation of batch stastics)
- Final/Test:
- Ideally, once model trained, do one more pass through data - replace exponential moving averages with true averages over whole data (I think most people are lazy here, they just use the last ema's)
...
Have to understand the tensorflow computational graph to some degree. When you do:If this was just python with numpy arrays instead of Tensorflow tensors, you could do something like
Code Block |
---|
if in_train_mode:
batchMean = compute_mean(input_ndarr)
running_mean *= 0.9
running_mean += 0.1 * batchMean
return batchMean
else:
return running_mean |
where in_train_mode is a Bool flag, and for every batch of data, the above will get executed. In tensorflow, when you write something like
Code Block | ||
---|---|---|
| ||
batchMean, batchVar = tf.nn.moments(inputTensor, axes=[0,1,2])
running_mean *= 0.9
running_mean += 0.1*batchMean= tf.add(tf.mul(momentum, self.running_mean), (1.0-momentum)*self.batchMean, name='new_running_mean') | ||
Code Block | ||
|
Stack overflow activity: