0%

FlashAttentions

随便写点东西,准备把 FlashAttention 系列的算法原理推一遍。

FlashAttention

这个系列的开山之作:【FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness】

主要是通过计算换访存的思路,减少 cuda kernel 对 global memory 的访问量,在 sequence length 长,整体计算过程偏 memory bound 的情况下有很好的效果。

Self Attention

先来看下常规的 Self Attention:

$$O = Softmax(Q * K^T) * V$$

假设 Sequence Length 是 L,Head dim 是 D,这个过程在常规实现中需要放在 3 个 cuda kernel 中,伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// [L, Dim] * [Dim, L] -> [L, L]
for (r = 1 to L)
for (i = 1 to L)
for (j = 1 to Dim)
X[r, i] += Q[r, j] * K[j, i]

// [L, L]
for (r = 1 to L)
for (i = 1 to L)
e_sum += e^(X[r, i])
for (i = 1 to L)
Softmax[r, i] = e^(X[r, i]) / e_sum

// [L, L] * [L, Dim] -> [L, Dim]
for (r = 1 to L)
for (c = 1 to Dim)
for (i = 1 to L)
O[r, c] += Softmax[r, i] * V[i, c]

对 global memory 的访存量为 $4L * Dim + 3L * L = O(LDim + L^2)$

Softmax

重新看一下 softmax 是在做什么:

$$Softmax({x_1, x_2, …, x_N}) = \left \lbrace \frac{e^{x_i}}{\sum^N_{j=1}e^{x_j}} \right \rbrace^N_{i=1}$$

值得注意的是,目前通用的 softmax 实现中为了防止数值溢出还需要再额外减掉一个 max:

$$\begin{aligned}
m &= Max^N_{n=0}x_n\\
Softmax({x_1, x_2, …, x_N}) &= \left \lbrace \frac{e^{x_i - m}}{\sum^N_{j=1}e^{x_j - m}} \right \rbrace^N_{i=1}
\end{aligned}
$$

因此这里计算 max 需要一次独立的全局 reduce,计算分母的 sum 再需要一次独立的全局 reduce,最后分别计算每一个元素的 softmax 值。三个步骤之间存在数据依赖。

Online Softmax

【Online normalizer calculation for softmax】 这篇 paper 提出了一种能将上面的 3 步 softmax 合并成 2 步完成的思路。

考虑原始计算步骤中分母求最大值以及求和的部分,这里需要 3 个独立的循环:

$$\begin{align}
m_i &= \max(m_{i-1}, x_i) \\
d_i &= d_{i-1} + e^{x_i - m_N} = \sum^i_{j=1}e^{x_j - m_N} \\
softmax_i &= \frac{e^{x_i - m_N}}{d_N} \\
\end{align}$$

公式中产生数据依赖的原因是 $(2)$ 需要依赖 $m_N$,而 $(3)$ 需要依赖 $m_N$ 和 $d_N$。

如果能有这样一个 $d_i’$:

$$d_i’ = \sum^i_{j=1}e^{x_j - m_i} = d_{i-1}’ + e^{x_i - m_i}$$

则 $(2)$ 对 $m_N$ 的数据依赖就解除了,虽然序列的中间部分的值不相等但 $d_N$ 与 $d_N’$ 是等价的。

而 $d_i’$ 这个序列存在递推性质:

$$
\begin{aligned}
d_i’ &= \sum^i_{j=1}e^{x_j - m_i} \\
&= \sum^{i-1}_{j=1}e^{x_j - m_i} + e^{x_i-m_i} \\
&= \left ({\sum^{i-1}_{j=1}e^{x_j - m_{i-1}}} \right ) * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
&= d_{i-1}’ * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
\end{aligned}
$$

这样 $m_i$ 和 $d_i’$ 就可以在一个 kernel 中计算完成了,kernel 可以减少到两个:

$$\begin{aligned}
m_i &= \max(m_{i-1}, x_i) \\
d_i’ &= d_{i-1}’ * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
softmax_i &= \frac{e^{x_i - m_N}}{d_N’}
\end{aligned}$$

但是这样对于加速来说还是不够,重新推几层:

Index 0 1 2 3
$m_i$ $x_0$ $\max(m_0, x_1)$ $\max(m_1, x_2)$ $\max(m_2, x_3)$
$d_i’$ 1 $e^{x_0-m_1} + e^{x_1-m_1}$ $e^{x_0-m_2} + e^{x_1-m_2} + e^{x_2-m_2}$ $e^{x_0-m_3} + e^{x_1-m_3} + e^{x_2-m_3} + e^{x_3-m_3}$

可以发现由于 exp/log 计算的特性,$d_3’$ 除了正常从 $d_0’$、$d_1’$、$d_2’$ 按顺序推出以外,乱序也是可以得到的,例如 2->1->0->3:

$$\begin{aligned}
m_{21} &= \max(x_2, x_1) \\
d_{21}’ &= e^{x_2 - m_{21}} + e^{x_1 - m_{21}}\\
m_{210} &= \max(x_0, m_{21}) = \max(x_0, x_2, x_1)\\
d_{210}’ &= d_{21}’ * e^{m_{21} - m_{210}} + e^{x_0 - m_{210}}\\
&= (e^{x_2 - m_{21}} + e^{x_1 - m_{21}}) * e^{m_{21} - m_{210}} + e^{x_0 - m_{210}}\\
&= e^{x_2 - m_{210}} + e^{x_1 - m_{210}} + e^{x_0 - m_{210}}\\
m_3 = m_{2103} &= \max(x_3, m_{210}) = \max(x_3, x_0, x_2, x_1) \\
d_3’=d_{2103}’&=d_{210}’ * e^{m_{210} - m_3} + e^{x_3 - m_3}\\
&=(e^{x_2 - m_{210}} + e^{x_1 - m_{210}} + e^{x_0 - m_{210}}) * e^{m_{210} - m_3} + e^{x_3 - m_3}\\
&=e^{x_2 - m_3} + e^{x_1 - m_3} + e^{x_0 - m_3} + e^{x_3 - m_3}\\
\end{aligned}$$

如果定义分块计算时 $d_{xy}’ = d_x’ * e^{m_x - m_{xy}} + d_y’ * e^{m_y - m_{xy}}$,则按照 (2->1)->(0->3) 的迭代顺序:

$$\begin{aligned}
m_{21} &= \max(x_2, x_1) \\
d_{21}’ &= e^{x_2 - m_{21}} + e^{x_1 - m_{21}}\\
m_{03} &= \max(x_0, x_3) \\
d_{03}’ &= e^{x_0 - m_{03}} + e^{x_3 - m_{03}}\\
m_3=m_{2103} &= \max(m_{21}, m_{03}) = \max(x_2, x_1, x_0, x_3)\\
d_3’=d_{2103}’ &= d_{21}’ * e^{m_{21} - m_3} + d_{03}’ * e^{m_{03} - m_3} \\
&= (e^{x_2 - m_{21}} + e^{x_1 - m_{21}}) * e^{m_{21} - m_3} + e^{x_0 - m_3} + (e^{x_0 - m_{03}} + e^{x_3 - m_{03}}) * e^{m_{03} - m_3} \\
&= e^{x_2 - m_3} + e^{x_1 - m_3} + e^{x_0 - m_3} + e^{x_3 - m_3}\\
\end{aligned}$$

同样可以得到相同的 $d_N’$。这样,我们就可以得到这种方式最大的一个特性:$m$ 和 $d’$ 的迭代计算操作同时满足交换律和结合律,任意分块分别计算 $m$ 和 $d’$ 之后,将所有子块结果重新聚合在数学上完全等价,即序列中 max 值带来的影响可以延迟到最后一步再被修正。

这样 Online softmax 就可以通过分块并行得到进一步的加速了。

FlashAttention

继续回到一开始的 FlashAttention,我们把 Online Softmax 放进去,这里的 Softmax 是对 [L, L] 结果中的每一行做一维的 softmax:

$$\begin{aligned}
Softmax^L_{r=1} &= Softmax(X_{r, 1}, X_{r, 2}, …, X_{r, L})^L_{r=1}\\
&=\left \lbrace \left \lbrace
\frac{e^{X_{r, i} - M_{r, L}}}{\sum^L_{j=1}e^{X_{r, j} - m_{r, L}}}
\right \rbrace^L_{i=1} \right \rbrace^L_{r=1} \\
\end{aligned}$$

原始实现中:

$$\begin{aligned}
X_{r, i} &= \sum^{Dim}_{j=1}Q[r, j]K[j, i]\\
M_{r, i} &= \max(M_{r, i-1}, X_{r, i}) \\
D_{r, i} &= D_{r, i-1} + e^{X_{r, i} - M_{r, L}} = \sum^i_{j=1}e^{X_{r, j} - M_{r, L}} \\
Softmax_{r, i} &= \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}} \\
\end{aligned}$$

同样可以替换成 $D_{r, i}’$:

$$\begin{aligned}
D_{r, i}’ &= D_{r, i-1}’ * e^{M_{r, i-1} - M_{r, i}} + e^{X_{r, i}-M_{r, i}} \\
Softmax_{r, i} &= \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}’} \\
O_{r, c} &= \sum^L_{i=1}(Softmax_{r, i} * V[i, c]) \\
\end{aligned}$$

把 $O_{r, c}$ 的累加过程拆开看:

$$\begin{aligned}
SubSum_{r, c, i} &= SubSum_{r, c, i-1} + Softmax_{r, i} * V[i, c]\\
&=SubSum_{r, c, i-1} + \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}’} * V[i, c]\\
&=\sum^i_{j=1}\frac{e^{X_{r, j} - M_{r, L}}}{D_{r, L}’}V[j, c]
\end{aligned}$$

可以发现 $SubSum_{r,c,i}$ 也是依赖于 $M_{r,L}$ 和 ${D_{r,L}’}$,运用与 online softmax 相似的方式,可以在这里增加一个 $SubSum_{r,c,i}’$:

$$\begin{aligned}
SubSum_{r,c,i}’ &= \sum^i_{j=1}\frac{e^{X_{r, j} - M_{r, i}}}{D_{r, i}’}V[j, c]\\
&=\sum^{i-1}_{j=1}\frac{e^{X_{r, j} - M_{r, i}}}{D_{r, i}’}V[j, c] + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
&=\left (\sum^{i-1}_{j=1}\frac{e^{X_{r, j} - M_{r, i-1}}}{D_{r, i-1}’}V[j, c] \right) * \frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
&=SubSum_{r,c,i-1}’*\frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
\end{aligned}$$

最终整理一下,在一个 $i = (1, L)$ 的循环中可以完成:

$$\begin{aligned}
X_{r, i} &= \sum^{Dim}_{j=1}Q[r, j]K[j, i]\\
M_{r, i} &= \max(M_{r, i-1}, X_{r, i})\\
D_{r, i}’ &= D_{r, i-1}’ * e^{M_{r, i-1} - M_{r, i}} + e^{X_{r, i}-M_{r, i}}\\
SubSum_{r,c,i}’ &=SubSum_{r,c,i-1}’*\frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
\end{aligned}$$

最终的输出结果为:

$$O_{r, c} = SubSum_{r,c,L}$$

写成伪代码是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for (r = 1 to L)
for (i = 1 to L)
// [L, Dim] * [Dim, L] -> [L, L]
for (j = 1 to Dim)
X[r, i] += Q[r, j] * K[j, i]

// [L, L]
M[r, i] = max(M[r, i-1], X[r, i])

// [L, L]
D'[r, i] = D'[r, i-1] * e(...) + e(...)

// [L, Dim]
for (c = 1 to Dim)
O[r, c] += O[r, c] * e(...) * D'[r, i-1] / D'[r, i] + e(...) / D'[r, i] * V[i, c]

当然,与 Online softmax 一样,$D_{r, i}’$、$SubSum_{r,c,i}’$ 也具有分块满足交换律和结合律的特性:

$$\begin{aligned}
D_{r, xy}’ &= D_{r, x}’ * e^{M_{r, x} - M_{r, xy}} + D_{r, y}’ * e^{M_{r, y} - M_{r, xy}}\\
SubSum_{r,c,xy}’ &= SubSum_{r,c,x}’ * \frac{e^{M_{r, x}-M_{r, xy}}D_{r, x}}{D_{r, xy}’} + SubSum_{r,c,y}’ * \frac{e^{M_{r, y}-M_{r, xy}}D_{r, y}}{D_{r, xy}’}\\
\end{aligned}$$

则在常规矩阵乘法计算中可以用到的 tiling 分块策略也同样可以用在这个算法上得到终极加速了。

FlashAttention-2

【FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning】

待续…