Batch Normalization

Implementing batch normalization seems to be advanced tensorflow programming. We'll go over our current implementation below.

Train/Test behavior

Typically want the following behavior during train/test

  • train:
    • compute and use batch moments
    • update exponential moving averages
  • validation:
    • use saved exponential moving average's (ema's) (no updates or computation of batch stastics)
  • Final/Test:
    • Ideally, once model trained, do one more pass through data - replace exponential moving averages with true averages over whole data (I think most people are lazy here, they just use the last ema's)

Computational Graph

Have to understand the tensorflow computational graph to some degree. If this was just python with numpy arrays instead of Tensorflow tensors, you could do something like this in the code

if in_train_mode:
  batchMean = compute_mean(input_ndarr)
  running_mean *= 0.9
  running_mean += 0.1 * batchMean
  return batchMean 
else:
  return running_mean

where in_train_mode is a Bool flag, and for every batch of data, the above will get executed. In tensorflow, when you write some Python code like

batchMean, batchVar = tf.nn.moments(inputTensor, axes=[0,1,2])  
# say inputTensor is 4D, batch sample, row, column, channel

running_mean *= 0.9
running_mean += 0.1*batchMean

It only gets executed once, but what it does, is add 'ops' to the 'computational graph'. The only op we are going to execute repeatedly is either one of

  • the training op, to minimize the optimizer,
  • the predict op, to make predictions with a trained model

Boolean Logic and the Computational Graph

We don't get to write Python if/else boolean logic that gets executed during the training ops, all we can do is add tensorflow ops to the computational graph. Tensorflow has a

tf.cond(bool, callableA, callableB) 

op that one can use instead. Here, one can use a tensorflow placeholder for the bool, and then feed in True or False when running the training or prediciton operations through the session.

Control Dependencies

One way to do batch normalization is to make the training op depend on, or trigger, the updates to the running mean, while using the batch mean for the computation. Creating the op to update the running_mean like above, but not tieing it to the training training, will mean it never gets executed. Following some stackoverflow threads, I have tried this, using the tensorflow 'control_dependencies' but have not yet got it to work - we can look at the broken code if we like: bn_broke_BatchNormalization.py things to note

  • Trying to emulate keras BatchNormalization 'modes', but actually, I mixed up mode 1 and 3 (keras mode 1 is what I called mode 3)
  • The problem is that in test mode, the running mean and std still get updated. Here is tensorflow documentation on control dependencies

Adding Ops to Training

Another way to do this is

  • to keep track of the additional ops needed to update the running mean and average
  • explicitly invoke them during every training step
  • Use a tensorflow boolean placeholder to keep track of training vs. testing
    • set this appropriately during invokation of the code

If you look at stackoverflow posts (links below), you see there is a better, more automatic way to do some of this, there appears to be a training group of ops, and you should be able to add the running mean/std updates into this group by accessing the default computational graph - then you can just run the normal training op of minimizing the optimizer. The computational graph is a sort of global variable that is in scope - this appears to be the real tensorflow way to do this.

Stack overflow activity:

Saving Restoring Models in TensorFlow

Save/Restore is straightforward, documented better in tensorflow. With TensorFlow, steps seem to be

  • create the session
  • creating the op to initializing all the variables
  • create the tf.Saver() object, we'll call it saver
  • run the init op
    • saver automatically ties ops to computational graph to save variables
  • call saver.save()
  • For restoring
    • create initialize variables op
    • create saver
    • run init op
    • call saver.restore()

Code

  • A BatchNormalization class that keeps track of additional training ops: BatchNormalization.py
  • Driver program: ex06_tf_batchnorm.py. Note:
    • use of a new boolean flag placeholder
    • model class keeps list of instances of the batch normalization classes
    • model returns trainOps by querying all the batch normalization instances
    • trainOps are added to ops used during training
  • No labels