You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 2 Next »

Implementing batch normalization seems to be advanced tensorflow programming. We'll go over our current implementation below.

Train/Test behavior

Typically want the following behavior during train/test

  • train:
    • compute and use batch moments
    • update exponential moving averages
  • validation:
    • use saved exponential moving average's (ema's) (no updates or computation of batch stastics)
  • Final/Test:
    • Ideally, once model trained, do one more pass through data - replace exponential moving averages with true averages over whole data (I think most people are lazy here, they just use the last ema's)

Computational Graph

Have to understand the tensorflow computational graph to some degree. If this was just python with numpy arrays instead of Tensorflow tensors, you could do something like

if in_train_mode:
  batchMean = compute_mean(input_ndarr)
  running_mean *= 0.9
  running_mean += 0.1 * batchMean
  return batchMean
else:
  return running_mean

where in_train_mode is a Bool flag, and for every batch of data, the above will get executed. In tensorflow, when you write something like

            batchMean, batchVar = tf.nn.moments(inputTensor, axes=[0,1,2])
            running_mean *= 0.9
            running_mean += 0.1*batchMean= tf.add(tf.mul(momentum, self.running_mean), (1.0-momentum)*self.batchMean, name='new_running_mean')

 

Stack overflow activity:

  • No labels