Skip to content

Commit 2bdcae4

Browse files
sxufacebook-github-bot
authored andcommitted
Fix static attention mask update (#9101)
Summary: The range based for loop was making a copy of the mask, and thus the updates did not take effect. Remove the copy and move constructors of StaticKVCache and StaticAttention as they are not needed. Also add the missing deallocate call in mask's destructor. Differential Revision: D70914174
1 parent 1a34e56 commit 2bdcae4

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

examples/models/llama/runner/static_attention_io_manager.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ class StaticKVCache {
3838
reset();
3939
}
4040

41+
StaticKVCache(const StaticKVCache& other) = delete;
42+
StaticKVCache& operator=(const StaticKVCache& other) = delete;
43+
StaticKVCache(StaticKVCache&& other) = delete;
44+
StaticKVCache& operator=(StaticKVCache&& other) = delete;
45+
4146
~StaticKVCache() {
4247
allocator_.deallocate(data_, data_size_);
4348
}
@@ -200,6 +205,15 @@ class StaticAttentionMask {
200205
reset();
201206
}
202207

208+
StaticAttentionMask(const StaticAttentionMask& other) = delete;
209+
StaticAttentionMask& operator=(const StaticAttentionMask& other) = delete;
210+
StaticAttentionMask(StaticAttentionMask&& other) = delete;
211+
StaticAttentionMask& operator=(StaticAttentionMask&& other) = delete;
212+
213+
~StaticAttentionMask() {
214+
allocator_.deallocate(data_, data_size_);
215+
}
216+
203217
/**
204218
* Reset the mask to the state where the cache contains no valid data.
205219
*/
@@ -315,7 +329,7 @@ class StaticAttentionIOManager {
315329
input_pos_ += update_len;
316330
kCaches_.update(method, k_cache_output_indices, update_len);
317331
vCaches_.update(method, v_cache_output_indices, update_len);
318-
for (auto it : attentionMasks_) {
332+
for (auto& it : attentionMasks_) {
319333
it.second.updateCacheMask(update_len);
320334
}
321335
}
@@ -324,7 +338,7 @@ class StaticAttentionIOManager {
324338
input_pos_ = 0;
325339
kCaches_.reset();
326340
vCaches_.reset();
327-
for (auto it : attentionMasks_) {
341+
for (auto& it : attentionMasks_) {
328342
it.second.reset();
329343
}
330344
}

0 commit comments

Comments
 (0)