0%

FlashAttentions

随便写点东西,准备把 FlashAttention 系列的算法原理推一遍。

本着 Talk is cheap,show me the code 的思想,简单用 python 快速实现了一遍。
这里主要是对算法设计比较好奇,因此跟 cuda 相关的代码优化部分先忽略。搞懂算法实现后,系统同学可以自己想想 gpu 优化怎么做。

FlashAttention

这个系列的开山之作:【FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness】

主要是通过计算换访存的思路,减少 cuda kernel 对 global memory 的访问量,在 sequence length 长,整体计算过程偏 memory bound 的情况下有很好的效果。

Self Attention

先来看下常规的 Self Attention:

$$O = Softmax(Q * K^T) * V$$

假设 Sequence Length 是 L,Head dim 是 D,这个过程在常规实现中需要放在 3 个 cuda kernel 中,伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// [L, Dim] * [Dim, L] -> [L, L]
for (r = 1 to L)
for (i = 1 to L)
for (j = 1 to Dim)
X[r, i] += Q[r, j] * K[j, i]

// [L, L]
for (r = 1 to L)
for (i = 1 to L)
e_sum += e^(X[r, i])
for (i = 1 to L)
Softmax[r, i] = e^(X[r, i]) / e_sum

// [L, L] * [L, Dim] -> [L, Dim]
for (r = 1 to L)
for (c = 1 to Dim)
for (i = 1 to L)
O[r, c] += Softmax[r, i] * V[i, c]

对 global memory 的访存量为 $4L * Dim + 3L * L = O(LDim + L^2)$

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def normal_selfattn(q, k, v):
# [L, Dim] * [Dim, L] -> [L, L]
x = np.zeros([L, L], "float32")
for r in range(0, L):
for i in range(0, L):
for j in range(0, Dim):
x[r, i] += q[r, j] * k[i, j]

# [L, L]
softmax = np_softmax(x)

# [L, L] * [L, Dim] -> [L, Dim]
o = np.zeros([L, Dim], "float32")
for r in range(0, L):
for i in range(0, Dim):
for j in range(0, L):
o[r, i] += softmax[r, j] * v[j, i]

return o

Softmax

重新看一下 softmax 是在做什么:

$$Softmax({x_1, x_2, …, x_N}) = \left \lbrace \frac{e^{x_i}}{\sum^N_{j=1}e^{x_j}} \right \rbrace^N_{i=1}$$

值得注意的是,目前通用的 softmax 实现中为了防止数值溢出还需要再额外减掉一个 max:

$$\begin{aligned}
m &= Max^N_{n=0}x_n\\
Softmax({x_1, x_2, …, x_N}) &= \left \lbrace \frac{e^{x_i - m}}{\sum^N_{j=1}e^{x_j - m}} \right \rbrace^N_{i=1}
\end{aligned}
$$

因此这里计算 max 需要一次独立的全局 reduce,计算分母的 sum 再需要一次独立的全局 reduce,最后分别计算每一个元素的 softmax 值。三个步骤之间存在数据依赖。

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
def normal_softmax(x):
out = np.array(x)
for r in range(0, L):
maxi = 0
for i in range(0, L):
maxi = max(maxi, x[r, i])
e_sum = 0
for i in range(0, L):
e_sum += np.exp(x[r, i] - maxi)
for i in range(0, L):
out[r, i] = np.exp(x[r, i] - maxi) / e_sum
return out

Online Softmax

【Online normalizer calculation for softmax】 这篇 paper 提出了一种能将上面的 3 步 softmax 合并成 2 步完成的思路。

考虑原始计算步骤中分母求最大值以及求和的部分,这里需要 3 个独立的循环:

$$\begin{align}
m_i &= \max(m_{i-1}, x_i) \\
d_i &= d_{i-1} + e^{x_i - m_N} = \sum^i_{j=1}e^{x_j - m_N} \\
softmax_i &= \frac{e^{x_i - m_N}}{d_N} \\
\end{align}$$

