@@ -209,6 +209,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
209
209
self .head_dim = config .head_dim
210
210
self .inv_scale = 1.0 / (float (self .head_dim ) ** 0.5 )
211
211
self .attention_qkv_bias = config .attention_qkv_bias
212
+ self .use_conv2d = False
212
213
213
214
self .wqs = nn .ModuleList (
214
215
[
@@ -253,9 +254,25 @@ def forward(
253
254
in_cache_state = kwargs .get ("in_cache_state" )
254
255
out_cache_state = kwargs .get ("out_cache_state" )
255
256
257
+ bsz , seq_len , dim = x .shape
258
+ if self .use_conv2d :
259
+ x = x .reshape (bsz , seq_len , 1 , dim ).transpose (1 , 3 )
260
+
256
261
new_qs = [self .wqs [i ](x ) for i in range (self .n_heads )]
257
262
new_ks = [self .wks [i ](x ) for i in range (self .n_kv_heads )]
258
263
new_vs = [self .wvs [i ](x ) for i in range (self .n_kv_heads )]
264
+
265
+ if self .use_conv2d :
266
+
267
+ def from_conv2ds (ts ):
268
+ return [
269
+ t .reshape (bsz , self .head_dim , seq_len ).transpose (1 , 2 ) for t in ts
270
+ ]
271
+
272
+ new_qs = from_conv2ds (new_qs )
273
+ new_ks = from_conv2ds (new_ks )
274
+ new_vs = from_conv2ds (new_vs )
275
+
259
276
new_qs = [self .rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
260
277
new_ks = [self .rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
261
278
@@ -281,7 +298,14 @@ def forward(
281
298
heads .append (attn @ all_vs [kv_idx ])
282
299
283
300
y = torch .cat (heads , dim = - 1 )
284
- y = self .wo (y )
301
+ if self .use_conv2d :
302
+ y = (
303
+ self .wo (y .reshape (bsz , seq_len , 1 , - 1 ).transpose (1 , 3 ))
304
+ .transpose (1 , 3 )
305
+ .reshape (bsz , seq_len , - 1 )
306
+ )
307
+ else :
308
+ y = self .wo (y )
285
309
return y , {"out_cache_state" : out_cache_state }
286
310
287
311
def load_weights_from_attention_mha (self , other : AttentionMHA ):
@@ -299,3 +323,44 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
299
323
)
300
324
301
325
self .wo .weight .data .copy_ (other .wo .weight )
326
+
327
+ def linear_to_conv2d (self ):
328
+ def transfer_weight (linear , conv2d ):
329
+ conv2d .weight .data .copy_ (linear .weight [:, :, None , None ])
330
+ return conv2d
331
+
332
+ self .wqs = nn .ModuleList (
333
+ [
334
+ transfer_weight (
335
+ linear ,
336
+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
337
+ )
338
+ for linear in self .wqs
339
+ ]
340
+ )
341
+ self .wks = nn .ModuleList (
342
+ [
343
+ transfer_weight (
344
+ linear ,
345
+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
346
+ )
347
+ for linear in self .wks
348
+ ]
349
+ )
350
+ self .wvs = nn .ModuleList (
351
+ [
352
+ transfer_weight (
353
+ linear ,
354
+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
355
+ )
356
+ for linear in self .wvs
357
+ ]
358
+ )
359
+ self .wo = transfer_weight (
360
+ self .wo ,
361
+ nn .Conv2d (
362
+ self .n_heads * self .head_dim , self .dim , 1 , bias = self .attention_qkv_bias
363
+ ),
364
+ )
365
+
366
+ self .use_conv2d = True
0 commit comments