New Medium WR - Remove a redundant op while creating block masks, -220 ms #157
+13,913
−3
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Remove a redundant tensor op in block mask calculation
The change is super simple. The function
create_blockmasksperforms a series of tensor operations to ultimately create a block mask containing the indices of "partial" blocks and "full blocks" which indicate to flex attention where to apply thedocument_causalmask_mod and where to skip it respectively.One of those tensor ops is redundant, saving about 220 ms overall.
Timing
We can see in the record that it completes in 1496063 ms. When I timed it against the unmodified code, it finished in 1496283 ms so this change should save about 220 ms.
The redundant operation
To create block masks, we first aim to create two masks -
blockmask_anyandblockmask_alland later find the partial blocks by computingblockmask_any & ~blockmask_all.Focusing on
blockmask_any, it's computed like so:The idea is that this is exactly equal to
and thus the
&operation to computedocument_blockmask_anyis redundant.Why is the op redundant?
This is a little confusing but here is my attempt:
Consider the
(i, j)element in the matrixdocument_blockmask_any. Since we are going to apply causal mask later to computeblockmask_any, all elements wherej > iwill becomeFalse. So we only care aboutj <= i.If
j <= i, the ending doc of block i must be greater than the starting doc of block j by definition since block j appears before block i. Sodocs_high[:, None] >= docs_lowmust evaluate toTrueforj <= i.Another way to view this is,
docs_high[i] >= docs_low[j]for allj <= itherefore, forj <= i, computingdocs_high[:, None] >= docs_lowis redundant.Note that, on the other hand, if
j <= i, thendocs_low[i]is not guaranteed to be smaller thandocs_high[j].This can be seen via a counter-example, let's say the
docsvector is like so:Consider a block size of 2 which yields
docs_lowanddocs_highas:Then for
i = 4, j = 1(i.e.j <= i),docs_low[i] = 3whiledocs_high[j] = 2.We can also verify in this example that for all
j <= i,docs_high[i] >= docs_low[j].