Implementing batch normalization seems to be advanced tensorflow programming. We'll go over our current implementation below.
Train/Test behavior
Typically want the following behavior during train/test
- train:
- compute and use batch moments
- update exponential moving averages
- validation:
- use saved exponential moving average's (ema's) (no updates or computation of batch stastics)
- Final/Test:
- Ideally, once model trained, do one more pass through data - replace exponential moving averages with true averages over whole data (I think most people are lazy here, they just use the last ema's)
Computational Graph
Have to understand the tensorflow computational graph to some degree. If this was just python with numpy arrays instead of Tensorflow tensors, you could do something like
if in_train_mode: batchMean = compute_mean(input_ndarr) running_mean *= 0.9 running_mean += 0.1 * batchMean return batchMean else: return running_mean
where in_train_mode is a Bool flag, and for every batch of data, the above will get executed. In tensorflow, when you write something like
batchMean, batchVar = tf.nn.moments(inputTensor, axes=[0,1,2]) running_mean *= 0.9 running_mean += 0.1*batchMean= tf.add(tf.mul(momentum, self.running_mean), (1.0-momentum)*self.batchMean, name='new_running_mean')
Stack overflow activity: