-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathpointnet2_segmentation.py
144 lines (111 loc) · 4.95 KB
/
pointnet2_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os.path as osp
import torch
import torch.nn.functional as F
from pointnet2_classification import GlobalSAModule, SAModule
from torchmetrics.functional import jaccard_index
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, knn_interpolate
from torch_geometric.typing import WITH_TORCH_CLUSTER
from torch_geometric.utils import scatter
if not WITH_TORCH_CLUSTER:
quit("This example requires 'torch-cluster'")
category = 'Airplane' # Pass in `None` to train on all categories.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
transform = T.Compose([
T.RandomJitter(0.01),
T.RandomRotate(15, axis=0),
T.RandomRotate(15, axis=1),
T.RandomRotate(15, axis=2)
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(path, category, split='trainval', transform=transform,
pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test',
pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True,
num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=12, shuffle=False,
num_workers=6)
class FPModule(torch.nn.Module):
def __init__(self, k, nn):
super().__init__()
self.k = k
self.nn = nn
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.nn(x)
return x, pos_skip, batch_skip
class Net(torch.nn.Module):
def __init__(self, num_classes):
super().__init__()
# Input channels account for both `pos` and node features.
self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))
self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))
self.mlp = MLP([128, 128, 128, num_classes], dropout=0.5, norm=None)
self.lin1 = torch.nn.Linear(128, 128)
self.lin2 = torch.nn.Linear(128, 128)
self.lin3 = torch.nn.Linear(128, num_classes)
def forward(self, data):
sa0_out = (data.x, data.pos, data.batch)
sa1_out = self.sa1_module(*sa0_out)
sa2_out = self.sa2_module(*sa1_out)
sa3_out = self.sa3_module(*sa2_out)
fp3_out = self.fp3_module(*sa3_out, *sa2_out)
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
x, _, _ = self.fp1_module(*fp2_out, *sa0_out)
return self.mlp(x).log_softmax(dim=-1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(train_dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
def train():
model.train()
total_loss = correct_nodes = total_nodes = 0
for i, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()
total_nodes += data.num_nodes
if (i + 1) % 10 == 0:
print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '
f'Train Acc: {correct_nodes / total_nodes:.4f}')
total_loss = correct_nodes = total_nodes = 0
@torch.no_grad()
def test(loader):
model.eval()
ious, categories = [], []
y_map = torch.empty(loader.dataset.num_classes, device=device).long()
for data in loader:
data = data.to(device)
outs = model(data)
sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
data.category.tolist()):
category = list(ShapeNet.seg_classes.keys())[category]
part = ShapeNet.seg_classes[category]
part = torch.tensor(part, device=device)
y_map[part] = torch.arange(part.size(0), device=device)
iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],
num_classes=part.size(0), absent_score=1.0)
ious.append(iou)
categories.append(data.category)
iou = torch.tensor(ious, device=device)
category = torch.cat(categories, dim=0)
mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU.
return float(mean_iou.mean()) # Global IoU.
for epoch in range(1, 31):
train()
iou = test(test_loader)
print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}')