公式中产生数据依赖的原因是 $(2)$ 需要依赖 $m_N$,而 $(3)$ 需要依赖 $m_N$ 和 $d_N$。

如果能有这样一个 $d_i’$:

$$d_i’ = \sum^i_{j=1}e^{x_j - m_i} = d_{i-1}’ + e^{x_i - m_i}$$

则 $(2)$ 对 $m_N$ 的数据依赖就解除了,虽然序列的中间部分的值不相等但 $d_N$ 与 $d_N’$ 是等价的。

而 $d_i’$ 这个序列存在递推性质:

$$
\begin{aligned}
d_i’ &= \sum^i_{j=1}e^{x_j - m_i} \\
&= \sum^{i-1}_{j=1}e^{x_j - m_i} + e^{x_i-m_i} \\
&= \left ({\sum^{i-1}_{j=1}e^{x_j - m_{i-1}}} \right ) * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
&= d_{i-1}’ * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
\end{aligned}
$$

这样 $m_i$ 和 $d_i’$ 就可以在一个 kernel 中计算完成了,kernel 可以减少到两个:

$$\begin{aligned}
m_i &= \max(m_{i-1}, x_i) \\
d_i’ &= d_{i-1}’ * e^{m_{i-1} - m_i} + e^{x_i-m_i} \\
softmax_i &= \frac{e^{x_i - m_N}}{d_N’}
\end{aligned}$$

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def online_softmax_update(m0, d0, m1, d1):
# x 1
# Init value: MIN_M 0
m = max(m0, m1)
d = d0 * np.exp(m0 - m) + d1 * np.exp(m1 - m)
return m, d

def online_softmax(x):
out = np.zeros(x.shape, x.dtype)
for r in range(0, L):
m = MIN_M
d = 0
for i in range(0, L):
m, d = online_softmax_update(m, d, x[r, i], 1)
for i in range(0, L):
out[r, i] = np.exp(x[r, i] - m) / d
return out

但是这样对于加速来说还是不够,重新推几层:

Index 0 1 2 3
$m_i$ $x_0$ $\max(m_0, x_1)$ $\max(m_1, x_2)$ $\max(m_2, x_3)$
$d_i’$ 1 $e^{x_0-m_1} + e^{x_1-m_1}$ $e^{x_0-m_2} + e^{x_1-m_2} + e^{x_2-m_2}$ $e^{x_0-m_3} + e^{x_1-m_3} + e^{x_2-m_3} + e^{x_3-m_3}$

可以发现由于 exp/log 计算的特性,$d_3’$ 除了正常从 $d_0’$、$d_1’$、$d_2’$ 按顺序推出以外,乱序也是可以得到的,例如 2->1->0->3:

$$\begin{aligned}
m_{21} &= \max(x_2, x_1) \\
d_{21}’ &= e^{x_2 - m_{21}} + e^{x_1 - m_{21}}\\
m_{210} &= \max(x_0, m_{21}) = \max(x_0, x_2, x_1)\\
d_{210}’ &= d_{21}’ * e^{m_{21} - m_{210}} + e^{x_0 - m_{210}}\\
&= (e^{x_2 - m_{21}} + e^{x_1 - m_{21}}) * e^{m_{21} - m_{210}} + e^{x_0 - m_{210}}\\
&= e^{x_2 - m_{210}} + e^{x_1 - m_{210}} + e^{x_0 - m_{210}}\\
m_3 = m_{2103} &= \max(x_3, m_{210}) = \max(x_3, x_0, x_2, x_1) \\
d_3’=d_{2103}’&=d_{210}’ * e^{m_{210} - m_3} + e^{x_3 - m_3}\\
&=(e^{x_2 - m_{210}} + e^{x_1 - m_{210}} + e^{x_0 - m_{210}}) * e^{m_{210} - m_3} + e^{x_3 - m_3}\\
&=e^{x_2 - m_3} + e^{x_1 - m_3} + e^{x_0 - m_3} + e^{x_3 - m_3}\\
\end{aligned}$$

