在PyTorch中,我们常在网络中遇到BN层,基本单元如果如下所示,则可以离线将Conv和BN进行fusion,从而在inference时不必计算BN。
在PyTorch中BatchNorm2d
的定义如下:
1 |
class .nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
参数列表如下:
1 |
dict_keys(['weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked']) |
与公式中变量的对应关系为:
Param | Variable name |
---|---|
eps | $epsilon$ |
weight | $gamma$ |
bias | $beta$ |
running_mean | $mu$ |
running_var | $sigma^2$ |
Key Idea
Convolution layer:
$$
Y = X * w + b
$$
BatchNorm layer:
$$
Y = frac{X - mu}{sqrt{sigma^2 + epsilon}} gamma + beta
$$
Convolution and BatchNorm fusion:
$$
Y = frac{(X * w + b) - mu}{sqrt{sigma^2 + epsilon}} gamma + beta \
Y = frac{gamma w}{sqrt{sigma^2 + epsilon}}X + frac{b - mu}{sqrt{sigma^2 + epsilon}}gamma + beta \
w_{merged} = aw; quad
b_{merged} = (b - mu)a + beta; quad
a = frac{gamma}{sqrt{sigma^2 + epsilon}}
$$
Uniform Quantization
$$r = Sq$$ where constants $S$ is quantization parameter, intergers $q$ are mapped to real numbers $r$.
param | type | |
---|---|---|
S | Scale | fp32 |
q | quantize | int4 for w int5 for a |
Quantized convolution layer:
$$
Y = X * Sw_q + b \
where: w = Sw_q
$$
Quantized convolution and BatchNorm fusion:
$$
Y = frac{(X * Sw_q + b) - mu}{sqrt{sigma^2 + epsilon}} gamma + beta \
Y = frac{gamma Sw_q}{sqrt{sigma^2 + epsilon}}X + frac{b - mu}{sqrt{sigma^2 + epsilon}}gamma + beta \
S_{merged} = aS; quad
w_{merged} = aw; quad
b_{merged} = (b - mu)a + beta; quad
a = frac{gamma}{sqrt{sigma^2 + epsilon}}
$$
近期评论