weight normalization

Abstract

Weight normalization: a reparameterization of the weight vectors in a neural network that decouples the length of those weight vectors from their direction. Weight normalization does not introduce any dependencies between the training sets in a minibatch. It can be applied to RNN, DQN and GAN. It is a useful way to improve the conditioning of the optimization problem and speed up convergence of stochastic gradient descent.

Introduction

The practical success of first-order gradient based optimization is highly dependent on the curvature of the objective that is optimized. If the condition number of the Hessian matrix of the bojective at the optimum is low, the problem is said to exhibit pathological curvature, and first-order gradient descent will have trouble making progress .

Weight normalization

For an elementwise nonlinearity:

nonlinearity

After associating a loss function to one or more neuron outputs, such a neural netword is commonly trained by stochastic gradient descent in the parameters w, b of each neuron. In order to speef up the convergence of this optimization procedure, reparameterize each weight vector w in terms of a parameter vector v and a scalar parameter g and to o perform stochastic gradient descent with respect to those parameters instead.

weight_norm

||v|| denotes the Euclidean norm of v, fixing the Euclidean norm of the weight vector w: we now have ||w|| = g, independent of the parameters v . By decoupling the norm of the weight vector ( g) from the direction of the weight vector (v/||v||), we speed up convergence of our stochastic gradient descent optimization.

Gradients

Use standard stochastic gradient descent methods to obtain the gradient of a loss function L with respect to the new parameters v, g :

gradient

gradient_g

gradient_v

gradient_al

Weight normalization accomplishes two things: it scales the weight gradient by **g/||v||**,and it project the gradient away from the current weight vector.

Due to projecting away from w, the norm of v grows monotonically with the number of weight
updates when learning a neural network with weight normalization using standard gradient descent without momentum:

gradient_update

if ||&v||/||v||= c:

gradient_new_v

The rate of increase will depend on the the variance of the weight gradient. If our gradients are noisy, c will be high and the norm of v will quickly increase, which in turn will decrease the scaling factor g/||v||; If the norm of the gradients is small, we get sqrt(1+ c*c) approximately equal to 1.

pytorch weight_norm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch.utils.hooks as hooks
from torch.nn.parameter import Parameter
class (object):
def __init__(self, name, dim):
self.name =name
self.dim = dim
def computer_weight(self, module):
g = getattr(module, self.name+'_g')
v = getattr(module, self.name+'_v')
return (g/self.norm(v))*v
def norm(self, p):
"""Computes the norm over all dimensions except dim"""
if self.dim is None:
return p.norm()
if self.dim != 0:
p = p.transpose(0, self.dim)
output_size = (p.size(0),)+(1,)*(p.dim()-1)
p = p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
if self.dim != 0:
p = p.transpose(0, self.dim)
return p
def apply(modula, name, dim):
fn = WeightNorm(name, dim)
weight = getattr(module, name)
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters and express w as g/||v|| * v
module.register_parameter(name + '_g', Parameter(fn.norm(weight).data))
module.register_parameter(name + '_v', Parameter(weight.data))
setattr(module, name, fn.computer_weight(module))
handle = hooks.RemovableHandle(module._forward_pre_hooks)
module._forward_pre_hooks[handle.id] = fn
fn.handle = handle
return fn
def __call__(self, module, inputs):
setattr(module, self.name, self.computer_weight(module))
'''
example
'''
def weight_norm(module, name='weight', dim=0):
WeightNorm.apply(module, name, dim)
return module
'''
>>> import torch.nn as nn
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
Linear (20 -> 40)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])
'''