自工作以来荒废了好久了,我又开始写博客啦~(不知道这一次能坚持多久…)
最近在比较多的帮团队面试,因为我们这边主要是 focus 在 AI 编译栈以及更偏底层一些的算子实现方面的工作,在面一些编译方向的候选人时,我这边最后一关通常是让他们写一个高效的矩阵乘的实现。
然后结果一直都没有遇到能写出来让我觉得还可以的人…唉,是我要求的太高了吗?
根据候选人的背景和平时的擅长领域,对这个问题我会做一些不同的调整:
常规的是 x86 AVX 或者 ARM Neo 的 SIMD 实现;
对 GPU 更熟的,会让写个 CUDA 的实现,具体考察对 shared memory 这些的理解;
对于我觉得经验应该要更丰富一些的选手,我会再多问一些这里面跟多线程、cache 相关的知识。
并不要求一定要把 code 写的很好,写不出来也没关系,毕竟还真不一定是每个搞这块的人都有自己手撸这些的经验,但是我想看到候选人对这个问题里面一些细节的思考。
面了挺多在这方面有好多年工作经验的候选人了,得到的答案都让我不太满意,里面也不乏简历上写了有比较多的算子调优、优化经验的人,我甚至都开始怀疑是不是我自己的认识出了问题…所以打算自己好好写一写这个问题,这篇先从简单的 SIMD 开始吧,也先不考虑多线程这些问题。
测试设备是我自己的 MBP,一块 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
的 CPU,虽然架构号还挺高的,不过挺可惜没有 AVX512。
首先的问题是一个达到性能极限的程序大概能跑多快?这里引用一个分析:
里面提到的 cpufp 这个小工具写的挺好的,我自己也经常在机器上跑这个测。不过我现在用的这个 commit 在 mac 上有一点点小问题,稍微 fix 一下以后在我的 mac 上可以跑出来:
1 2 3 4 5 6 7 Thread(s): 1 fma fp32 perf: 135.3979 gflops. fma fp64 perf: 68.7643 gflops. avx fp32 perf: 69.1885 gflops. avx fp64 perf: 34.3289 gflops. sse fp32 perf: 34.0870 gflops. sse fp64 perf: 17.4539 gflops.
这里的 fma 是 avx2 下的乘累加指令,已经是我这块 CPU 上能支持的最高效的计算指令了。暂时先不考虑睿频这些的影响(2.60GHz 加上 AVX2 FMA 的指令特性是能算出一个理论上限的),我们可以把这里测出来的作为能够跑出来的峰值性能的上限。
先找个计算库看看性能情况,OneDNN 就不错,拉下来 build 一下,test 目录中有个 benchdnn 可以直接跑,这里就把目标设的小一点,跑一个 [128, 128] x [128, 128] 的矩阵乘吧:
1 2 3 4 5 6 7 8 9 $ OMP_NUM_THREADS=1 ./benchdnn --matmul --verbose=99 --mode=p 128x128:128x128 run: --matmul --mode=P 128x128:128x128:128x128 oneDNN implementation: gemm:jit oneDNN implementation: gemm:jit Requested: 0.000366211 GB, benchdnn limit : 24 GB, CPU RAM capacity: 32 GB, GPU RAM capacity: 0 GB Output template: perf,%engine%,%impl%,%name%,%prb%,%Gops%,%Gfreq%,%-time%,%-Gflops%,%0time%,%0Gflops% perf,cpu,gemm:jit,,--matmul --mode=P 128x128:128x128:128x128,0.0041943,0,0.036722,114.218,0.0393784,106.513 tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 failed:0 listed:0 total perf: min(ms):0.036722 avg(ms):0.0393784
平均的 Gflops 是 106.513,差不多 78% 左右的峰值性能,毕竟是开源的 OneDNN 版本,可能再多做一些参数调整或者换上 blas 以后性能能再好一些,基本上可以认为跑到 80% 左右的性能可以算很不错了。
Have a try on Ansor 然后,嗯…做完 Ansor 之后一直没有好好写点跟它相关的东西(主要是因为我懒),其实还挺惭愧的。
不过也是最近在知乎上看到个帖子,有同学质疑 Anosr 测下来只能跑到一半不到的性能(虽然里面有一些其他方面的原因,这个结论还是让人挺沮丧),就顺便也贴一个随手写的测试结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 import numpy as npimport tvmfrom tvm import auto_scheduler, te@auto_scheduler.register_workload def gemm_test (M, N, K ): a = te.placeholder([M, K]) b = te.placeholder([K, N]) k = te.reduce_axis((0 , K), name="k" ) c = te.compute([M, N], lambda i, j: te.sum (a[i, k] * b[k, j], k), name="c" ) return [a, b, c] M = N = K = 128 target = tvm.target.create("llvm -mcpu=skylake" ) dev = tvm.cpu(0 ) log_file = "gemm.log" a_data = np.random.rand(M, K).astype("float32" ) b_data = np.random.rand(K, N).astype("float32" ) c_data = np.dot(a_data, b_data) a_tvm = tvm.nd.array(a_data) b_tvm = tvm.nd.array(b_data) c_tvm = tvm.nd.empty(c_data.shape) task = auto_scheduler.SearchTask(func=gemm_test, args=(M, N, K), target=target) print ("Computational DAG:" )print (task.compute_dag)ts = auto_scheduler.task_scheduler.TaskScheduler([task], load_log_file=log_file) tune_option = auto_scheduler.TuningOptions( num_measure_trials=200 , measure_callbacks=[auto_scheduler.RecordToFile(log_file)], verbose=2 , ) sch, args = task.apply_best(log_file) func = tvm.build(sch, args, target) evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500 ) costs = np.median(evaluator(a_tvm, b_tvm, c_tvm).results) print ("Execution time of this operator: %.3f ms" % (costs * 1e3 ))print ("Gflops: " , 2 * M * N * K / costs / 1e9 )
大概跑个 500 ~ 1000 步回来看结果,这个算子不算大,搜起来非常快,多跑个几分钟就跑完了:
1 2 3 4 5 6 7 8 $ TVM_NUM_THREADS=1 python gemm.py Computational DAG: placeholder = PLACEHOLDER [128, 128] placeholder = PLACEHOLDER [128, 128] c(i, j) += (placeholder[i, k]*placeholder[k, j]) Execution time of this operator: 0.038 ms Gflops: 111.55499503743454
差不多 82% 左右的峰值性能,可能这里面会有一些测量误差,已经差不多算比较满意的成绩了。谦虚一些可以看成跟前面 OneDNN 的结果基本相当吧,在实际的模型里面应用的时候结合点别的优化手段还可以再提升一点点。
What does Ansor do? 用 Ansor/TVM 的好处是所有的一切动作对我们来说都是可控的,想知道它生成出来的 schedule 长啥样,加点代码打出来就好了:
1 2 3 4 inp, _ = load_best_record(log_file, task.workload_key) s = task.compute_dag.infer_bound_from_state(inp.state) print (s)print (tvm.lower(sch, args))
首先是这个 schedule 的 state:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 Placeholder: placeholder, placeholder parallel i.0@ (0,16) for j.0 (0,4) for i.1 (0,2) for j.1 (0,2) c.local auto_unroll: 16 for k.0 (0,16) for k.1 (0,8) for i_c.3 (0,4) vectorize j_c.3 (0,16) c.local = ... for i.2 (0,4) vectorize j.2 (0,16) c = ...
这里面用上的优化策略其实很简单,就是基本的一些常规操作 Tiling、Unroll、Vectorize。
虽然我前面把核数限制到了单核上,但是结果还是带上了 Parallel 的 attribute,这个在我看来好像是个代码里面的 bug … 试了下这个 kernel 在 TVM_NUM_THREADS=4 下面居然还能够跑到 400 左右的 Gflops,多核扩展性看起来也还可以哦。
对应的 TVM ir:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 primfn(placeholder_2: handle, placeholder_3: handle, c_1: handle) -> () attr = {"from_legacy_te_schedule" : True, "global_symbol" : "main" , "tir.noalias" : True} buffers = {c: Buffer(c_2: Pointer(float32), float32, [128, 128], []), placeholder_1: Buffer(placeholder_4: Pointer(float32), float32, [128, 128], []), placeholder: Buffer(placeholder_5: Pointer(float32), float32, [128, 128], [])} buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1, c_1: c} { for (i.outer.outer: int32, 0, 16) "parallel" { allocate(c.local: Pointer(local float32x16), float32x16, [4]), storage_scope = local ; for (j.outer.outer: int32, 0, 4) { for (i.outer.inner: int32, 0, 2) { for (j.outer.inner: int32, 0, 2) { c.local[ramp(0, 1, 16)] = broadcast(0f32, 16) c.local[ramp(16, 1, 16)] = broadcast(0f32, 16) c.local[ramp(32, 1, 16)] = broadcast(0f32, 16) c.local[ramp(48, 1, 16)] = broadcast(0f32, 16) for (k.outer: int32, 0, 16) { for (k.inner: int32, 0, 8) { c.local[ramp(0, 1, 16)] = ((float32x16*)c.local[ramp(0 , 1 , 16 )] + (broadcast((float32*)placeholder_5[((((i.outer.outer*1024 ) + (i.outer.inner*512 )) + (k.outer*8)) + k.inner)], 16)*(float32x16*)placeholder_4[ramp(((((k.outer*1024 ) + (k.inner*128 )) + (j.outer.outer*32)) + (j.outer.inner*16)), 1, 16)])) c.local[ramp(16, 1, 16)] = ((float32x16*)c.local[ramp(16 , 1 , 16 )] + (broadcast((float32*)placeholder_5[(((((i.outer.outer*1024 ) + (i.outer.inner*512 )) + (k.outer*8)) + k.inner) + 128)], 16)*(float32x16*)placeholder_4[ramp(((((k.outer*1024 ) + (k.inner*128 )) + (j.outer.outer*32)) + (j.outer.inner*16)), 1, 16)])) c.local[ramp(32, 1, 16)] = ((float32x16*)c.local[ramp(32 , 1 , 16 )] + (broadcast((float32*)placeholder_5[(((((i.outer.outer*1024 ) + (i.outer.inner*512 )) + (k.outer*8)) + k.inner) + 256)], 16)*(float32x16*)placeholder_4[ramp(((((k.outer*1024 ) + (k.inner*128 )) + (j.outer.outer*32)) + (j.outer.inner*16)), 1, 16)])) c.local[ramp(48, 1, 16)] = ((float32x16*)c.local[ramp(48 , 1 , 16 )] + (broadcast((float32*)placeholder_5[(((((i.outer.outer*1024 ) + (i.outer.inner*512 )) + (k.outer*8)) + k.inner) + 384)], 16)*(float32x16*)placeholder_4[ramp(((((k.outer*1024 ) + (k.inner*128 )) + (j.outer.outer*32)) + (j.outer.inner*16)), 1, 16)])) } } for (i.inner: int32, 0, 4) { c_2[ramp((((((i.outer.outer*1024 ) + (i.outer.inner*512 )) + (i.inner*128)) + (j.outer.outer*32)) + (j.outer.inner*16)), 1, 16)] = (float32x16*)c.local[ramp((i.inner*16 ), 1 , 16 )] } } } } } }
这个 schedule 里面最核心的计算块其实是在做一个 [4, 8] x [8, 16] = [4, 16] 子矩阵运算,TVM ir 层面对最内层的 j_c.3 做了向量化,并且对 i_c.3 做了循环展开。
1 2 3 4 for k.1 (0,8) for i_c.3 (0,4) vectorize j_c.3 (0,16) c.local = ...
考虑 A、B、C 三个矩阵各自的访存顺序,C 和 B 都是在 j 方向上连续访问,因为我们 Ansor 默认的策略也是直接在这个维度上做向量化。AVX2 的指令长度是 256 位,对应到 float32 上是 8 个,所以虽然这里搜出来的 vectorize 的长度是 16,其实翻译成指令会是两条 fma。
在这个子矩阵的计算里面,每次是从 A 矩阵读出一个值,broadcast 成 8 份,然后跟 B 矩阵里面连续读出来的 8 个数做对应位置相乘,再累加到输出上。
存放 C 矩阵的输出需要用上 4 * 2 * 8 / 8 一共 8 个 256 位的向量寄存器,B 矩阵的数据可以重复复用,只需要两个寄存器,这里把 A 的 4 个值分别 broadcast 需要 4 个寄存器,一共 14 个寄存器可以解决:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 vector_b_0 = vector_load(b[0-7]) vector_b_1 = vector_load(b[8-15]) vector_a_0 = broadcast(a[0]) vector_a_0 * vector_b_0 -> vector_c_0 vector_a_0 * vector_b_1 -> vector_c_1 vector_a_1 = broadcast(a[1]) vector_a_1 * vector_b_0 -> vector_c_2 vector_a_1 * vector_b_1 -> vector_c_3 vector_a_2 = broadcast(a[2]) vector_a_2 * vector_b_0 -> vector_c_4 vector_a_2 * vector_b_1 -> vector_c_5 vector_a_3 = broadcast(a[3]) vector_a_3 * vector_b_0 -> vector_c_6 vector_a_3 * vector_b_1 -> vector_c_7
嗯 … 看起来 A 这里的寄存器其实还可以再复用到同一个上,可以进一步把寄存器数量压到 11 个,不过这么做也就额外引入了一个数据依赖?况且 16 个寄存器还没有全部用满,这段代码的效率还有再进一步提高的空间。
看下最后生成的汇编指令是什么样子的吧:
1 print (func.get_source("s" ))
从完整的汇编代码里面截取了其中的一小段:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 vbroadcastss -1052(%rcx,%r12), %ymm9 vmovaps -3584(%r9), %ymm10 vmovaps -3552(%r9), %ymm11 vmovaps -3072(%r9), %ymm8 vfmadd231ps %ymm10, %ymm9, %ymm6 vfmadd231ps %ymm9, %ymm11, %ymm7 vbroadcastss -540(%rcx,%r12), %ymm9 vfmadd231ps %ymm10, %ymm9, %ymm4 vfmadd231ps %ymm9, %ymm11, %ymm5 vbroadcastss -28(%rcx,%r12), %ymm9 vfmadd231ps %ymm10, %ymm9, %ymm2 vfmadd231ps %ymm9, %ymm11, %ymm3 vbroadcastss 484(%rcx,%r12), %ymm9 vfmadd231ps %ymm10, %ymm9, %ymm0 vfmadd231ps %ymm11, %ymm9, %ymm1 vbroadcastss -1048(%rcx,%r12), %ymm9 vmovaps -3040(%r9), %ymm10 vfmadd231ps %ymm10, %ymm9, %ymm7 vfmadd231ps %ymm9, %ymm8, %ymm6 vbroadcastss -536(%rcx,%r12), %ymm9 vfmadd231ps %ymm10, %ymm9, %ymm5 vfmadd231ps %ymm9, %ymm8, %ymm4 vbroadcastss -24(%rcx,%r12), %ymm9 vfmadd231ps %ymm10, %ymm9, %ymm3 vfmadd231ps %ymm9, %ymm8, %ymm2 vbroadcastss 488(%rcx,%r12), %ymm9 vfmadd231ps %ymm10, %ymm9, %ymm1 vfmadd231ps %ymm8, %ymm9, %ymm0 vbroadcastss -1044(%rcx,%r12), %ymm8
上面这个是 8 段类似代码里面的两段,可以发现 LLVM 在往下 lower 成汇编指令的时候进一步对计算逻辑里面的 k.1 也做了循环展开。
以第一段举例,如前面所分析的,输出用了 ymm0 ~ ymm7 这 8 个向量寄存器。每一段里面把 A 的数据反复 broadcast 到 ymm9 这一个向量寄存器上,ymm10 和 ymm11 用来存 B 矩阵的数据(好吧…还是只用了 11 个)。
Try more 如果把 M、N、K 都改成 144 再试一下,发现很容易就能搜出一个性能能跑的更高的 kernel:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 Execution time of this operator: 0.047 ms Gflops: 126.21408038450134 Placeholder: placeholder, placeholder for j.0 (0,9) c.local auto_unroll: 64 for i_c.1 (0,24) for k.0 (0,24) for i_c.2 (0,6) for k.1 (0,6) vectorize j_c.3 (0,16) c.local = ... for i.1 (0,144) for j.1 (0,16) c = ...
从汇编里面可以找到这段 schedule 能够跑出 93% 峰值性能的 schedule 的秘密:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 vbroadcastss -2324(%rdx,%rbx), %ymm14 vmovaps (%rbp), %ymm13 vmovaps 32(%rbp), %ymm12 vfmadd231ps %ymm13, %ymm14, %ymm10 vfmadd231ps %ymm14, %ymm12, %ymm11 vbroadcastss -1748(%rdx,%rbx), %ymm14 vfmadd231ps %ymm13, %ymm14, %ymm8 vfmadd231ps %ymm14, %ymm12, %ymm9 vbroadcastss -1172(%rdx,%rbx), %ymm14 vfmadd231ps %ymm13, %ymm14, %ymm6 vfmadd231ps %ymm14, %ymm12, %ymm7 vbroadcastss -596(%rdx,%rbx), %ymm14 vfmadd231ps %ymm13, %ymm14, %ymm4 vfmadd231ps %ymm14, %ymm12, %ymm5 vbroadcastss -20(%rdx,%rbx), %ymm14 vfmadd231ps %ymm13, %ymm14, %ymm2 vfmadd231ps %ymm14, %ymm12, %ymm3 vbroadcastss 556(%rdx,%rbx), %ymm14 vfmadd231ps %ymm13, %ymm14, %ymm0 vbroadcastss -2320(%rdx,%rbx), %ymm13 vfmadd231ps %ymm12, %ymm14, %ymm1
除去最后一个往 ymm13 里面做 broadcast 的指令,这段代码生成得是真的是工整(当然其实太工整也不一定是好事,有时候为了充分利用流水线还得做指令重排把它们弄的“乱”一些,这是另外一个问题了)。跟之前类似,ymm12 和 ymm13 存放 B 矩阵的 16 个连续元素,每次取一个数 broadcast 到 ymm14 中,所有结果累加到 ymm0 ~ ymm11 这 12 个向量寄存器中,整个过程中一共用满了 15 个向量寄存器。
那么问题来了,寄存器用的更多性能就一定好吗?
当然不是……严格来说这得去分析指令流水线,能够打满流水,把所有的计算部件都用上才能发挥出最大的计算性能。
不过如果把问题简化一下,从过往经验上来看,如果我们假定计算的 pattern 已经排布的非常高效了,通常计算优化到最后都很容易最终会 bound 到 memory 上(这一点在 CPU 和 GPU 上都是成立的,为什么 A100 的 TensorCore 要出个 2 比 4 的稀疏?也是因为算力已经压榨到极限了,最后跟不上的反而是访存,这样其实相当于压缩了一倍的访存。Emm…这又是另外一个问题了)。至少如果已知还有资源没用上,当然还是想办法把所有能用的东西都用起来。
这里还有个用满 16 个寄存器跑到 99% 的峰值性能的例子,基本上可以说是把这个游戏玩到底了(虽然不用全部用上其实也有办法达到接近性能峰值,不过这个也是个很好的例子):
Ansor 在这一点上的局限性其实在于 tiling 的每个 split factor 的选择都是跟算子的原始尺寸相关的,为了避免引入一些对 index 的 if else 判断,默认的 split factor 采用的都是每个 axis 的约数。因此在 M == N == K == 128
的 case 中 Ansor 永远不可能搜出来 144 中最后那种 micro kernel 的尺寸(6 * 16 * 6)。这个局限性在最严重时候的体现是我们在业务中曾经遇到过两个尺寸非常接近只在某一个维度上有略微差别的 op,axis 是 59 的矩阵乘会与 60 的有很大的性能差距,原因就是 59 这个质数长度的 axis 在我们的策略中没办法 split,因此始终搜不出一个效率比较高的 micro kernel。
曾经想过的几种策略:
Loop partition:把类似 59 这种拆成两个或者多个循环(exp. 48 + 1),在各自单独的循环中可能能跑出更好的性能,即使不一定是最优的,综合性能超过不拆之前应该还是比较容易的;
Compute padding:浪费一些计算资源,如果我直接用 60 的 schedule 去跑 59,多算的那一点点就抛掉不要了,是不是也能至少跑出更好的性能来?(当然这个实际操作起来可能还有别的麻烦要处理)
希望后面有机会的时候能把这些想法都尝试一下吧。
关于 2 这一条,之前看过一个也很有意思的工作做的更激进:为了提高计算密度,对某些利用率不高的计算过程做有损变换,最后再加一个修正的 stage 把结果调回来。类似把 dilation 卷积变成常规卷积做,最后把结果重新修正回来这种方式,即使浪费了计算量也引入了额外的修正操作,可能在某些情况下还是能有性能收益的。 可惜忘了论文题目了…后面找到再回来补上吧。
Any other? 回到一开始的 SIMD 本身上,向量化矩阵乘还有别的实现方式吗?
当然。
对于一个 A、B 矩阵均为 NN layout(非转置)的矩阵乘运算,上面的做法其实是把这个基础的三重循环:
1 2 3 4 for (i = 0 to n) for (j = 0 to n) for (k = 0 to n) c[i][j] += a[i][k] * b[k][j]
变成了:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 for (i = 0 to n) for (j = 0 to n / v) for (k = 0 to n) for (jj = 0 to v) c[i][j][jj] += a[i][k] * b[k][j][jj] -> for (i = 0 to n) for (j = 0 to n / v) for (k = 0 to n) v_a = broadcast (a[i][k], v) v_b = vector_load (b[k][j], v) v_c = vector_load (c[i][j], v) vector_fma (v_a, v_b, v_c) vector_store (v_c, c[i][j], v)
的过程。
如果对 A、B、C 矩阵加上转置,则可以让这个计算过程根据需要做到各种方向上的数据连续存储,也可以做到更多种实现方式。
可以对 k 方向上做向量化吗?
当然。
很多候选人一看到这个 i,j,k 三重循环的表达式,上来就直接给我把 k 拆了,然后看了一会感觉好像哪里不太对,就进行不下去了。
k 在这里是个 reduce 的方向,常见的向量化部件设计上确实比较少有能够直接处理这种操作的,也有像 arm v8 的 neon 上就有提供了一条叫 sdot 的指令,可以对 int8 的数据做向量点积,最后累加到一个 int32 的寄存器上。这个过程的示例是从一个转置的 B 矩阵开始,做完转置之后 A、B 就都在 k 方向上连续了:
1 2 3 4 for (i = 0 to n) for (j = 0 to n) for (k = 0 to n) c[i][j] += a[i][k] * b[j][k]
然后再往下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 for (i = 0 to n) for (j = 0 to n) temp_sum = 0 for (k = 0 to n / v) for (kk = 0 to v) temp_sum += a[i][k][kk] * b[j][k][kk] c[i][j] = temp_sum -> for (i = 0 to n) for (j = 0 to n) temp_sum = 0 for (k = 0 to n / v) v_a = vector_load (a[i][k], v) v_b = vector_load (b[j][k], v) temp_sum += vector_dot (v_a, v_b) c[i][j] = temp_sum
额,上面这种方式其实我没有实际写过…
如果非要在不支持 dot 指令的硬件上实现 k 方向上的向量化,只是相当于手动实现一下这个 vector_dot 的过程:
1 2 3 4 5 6 7 8 for (i = 0 to n) for (j = 0 to n) v_c = broadcast (0 , v) for (k = 0 to n / v) v_a = vector_load (a[i][k], v) v_b = vector_load (b[j][k], v) vector_fma (v_a, v_b, v_c) c[i][j] = reduce (v_c)
咦,好像看起来也是有可能达到比较高的效率的。注意,虽然最后这里对 v_c 的 v 个元素做 reduce 求和的操作我不确定是不是各硬件平台上就有现成的指令可以直接做,不过可以看到中间最核心的代码块一直是在做高效的向量访存和向量乘加,如果 k 方向上的长度是一个比较大的值(即 n / v
是个比较大的值),则即使最后只是挨个把这 v 个元素做标量加法应该也是能有一定的性能收益的。
似乎 MKL 里面的实现就是比较多的采用了把 NN 矩阵转置成 NT 以后再算的,不知道是不是主要用的这种实现方式。
另:引申一下 tvm 中提供了一条叫 rfactor 的 schedule primitive,是针对 parallel 做的,不过实现方式上跟这个思路也有一些相似点。Reduce axis 在 tvm 中本身并不支持直接做 parallel 或者 vectorize,就把这个过程拆成两步做,可以在其中的一步做上 parallel / vectorize,只要 k 足够大就是能够有性能收益的。