如果定义分块计算时 $d_{xy}’ = d_x’ * e^{m_x - m_{xy}} + d_y’ * e^{m_y - m_{xy}}$,则按照 (2->1)->(0->3) 的迭代顺序:

$$\begin{aligned}
m_{21} &= \max(x_2, x_1) \\
d_{21}’ &= e^{x_2 - m_{21}} + e^{x_1 - m_{21}}\\
m_{03} &= \max(x_0, x_3) \\
d_{03}’ &= e^{x_0 - m_{03}} + e^{x_3 - m_{03}}\\
m_3=m_{2103} &= \max(m_{21}, m_{03}) = \max(x_2, x_1, x_0, x_3)\\
d_3’=d_{2103}’ &= d_{21}’ * e^{m_{21} - m_3} + d_{03}’ * e^{m_{03} - m_3} \\
&= (e^{x_2 - m_{21}} + e^{x_1 - m_{21}}) * e^{m_{21} - m_3} + e^{x_0 - m_3} + (e^{x_0 - m_{03}} + e^{x_3 - m_{03}}) * e^{m_{03} - m_3} \\
&= e^{x_2 - m_3} + e^{x_1 - m_3} + e^{x_0 - m_3} + e^{x_3 - m_3}\\
\end{aligned}$$

同样可以得到相同的 $d_N’$。这样,我们就可以得到这种方式最大的一个特性:$m$ 和 $d’$ 的迭代计算操作同时满足交换律和结合律,任意分块分别计算 $m$ 和 $d’$ 之后,将所有子块结果重新聚合在数学上完全等价,即序列中 max 值带来的影响可以延迟到最后一步再被修正。

之前在看很多资料的时候并没有提到这一点,突然就一下跳到 block 分块之后的实现了,看得人一脸懵逼。

这样 Online softmax 就可以通过分块并行得到进一步的加速了。

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def online_block_softmax(x):
assert L % BLK == 0
out = np.array(x)
for r in range(0, L):

m = MIN_M
d = 0
for b in range(0, L // BLK):
# Calculate block
mm = MIN_M
dd = 0
for i in range(0, BLK):
mm, dd = online_softmax_update(mm, dd, x[r, b * BLK + i], 1)

# Merge to total
m, d = online_softmax_update(m, d, mm, dd)

for i in range(0, L):
out[r, i] = np.exp(x[r, i] - m) / d
return out

FlashAttention

继续回到一开始的 FlashAttention,我们把 Online Softmax 放进去,这里的 Softmax 是对 [L, L] 结果中的每一行做一维的 softmax:

$$\begin{aligned}
Softmax^L_{r=1} &= Softmax(X_{r, 1}, X_{r, 2}, …, X_{r, L})^L_{r=1}\\
&=\left \lbrace \left \lbrace
\frac{e^{X_{r, i} - M_{r, L}}}{\sum^L_{j=1}e^{X_{r, j} - m_{r, L}}}
\right \rbrace^L_{i=1} \right \rbrace^L_{r=1} \\
\end{aligned}$$

原始实现中:

$$\begin{aligned}
X_{r, i} &= \sum^{Dim}_{j=1}Q[r, j]K[j, i]\\
M_{r, i} &= \max(M_{r, i-1}, X_{r, i}) \\
D_{r, i} &= D_{r, i-1} + e^{X_{r, i} - M_{r, L}} = \sum^i_{j=1}e^{X_{r, j} - M_{r, L}} \\
Softmax_{r, i} &= \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}} \\
\end{aligned}$$

同样可以替换成 $D_{r, i}’$:

