-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathbatch_norm_manual.py
114 lines (91 loc) · 3.61 KB
/
batch_norm_manual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Comparison of manual BatchNorm2d layer implementation in Python and
nn.BatchNorm2d
@author: ptrblck
"""
import torch
import torch.nn as nn
def compare_bn(bn1, bn2):
err = False
if not torch.allclose(bn1.running_mean, bn2.running_mean):
print('Diff in running_mean: {} vs {}'.format(
bn1.running_mean, bn2.running_mean))
err = True
if not torch.allclose(bn1.running_var, bn2.running_var):
print('Diff in running_var: {} vs {}'.format(
bn1.running_var, bn2.running_var))
err = True
if bn1.affine and bn2.affine:
if not torch.allclose(bn1.weight, bn2.weight):
print('Diff in weight: {} vs {}'.format(
bn1.weight, bn2.weight))
err = True
if not torch.allclose(bn1.bias, bn2.bias):
print('Diff in bias: {} vs {}'.format(
bn1.bias, bn2.bias))
err = True
if not err:
print('All parameters are equal!')
class MyBatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(MyBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
def forward(self, input):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
return input
# Init BatchNorm layers
my_bn = MyBatchNorm2d(3, affine=True)
bn = nn.BatchNorm2d(3, affine=True)
compare_bn(my_bn, bn) # weight and bias should be different
# Load weight and bias
my_bn.load_state_dict(bn.state_dict())
compare_bn(my_bn, bn)
# Run train
for _ in range(10):
scale = torch.randint(1, 10, (1,)).float()
bias = torch.randint(-10, 10, (1,)).float()
x = torch.randn(10, 3, 100, 100) * scale + bias
out1 = my_bn(x)
out2 = bn(x)
compare_bn(my_bn, bn)
torch.allclose(out1, out2)
print('Max diff: ', (out1 - out2).abs().max())
# Run eval
my_bn.eval()
bn.eval()
for _ in range(10):
scale = torch.randint(1, 10, (1,)).float()
bias = torch.randint(-10, 10, (1,)).float()
x = torch.randn(10, 3, 100, 100) * scale + bias
out1 = my_bn(x)
out2 = bn(x)
compare_bn(my_bn, bn)
torch.allclose(out1, out2)
print('Max diff: ', (out1 - out2).abs().max())