-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformer_I.py
159 lines (132 loc) · 7.33 KB
/
transformer_I.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
embed_dim=16
num_heads=8
feed_forward_ratio=4
size=10
seq_len=24
batch_size=2
device="mps"
class MultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dim_head = embed_dim // num_heads
if self.dim_head * num_heads != embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads, but got embed_dim={embed_dim} and num_heads={num_heads}.")
def forward(self, q, k, v, attn_mask=None):
batch_size = q.shape[0]
seq_len = q.shape[1]
q=q.reshape(batch_size, seq_len, self.num_heads, self.dim_head).permute(0,2,1,3) # (batch_size, seq_len, embed_dim) -> (batch_size, num_heads, seq_len, embed_dim)
k=k.reshape(batch_size, seq_len, self.num_heads, self.dim_head).permute(0,2,1,3) # (batch_size, seq_len, embed_dim) -> (batch_size, num_heads, seq_len, embed_dim)
v=v.reshape(batch_size, seq_len, self.num_heads, self.dim_head).permute(0,2,1,3) # (batch_size, seq_len, embed_dim) -> (batch_size, num_heads, seq_len, embed_dim)
[email protected](-2,-1)/torch.sqrt(torch.tensor(self.embed_dim)) # (batch_size, num_heads, seq_len, seq_len)
if attn_mask is not None:
if attn_mask.ndim==2: # (seq_len, seq_len)
attn_mask=attn_mask.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
elif attn_mask.ndim==3: # (batch_size*num_heads, seq_len, seq_len)
attn_mask=attn_mask.reshape(batch_size, self.num_heads, seq_len, seq_len)
else:
raise ValueError(f"attn_mask must be 2D or 3D tensor, but got {attn_mask.ndim}D tensor.")
attn_logits=attn_logits.masked_fill(attn_mask==0, float("-inf"))
attn_weights=F.softmax(attn_logits, dim=-1)
values=attn_weights@v # (batch_size, num_heads, seq_len, seq_len) @ (batch_size, num_heads, seq_len, embed_dim) -> (batch_size, num_heads, seq_len, embed_dim)
values=values.permute(0,2,1,3).reshape(batch_size, seq_len, self.embed_dim) # (batch_size, num_heads, seq_len, embed_dim) -> (batch_size, seq_len, embed_dim)
return values, attn_weights
class transformer(nn.Module):
def __init__(self, embed_dim,num_heads,feed_forward_ratio):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dim_head = embed_dim // num_heads
if self.dim_head * num_heads != embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads, but got embed_dim={embed_dim} and num_heads={num_heads}.")
self.feed_forward_ratio = feed_forward_ratio
self.qkv_combined=nn.Linear(embed_dim,3*embed_dim)
#self.atten_layer=nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.atten_layer=MultiheadAttention(embed_dim, num_heads)
self.ff_layers=nn.Sequential(
nn.Linear(embed_dim,feed_forward_ratio*embed_dim),
nn.ReLU(),
nn.Linear(feed_forward_ratio*embed_dim,embed_dim)
)
self.layer_norm=nn.LayerNorm(embed_dim)
def forward(self,x): # x: (batch_size, seq_len, embed_dim)
qkv=self.qkv_combined(x) # (batch_size, seq_len, embed_dim) -> (batch_size, seq_len, 3*embed_dim)
q, k, v =qkv.chunk(3,dim=-1) # q: (batch_size, seq_len, embed_dim)
#mask = torch.ones(x.shape[1], x.shape[1]) # (seq_len, seq_len)
#mask = torch.tril(mask, diagonal=0).to(device) # upper triangular part of the matrix is zero
mask=None
values, attn_weights =self.atten_layer.forward(q,k,v,attn_mask=mask) # multiplies muck to logits and applies softmax to get attention weights
x=values+x
x=self.layer_norm(x)
x=self.ff_layers(x)
x=self.layer_norm(x)
return x
model=transformer(embed_dim,num_heads,feed_forward_ratio)
class flipped_dataset(data.Dataset):
def __init__(self,size, embed_dim, seq_len):
super().__init__()
self.len=size
self.embed_dim=embed_dim
self.seq_len=seq_len
self.data=torch.randint(embed_dim,size=(size,seq_len), dtype=torch.float32) # (size, seq_len) and each element is between 0 and embed_dim-1
def __len__(self):
return self.len
def __getitem__(self,idx):
input=self.data[idx]
output=input.flip(0)
return input, output
sample_dataset=flipped_dataset(size,embed_dim,seq_len)
sample_dataloader=data.DataLoader(sample_dataset,batch_size=batch_size,shuffle=True)
small_data, samll_label=next(iter(sample_dataloader)) # (batch_size, seq_len), (batch_size, seq_len)
class positional_encoding(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x): # x: (batch_size, seq_len, embed_dim)
batch_size=x.shape[0]
seq_len=x.shape[1]
embed_dim=x.shape[2]
pos=torch.arange(seq_len).reshape(seq_len,1)
embed_pos=torch.arange(embed_dim)
embed1=torch.where(embed_pos%2==0,torch.tensor(0),torch.tensor(1))*torch.sin(pos*(10**(-embed_pos/embed_dim).reshape(1,embed_dim)))
embed2=torch.where(embed_pos%2==0,torch.tensor(1),torch.tensor(0))*torch.cos(pos*(10**(-embed_pos/embed_dim).reshape(1,embed_dim)))
embed=embed1+embed2
embed=embed.reshape(1,seq_len,embed_dim).repeat(batch_size,1,1)
x=x+embed
return x
pos_enc=positional_encoding()
class Loss_Function(nn.Module):
def __init__(self, ):
super().__init__()
def forward(self,pred_logit, label): # pred_logit: (batch_size, seq_len, embed_dim), label: (batch_size, seq_len)
loss=F.cross_entropy(pred_logit.reshape(batch_size*seq_len, embed_dim),label.long().reshape(batch_size*seq_len))
acc=(pred_logit.argmax(dim=-1)==label).float().mean()
#print(f"Loss: {loss}, Accuracy: {acc}")
return loss
loss_function=Loss_Function()
optimizer=torch.optim.SGD(model.parameters(), lr=0.1)
scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step:min((step+1)/100,1/(step+1)**(1/2)))
def train(model, dataloader, loss_function, optimizer, scheduler, num_epoch=2):
model=model.to(device)
model.train()
for epoch in torch.arange(num_epoch):
epoch_loss = 0
for data, label in dataloader:
data_encoded=F.one_hot(data.long(), num_classes=embed_dim).float() # (batch_size, seq_len) -> (batch_size, seq_len, embed_dim), here .long() ensures data is int type not float
data=pos_enc(data_encoded).to(device)
label=label.to(device)
optimizer.zero_grad() # forget the previous gradients
pred=model(data)
loss_val=loss_function(pred, label)
loss_val.backward() # backpropagation to calculate the gradients
max_grad_norm=10.0
nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # gradient clipping
optimizer.step() # update the weights
scheduler.step()
epoch_loss += loss_val.item()
print(f"Epoch {epoch+1} Loss: {epoch_loss/len(dataloader)}")
train(model, sample_dataloader, loss_function, optimizer, scheduler, num_epoch=50)