$$\begin{aligned}
D_{r, i}’ &= D_{r, i-1}’ * e^{M_{r, i-1} - M_{r, i}} + e^{X_{r, i}-M_{r, i}} \\
Softmax_{r, i} &= \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}’} \\
O_{r, c} &= \sum^L_{i=1}(Softmax_{r, i} * V[i, c]) \\
\end{aligned}$$

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def flashattn_0(q, k, v):
# [L, Dim] * [Dim, L] -> [L, L]
x = np.zeros([L, L], "float32")
for r in range(0, L):
for i in range(0, L):
for j in range(0, Dim):
x[r, i] += q[r, j] * k[i, j]

# [L, L] -> [L, L] * [L, Dim] -> [L, Dim]
o = np.zeros([L, Dim], "float32")
for r in range(0, L):
m = MIN_M
d = 0
for i in range(0, L):
m, d = online_softmax_update(m, d, x[r, i], 1)

softmax = np.zeros([L], "float32")
for i in range(0, L):
softmax[i] = np.exp(x[r, i] - m) / d
for c in range(0, Dim):
for i in range(0, L):
o[r, c] += softmax[i] * v[i, c]

return o

把 $O_{r, c}$ 的累加过程拆开看:

$$\begin{aligned}
SubSum_{r, c, i} &= SubSum_{r, c, i-1} + Softmax_{r, i} * V[i, c]\\
&=SubSum_{r, c, i-1} + \frac{e^{X_{r, i} - M_{r, L}}}{D_{r, L}’} * V[i, c]\\
&=\sum^i_{j=1}\frac{e^{X_{r, j} - M_{r, L}}}{D_{r, L}’}V[j, c]
\end{aligned}$$

可以发现 $SubSum_{r,c,i}$ 也是依赖于 $M_{r,L}$ 和 ${D_{r,L}’}$,运用与 online softmax 相似的方式,可以在这里增加一个 $SubSum_{r,c,i}’$:

$$\begin{aligned}
SubSum_{r,c,i}’ &= \sum^i_{j=1}\frac{e^{X_{r, j} - M_{r, i}}}{D_{r, i}’}V[j, c]\\
&=\sum^{i-1}_{j=1}\frac{e^{X_{r, j} - M_{r, i}}}{D_{r, i}’}V[j, c] + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
&=\left (\sum^{i-1}_{j=1}\frac{e^{X_{r, j} - M_{r, i-1}}}{D_{r, i-1}’}V[j, c] \right) * \frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
&=SubSum_{r,c,i-1}’*\frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
\end{aligned}$$

最终整理一下,在一个 $i = (1, L)$ 的循环中可以完成:

$$\begin{aligned}
X_{r, i} &= \sum^{Dim}_{j=1}Q[r, j]K[j, i]\\
M_{r, i} &= \max(M_{r, i-1}, X_{r, i})\\
D_{r, i}’ &= D_{r, i-1}’ * e^{M_{r, i-1} - M_{r, i}} + e^{X_{r, i}-M_{r, i}}\\
SubSum_{r,c,i}’ &=SubSum_{r,c,i-1}’*\frac{e^{M_{r, i-1} - M_{r, i}}D_{r,i-1}’}{D_{r,i}’} + \frac{e^{X_{r, i} - M_{r, i}}}{D_{r, i}’}V[i, c]\\
\end{aligned}$$

最终的输出结果为:

$$O_{r, c} = SubSum_{r,c,L}$$

这个就是论文里面的 Algorithm 1:

Algorithm 1

附录部分还有个包含了 mask 这些额外操作的完整版 Algorithm 2,这里就不再重复写了。

写成伪代码是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for (r = 1 to L)
for (i = 1 to L)
// [L, Dim] * [Dim, L] -> [L, L]
for (j = 1 to Dim)
X[r, i] += Q[r, j] * K[j, i]

// [L, L]
M[r, i] = max(M[r, i-1], X[r, i])

// [L, L]
D'[r, i] = D'[r, i-1] * e(...) + e(...)

