-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathadaptive_batchnorm.py
119 lines (99 loc) · 3.76 KB
/
adaptive_batchnorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Implementation of Adaptive BatchNorm
@author: ptrblck
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
# Globals
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed = 2809
batch_size = 10
lr = 0.01
log_interval = 10
epochs = 10
torch.manual_seed(seed)
class AdaptiveBatchNorm2d(nn.Module):
'''
Adaptive BN implementation using two additional parameters:
out = a * x + b * bn(x)
'''
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(AdaptiveBatchNorm2d, self).__init__()
self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine)
self.a = nn.Parameter(torch.FloatTensor(1, 1, 1, 1))
self.b = nn.Parameter(torch.FloatTensor(1, 1, 1, 1))
def forward(self, x):
return self.a * x + self.b * self.bn(x)
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1,
out_channels=10,
kernel_size=5)
self.conv1_bn = AdaptiveBatchNorm2d(10)
self.conv2 = nn.Conv2d(in_channels=10,
out_channels=20,
kernel_size=5)
self.conv2_bn = AdaptiveBatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1_bn(self.conv1(x)), 2))
x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test():
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size,
shuffle=True)
model = MyNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.5)
for epoch in range(1, epochs + 1):
train(epoch)
test()