Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 23 additions & 32 deletions python/mlx/nn/layers/positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,47 +115,38 @@ def __call__(self, x):


class ALiBi(Module):
_alibi_mask_key = None
_alibi_mask = None

@classmethod
@staticmethod
def create_alibi_matrix(
cls,
q_sequence_length: int,
k_sequence_length: int,
num_heads: int,
offset: int,
dtype=mx.float32,
):
if (
q_sequence_length,
k_sequence_length,
num_heads,
offset,
dtype,
) != cls._alibi_mask_key:
x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
cls._alibi_mask_key = (
q_sequence_length,
k_sequence_length,
num_heads,
offset,
dtype,
)
cls._alibi_mask = alibi_mask

return cls._alibi_mask
x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads, dtype=dtype)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
return alibi_mask

@staticmethod
def create_alibi_slope(num_heads):
x = (2**8) ** (1 / num_heads)
out = mx.power(x, -mx.arange(1, num_heads + 1))
def create_alibi_slope(num_heads, dtype):
def get_slopes(n: int):
if math.log2(n).is_integer():
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)]
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)

slopes = get_slopes(num_heads)
out = mx.array(slopes, dtype=dtype)
return mx.expand_dims(out, axis=(-1, -2))

def __call__(self, attention_scores, offset=0, mask=None):
Expand Down
Loading