// [L, Dim]
for (c = 1 to Dim)
O[r, c] += O[r, c] * e(...) * D'[r, i-1] / D'[r, i] + e(...) / D'[r, i] * V[i, c]

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def flashattn_update(m, d, m0, d0, s0, m1, d1, s1):
# | | | | | |
# | | | x v 1
# Init value: MIN_M 0 0
s = s0 * np.exp(m0 - m) * d0 / d + s1 * np.exp(m1 - m) * d1 / d
return s


def flashattn_1(q, k, v):
# [L, Dim] * [Dim, L] -> [L, L]
x = np.zeros([L, L], "float32")
for r in range(0, L):
for i in range(0, L):
for j in range(0, Dim):
x[r, i] += q[r, j] * k[i, j]

# [L, L] -> [L, L] * [L, Dim] -> [L, Dim]
o = np.zeros([L, Dim], "float32")
for r in range(0, L):
m = []
d = []
for i in range(0, L):
mm, dd = online_softmax_update(
m[-1] if i > 0 else MIN_M, d[-1] if i > 0 else 0, x[r, i], 1
)
m.append(mm)
d.append(dd)

for c in range(0, Dim):
s = 0
for i in range(0, L):
s = flashattn_update(
m[i],
d[i],
m[i - 1] if i > 0 else MIN_M,
d[i - 1] if i > 0 else 0,
s,
x[r, i],
v[i, c],
1,
)
o[r, c] = s
return o

当然,与 Online softmax 一样,$D_{r, i}’$、$SubSum_{r,c,i}’$ 也具有分块满足交换律和结合律的特性:

$$\begin{aligned}
D_{r, xy}’ &= D_{r, x}’ * e^{M_{r, x} - M_{r, xy}} + D_{r, y}’ * e^{M_{r, y} - M_{r, xy}}\\
SubSum_{r,c,xy}’ &= SubSum_{r,c,x}’ * \frac{e^{M_{r, x}-M_{r, xy}}D_{r, x}’}{D_{r, xy}’} + SubSum_{r,c,y}’ * \frac{e^{M_{r, y}-M_{r, xy}}D_{r, y}’}{D_{r, xy}’}\\
\end{aligned}$$

则在常规矩阵乘法计算中可以用到的 tiling 分块策略也同样可以用在这个算法上得到终极加速了。

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def flashattn_1_block(q, k, v):
assert L % BLK == 0
# [L, Dim] * [Dim, L] -> [L, L]
x = np.zeros([L, L], "float32")
for r in range(0, L):
for i in range(0, L):
for j in range(0, Dim):
x[r, i] += q[r, j] * k[i, j]

