-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathconv_rnn.py
72 lines (57 loc) · 1.69 KB
/
conv_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Combine Conv3d with an RNN Module.
Use windowed frames as inputs.
@author: ptrblck
"""
import torch
import torch.nn as nn
from torch.utils.data import Dataset
class MyModel(nn.Module):
def __init__(self, window=16):
super(MyModel, self).__init__()
self.conv_model = nn.Sequential(
nn.Conv3d(
in_channels=3,
out_channels=6,
kernel_size=3,
stride=1,
padding=1),
nn.MaxPool3d((1, 2, 2)),
nn.ReLU()
)
self.rnn = nn.RNN(
input_size=6*16*12*12,
hidden_size=1,
num_layers=1,
batch_first=True
)
self.hidden = torch.zeros(1, 1, 1)
self.window = window
def forward(self, x):
self.hidden = torch.zeros(1, 1, 1) # reset hidden
activations = []
for idx in range(0, x.size(2), self.window):
x_ = x[:, :, idx:idx+self.window]
x_ = self.conv_model(x_)
x_ = x_.view(x_.size(0), 1, -1)
activations.append(x_)
x = torch.cat(activations, 1)
out, hidden = self.rnn(x, self.hidden)
return out, hidden
class MyDataset(Dataset):
'''
Returns windowed frames from sequential data.
'''
def __init__(self, frames=512):
self.data = torch.randn(3, 2048, 24, 24)
self.frames = frames
def __getitem__(self, index):
index = index * self.frames
x = self.data[:, index:index+self.frames]
return x
def __len__(self):
return self.data.size(1) / self.frames
model = MyModel()
dataset = MyDataset()
x = dataset[0]
output, hidden = model(x.unsqueeze(0))