随便写点东西,准备把 FlashAttention 系列的算法原理推一遍。
本着 Talk is cheap,show me the code 的思想,简单用 python 快速实现了一遍。
这里主要是对算法设计比较好奇,因此跟 cuda 相关的代码优化部分先忽略。搞懂算法实现后,系统同学可以自己想想 gpu 优化怎么做。
FlashAttention
这个系列的开山之作:【FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness】
主要是通过计算换访存的思路,减少 cuda kernel 对 global memory 的访问量,在 sequence length 长,整体计算过程偏 memory bound 的情况下有很好的效果。
Self Attention
先来看下常规的 Self Attention:
$$O = Softmax(Q * K^T) * V$$
假设 Sequence Length 是 L,Head dim 是 D,这个过程在常规实现中需要放在 3 个 cuda kernel 中,伪代码:
1 | // [L, Dim] * [Dim, L] -> [L, L] |
对 global memory 的访存量为 $4L * Dim + 3L * L = O(LDim + L^2)$
Show me the code!
1 | def normal_selfattn(q, k, v): |
Softmax
重新看一下 softmax 是在做什么:
$$Softmax({x_1, x_2, …, x_N}) = \left \lbrace \frac{e^{x_i}}{\sum^N_{j=1}e^{x_j}} \right \rbrace^N_{i=1}$$
值得注意的是,目前通用的 softmax 实现中为了防止数值溢出还需要再额外减掉一个 max:
$$\begin{aligned}
m &= Max^N_{n=0}x_n\\
Softmax({x_1, x_2, …, x_N}) &= \left \lbrace \frac{e^{x_i - m}}{\sum^N_{j=1}e^{x_j - m}} \right \rbrace^N_{i=1}
\end{aligned}
$$
因此这里计算 max 需要一次独立的全局 reduce,计算分母的 sum 再需要一次独立的全局 reduce,最后分别计算每一个元素的 softmax 值。三个步骤之间存在数据依赖。
Show me the code!
1 | def normal_softmax(x): |
Online Softmax
【Online normalizer calculation for softmax】 这篇 paper 提出了一种能将上面的 3 步 softmax 合并成 2 步完成的思路。
考虑原始计算步骤中分母求最大值以及求和的部分,这里需要 3 个独立的循环:
$$\begin{align}
m_i &= \max(m_{i-1}, x_i) \\
d_i &= d_{i-1} + e^{x_i - m_N} = \sum^i_{j=1}e^{x_j - m_N} \\
softmax_i &= \frac{e^{x_i - m_N}}{d_N} \\
\end{align}$$
公式中产生数据依赖的原因是 $(2)$ 需要依赖 $m_N$,而 $(3)$ 需要依赖 $m_N$ 和 $d_N$。
如果能有这样一个 $d_i’$:
$$d_i’ = \sum^i_{j=1}e^{x_j - m_i} = d_{i-1}’ + e^{x_i - m_i}$$
则 $(2)$ 对 $m_N$ 的数据依赖就解除了,虽然序列的中间部分的值不相等但 $d_N$ 与 $d_N’$ 是等价的。
而 $d_i’$ 这个序列存在递推性质:
$$
\begin{aligned}
d_i’ &= \sum^i_{j=1}e^{x_j - m_i} \\
&= \sum^{i-1}_{j=1}e^{x_j - m_i} + e^{x_i-m_i} \\
&= \left ({\sum^{i-1}_{j=1}e^{x_j - m_{i-1}}} \right ) * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
&= d_{i-1}’ * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
\end{aligned}
$$
这样 $m_i$ 和 $d_i’$ 就可以在一个 kernel 中计算完成了,kernel 可以减少到两个:
$$\begin{aligned}
m_i &= \max(m_{i-1}, x_i) \\
d_i’ &= d_{i-1}’ * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
softmax_i &= \frac{e^{x_i - m_N}}{d_N’}
\end{aligned}$$
Show me the code!
1 | def online_softmax_update(m0, d0, m1, d1): |
但是这样对于加速来说还是不够,重新推几层:
Index | 0 | 1 | 2 | 3 |
---|---|---|---|---|
$m_i$ | $x_0$ | $\max(m_0, x_1)$ | $\max(m_1, x_2)$ | $\max(m_2, x_3)$ |
$d_i’$ | 1 | $e^{x_0-m_1} + e^{x_1-m_1}$ | $e^{x_0-m_2} + e^{x_1-m_2} + e^{x_2-m_2}$ | $e^{x_0-m_3} + e^{x_1-m_3} + e^{x_2-m_3} + e^{x_3-m_3}$ |
可以发现由于 exp/log 计算的特性,$d_3’$ 除了正常从 $d_0’$、$d_1’$、$d_2’$ 按顺序推出以外,乱序也是可以得到的,例如 2->1->0->3:
$$\begin{aligned}
m_{21} &= \max(x_2, x_1) \\
d_{21}’ &= e^{x_2 - m_{21}} + e^{x_1 - m_{21}}\\
m_{210} &= \max(x_0, m_{21}) = \max(x_0, x_2, x_1)\\
d_{210}’ &= d_{21}’ * e^{m_{21} - m_{210}} + e^{x_0 - m_{210}}\\
&= (e^{x_2 - m_{21}} + e^{x_1 - m_{21}}) * e^{m_{21} - m_{210}} + e^{x_0 - m_{210}}\\
&= e^{x_2 - m_{210}} + e^{x_1 - m_{210}} + e^{x_0 - m_{210}}\\
m_3 = m_{2103} &= \max(x_3, m_{210}) = \max(x_3, x_0, x_2, x_1) \\
d_3’=d_{2103}’&=d_{210}’ * e^{m_{210} - m_3} + e^{x_3 - m_3}\\
&=(e^{x_2 - m_{210}} + e^{x_1 - m_{210}} + e^{x_0 - m_{210}}) * e^{m_{210} - m_3} + e^{x_3 - m_3}\\
&=e^{x_2 - m_3} + e^{x_1 - m_3} + e^{x_0 - m_3} + e^{x_3 - m_3}\\
\end{aligned}$$
如果定义分块计算时 $d_{xy}’ = d_x’ * e^{m_x - m_{xy}} + d_y’ * e^{m_y - m_{xy}}$,则按照 (2->1)->(0->3) 的迭代顺序:
$$\begin{aligned}
m_{21} &= \max(x_2, x_1) \\
d_{21}’ &= e^{x_2 - m_{21}} + e^{x_1 - m_{21}}\\
m_{03} &= \max(x_0, x_3) \\
d_{03}’ &= e^{x_0 - m_{03}} + e^{x_3 - m_{03}}\\
m_3=m_{2103} &= \max(m_{21}, m_{03}) = \max(x_2, x_1, x_0, x_3)\\
d_3’=d_{2103}’ &= d_{21}’ * e^{m_{21} - m_3} + d_{03}’ * e^{m_{03} - m_3} \\
&= (e^{x_2 - m_{21}} + e^{x_1 - m_{21}}) * e^{m_{21} - m_3} + e^{x_0 - m_3} + (e^{x_0 - m_{03}} + e^{x_3 - m_{03}}) * e^{m_{03} - m_3} \\
&= e^{x_2 - m_3} + e^{x_1 - m_3} + e^{x_0 - m_3} + e^{x_3 - m_3}\\
\end{aligned}$$
同样可以得到相同的 $d_N’$。这样,我们就可以得到这种方式最大的一个特性:$m$ 和 $d’$ 的迭代计算操作同时满足交换律和结合律,任意分块分别计算 $m$ 和 $d’$ 之后,将所有子块结果重新聚合在数学上完全等价,即序列中 max 值带来的影响可以延迟到最后一步再被修正。
之前在看很多资料的时候并没有提到这一点,突然就一下跳到 block 分块之后的实现了,看得人一脸懵逼。
这样 Online softmax 就可以通过分块并行得到进一步的加速了。
Show me the code!
1 | def online_block_softmax(x): |
FlashAttention
继续回到一开始的 FlashAttention,我们把 Online Softmax 放进去,这里的 Softmax 是对 [L, L] 结果中的每一行做一维的 softmax:
$$\begin{aligned}
Softmax^L_{r=1} &= Softmax(X_{r, 1}, X_{r, 2}, …, X_{r, L})^L_{r=1}\\
&=\left \lbrace \left \lbrace
\frac{e^{X_{r, i} - M_{r, L}}}{\sum^L_{j=1}e^{X_{r, j} - m_{r, L}}}
\right \rbrace^L_{i=1} \right \rbrace^L_{r=1} \\
\end{aligned}$$
原始实现中:
$$\begin{aligned}
X_{r, i} &= \sum^{Dim}_{j=1}Q[r, j]K[j, i]\\
M_{r, i} &= \max(M_{r, i-1}, X_{r, i}) \\
D_{r, i} &= D_{r, i-1} + e^{X_{r, i} - M_{r, L}} = \sum^i_{j=1}e^{X_{r, j} - M_{r, L}} \\
Softmax_{r, i} &= \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}} \\
\end{aligned}$$
同样可以替换成 $D_{r, i}’$:
$$\begin{aligned}
D_{r, i}’ &= D_{r, i-1}’ * e^{M_{r, i-1} - M_{r, i}} + e^{X_{r, i}-M_{r, i}} \\
Softmax_{r, i} &= \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}’} \\
O_{r, c} &= \sum^L_{i=1}(Softmax_{r, i} * V[i, c]) \\
\end{aligned}$$
Show me the code!
1 | def flashattn_0(q, k, v): |
把 $O_{r, c}$ 的累加过程拆开看:
$$\begin{aligned}
SubSum_{r, c, i} &= SubSum_{r, c, i-1} + Softmax_{r, i} * V[i, c]\\
&=SubSum_{r, c, i-1} + \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}’} * V[i, c]\\
&=\sum^i_{j=1}\frac{e^{X_{r, j} - M_{r, L}}}{D_{r, L}’}V[j, c]
\end{aligned}$$
可以发现 $SubSum_{r,c,i}$ 也是依赖于 $M_{r,L}$ 和 ${D_{r,L}’}$,运用与 online softmax 相似的方式,可以在这里增加一个 $SubSum_{r,c,i}’$:
$$\begin{aligned}
SubSum_{r,c,i}’ &= \sum^i_{j=1}\frac{e^{X_{r, j} - M_{r, i}}}{D_{r, i}’}V[j, c]\\
&=\sum^{i-1}_{j=1}\frac{e^{X_{r, j} - M_{r, i}}}{D_{r, i}’}V[j, c] + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
&=\left (\sum^{i-1}_{j=1}\frac{e^{X_{r, j} - M_{r, i-1}}}{D_{r, i-1}’}V[j, c] \right) * \frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
&=SubSum_{r,c,i-1}’*\frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
\end{aligned}$$
最终整理一下,在一个 $i = (1, L)$ 的循环中可以完成:
$$\begin{aligned}
X_{r, i} &= \sum^{Dim}_{j=1}Q[r, j]K[j, i]\\
M_{r, i} &= \max(M_{r, i-1}, X_{r, i})\\
D_{r, i}’ &= D_{r, i-1}’ * e^{M_{r, i-1} - M_{r, i}} + e^{X_{r, i}-M_{r, i}}\\
SubSum_{r,c,i}’ &=SubSum_{r,c,i-1}’*\frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
\end{aligned}$$
最终的输出结果为:
$$O_{r, c} = SubSum_{r,c,L}$$
这个就是论文里面的 Algorithm 1:
附录部分还有个包含了 mask 这些额外操作的完整版 Algorithm 2,这里就不再重复写了。
写成伪代码是:
1 | for (r = 1 to L) |
Show me the code!
1 | def flashattn_update(m, d, m0, d0, s0, m1, d1, s1): |
当然,与 Online softmax 一样,$D_{r, i}’$、$SubSum_{r,c,i}’$ 也具有分块满足交换律和结合律的特性:
$$\begin{aligned}
D_{r, xy}’ &= D_{r, x}’ * e^{M_{r, x} - M_{r, xy}} + D_{r, y}’ * e^{M_{r, y} - M_{r, xy}}\\
SubSum_{r,c,xy}’ &= SubSum_{r,c,x}’ * \frac{e^{M_{r, x}-M_{r, xy}}D_{r, x}’}{D_{r, xy}’} + SubSum_{r,c,y}’ * \frac{e^{M_{r, y}-M_{r, xy}}D_{r, y}’}{D_{r, xy}’}\\
\end{aligned}$$
则在常规矩阵乘法计算中可以用到的 tiling 分块策略也同样可以用在这个算法上得到终极加速了。
Show me the code!
1 | def flashattn_1_block(q, k, v): |
FlashAttention-2
【FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning】
2 相比 1 在算法上改动其实很小,注意到前面公式里面 $SubSum$ 的递推式需要反复乘除 $D’$,我们来简单推一下,原递推式可以简写成:
$$
S[i] = S[i-1] * E * D[i-1] / D[i] + E * V / D[i]
$$
把 $D[i]$ 乘到左边:
$$
S[i] * D[i] = S[i-1] * E * D[i-1] + E * V
$$
记:
$$
SS[i] = S[i] * D[i]
$$
则上式可改写成:
$$\begin{aligned}
SS[i] &= SS[i-1] * E + E * V\\
S[i] &= SS[i-1] / D[i]
\end{aligned}$$
这也就是论文中的更新版本了:
Show me the code!
1 | # m, d, m0, d0, s0, m1, d1, s1): |
FlashDecoding & FlashDecoding++
【Flash-Decoding for long-context inference】
这两篇后续在算法本身上没有什么特别的改进,主要是针对 LLM 的具体算子尺寸情况更精细地做了 cuda block 的划分和优化,这里也不再详细记录了。
完整的 test python 文件:softmax_test.py