# [L, L] -> [L, L] * [L, Dim] -> [L, Dim]
o = np.zeros([L, Dim], "float32")
for r in range(0, L):
m = np.zeros([L // BLK], "float32")
d = np.zeros([L // BLK], "float32")
mm = np.zeros([L], "float32")
dd = np.zeros([L], "float32")
for b in range(0, L // BLK):
# Calculate block
for i in range(0, BLK):
mm[b * BLK + i], dd[b * BLK + i] = online_softmax_update(
mm[b * BLK + i - 1] if i > 0 else MIN_M,
dd[b * BLK + i - 1] if i > 0 else 0,
x[r, b * BLK + i],
1,
)

# Merge to total
m[b], d[b] = online_softmax_update(
m[b - 1] if b > 0 else MIN_M,
d[b - 1] if i > 0 else 0,
mm[(b + 1) * BLK - 1],
dd[(b + 1) * BLK - 1],
)

for c in range(0, Dim):
s = 0
for b in range(0, L // BLK):
# Calculate block
ss = 0
for i in range(0, BLK):
ss = flashattn_update(
mm[b * BLK + i],
dd[b * BLK + i],
mm[b * BLK + i - 1] if i > 0 else MIN_M,
dd[b * BLK + i - 1] if i > 0 else 0,
ss,
x[r, b * BLK + i],
v[b * BLK + i, c],
1,
)

# Merge to total
s = flashattn_update(
m[b],
d[b],
m[b - 1] if b > 0 else MIN_M,
d[b - 1] if b > 0 else 0,
s,
mm[(b + 1) * BLK - 1],
dd[(b + 1) * BLK - 1],
ss,
)
o[r, c] = s
return o

FlashAttention-2

【FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning】

2 相比 1 在算法上改动其实很小,注意到前面公式里面 $SubSum$ 的递推式需要反复乘除 $D’$,我们来简单推一下,原递推式可以简写成:

$$
S[i] = S[i-1] * E * D[i-1] / D[i] + E * V / D[i]
$$

把 $D[i]$ 乘到左边:

$$
S[i] * D[i] = S[i-1] * E * D[i-1] + E * V
$$

记:

$$
SS[i] = S[i] * D[i]
$$

则上式可改写成:

$$\begin{aligned}
SS[i] &= SS[i-1] * E + E * V\\
S[i] &= SS[i-1] / D[i]
\end{aligned}$$

这也就是论文中的更新版本了:

FlashAttention2 Algorithm 1

Show me the code!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#                      m, d, m0, d0, s0, m1, d1, s1):
def flashattn_2_update(m, m0, s0, m1, d1, s1):
# | | | | |
# | | x v 1
# Init value: MIN_M 0
s = s0 * np.exp(m0 - m) + s1 * np.exp(m1 - m) * d1
return s


def flashattn_2_block(q, k, v):
assert L % BLK == 0
# [L, Dim] * [Dim, L] -> [L, L]
x = np.zeros([L, L], "float32")
for r in range(0, L):
for i in range(0, L):
for j in range(0, Dim):
x[r, i] += q[r, j] * k[i, j]

# [L, L] -> [L, L] * [L, Dim] -> [L, Dim]
o = np.zeros([L, Dim], "float32")
for r in range(0, L):
m = np.zeros([L // BLK], "float32")
d = np.zeros([L // BLK], "float32")
mm = np.zeros([L], "float32")
dd = np.zeros([L], "float32")
for b in range(0, L // BLK):
# Calculate block
for i in range(0, BLK):
mm[b * BLK + i], dd[b * BLK + i] = online_softmax_update(
mm[b * BLK + i - 1] if i > 0 else MIN_M,
dd[b * BLK + i - 1] if i > 0 else 0,
x[r, b * BLK + i],
1,
)

# Merge to total
m[b], d[b] = online_softmax_update(
m[b - 1] if b > 0 else MIN_M,
d[b - 1] if i > 0 else 0,
mm[(b + 1) * BLK - 1],
dd[(b + 1) * BLK - 1],
)

for c in range(0, Dim):
s = 0
for b in range(0, L // BLK):
# Calculate block
ss = 0
for i in range(0, BLK):
ss = flashattn_2_update(
mm[b * BLK + i],
# dd[b * BLK + i],
mm[b * BLK + i - 1] if i > 0 else MIN_M,
# dd[b * BLK + i - 1] if i > 0 else 0,
ss,
x[r, b * BLK + i],
v[b * BLK + i, c],
1,
)

# Merge to total
s = flashattn_2_update(
m[b],
# d[b],
m[b - 1] if b > 0 else MIN_M,
# d[b - 1] if b > 0 else 0,
s,
mm[(b + 1) * BLK - 1],
dd[(b + 1) * BLK - 1],
ss / dd[(b + 1) * BLK - 1],
)
o[r, c] = s / d[L // BLK - 1]
return o

FlashDecoding & FlashDecoding++

【Flash-Decoding for long-context inference】

这两篇后续在算法本身上没有什么特别的改进,主要是针对 LLM 的具体算子尺寸情况更精细地做了 cuda block 的划分和优化,这里也不再详细记录了。

完整的 test python 文件:softmax_test.py