BN的过程,具体是怎样计算均值和方差的?
下来找到部分相关代码如下:(\tensorflow\python\layers\normalization.py)
def call(self, inputs, training=False):
# First, compute the axes along which to reduce the mean / variance,
# as well as the broadcast shape to be used for all parameters.
input_shape = inputs.get_shape()
ndim = len(input_shape)
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis].value
# Determines whether broadcasting is needed.
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
scale, offset = self.gamma, self.beta