这一篇来记一下 TVM 跟语法树 IR 相关的一些内容。
TVM 是在 Halide 的基础上发展而来的一套编译架构。早一点版本的 TVM 其实还能看到 Halide 是作为一个 git 的 submodule 放在 TVM 目录里的,当时 TVM 自身的代码中还有一个 HalideIR
的 namespace,也有不少的结构是直接继承了 Halide IR 里的内容:
1 2 3 4 5 6 7 8 ... using HalideIR::Type;using HalideIR::Float;using HalideIR::Bool;using HalideIR::Int;using HalideIR::UInt;using HalideIR::Handle;...
[INFA][IR] Build and Evolve Low-level IR. Remove HalideIR dep. #3533 这个 PR 之后,Halide 相关的部分逐渐从 TVM 中删掉了,之后可以说从代码层面已经跟 Halide 没有关系了。
目前 TVM IR 的结构基本上还是从 Halide IR 一脉相承,可能以后 TVM IR 进一步演化以后跟 Halide 的差别就更大了。
IR expr.h
里面定义了 TVM IR 的两个基础结构:Expr 和 Stmt,分别是语法表达式和语法树节点的基类。
Expr 的派生类有加减乘除、IntImm、FloatImm 等等,从文法上可以做一些 symbolic 的处理。
Stmt 的派生类有 AttrStmt(语法树属性节点)、Store(数据存储节点)、Allocate(数据 Buffer 分配节点)等等。
每个 Stmt 结构本身表示一个独立的语法树节点,但是语法树节点之间相互嵌套,通过 Stmt 的 body(Stmt 的通常结构)等成员继续向下查看就能够看到一颗完整的抽象语法树(AST)了。
例如 IfThenElse 这个 Stmt 的结构:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 class IfThenElse : public StmtNode { public : Expr condition; Stmt then_case; Stmt else_case; void VisitAttrs (AttrVisitor* v) final { v->Visit ("condition" , &condition); v->Visit ("then_case" , &then_case); v->Visit ("else_case" , &else_case); } TVM_DLL static Stmt make (Expr condition, Stmt then_case, Stmt else_case = Stmt()) ; static constexpr const char * _type_key = "IfThenElse" ; TVM_DECLARE_NODE_TYPE_INFO (IfThenElse, StmtNode); };
判断条件是个 Expr,then_case 和 else_case 都是另外两个 Stmt,再展开又是两棵子树。
IRVisitor & IRMutator TVM 中定义了一些 ir_pass 来处理语法树,通过对语法树的修改和调整来完成编译优化的过程。
各种 ir_pass 的核心结构是 IRVisitor 和 IRMutator,从名称上也可以很容易看出来,IRVisitor 的功能是遍历语法树收集信息,本身对语法树的访问是只读的,然后通过 IRMutator 完成 ir_pass 需要的语法树修改需求。
看下 IRVisitor 的结构:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class TVM_DLL IRVisitor { public : virtual void Visit (const NodeRef& node) { static const FVisit& f = vtable (); if (node.defined ()) f (node, this ); } virtual ~IRVisitor () {} using FVisit = IRFunctor<void (const NodeRef&, IRVisitor*)>; static FVisit& vtable () ; virtual void Visit_ (const Variable* op) ; virtual void Visit_ (const LetStmt* op) ; virtual void Visit_ (const AttrStmt* op) ; virtual void Visit_ (const IfThenElse* op) ; virtual void Visit_ (const For* op) ; ... }
这里面最重要的实现其实是自己构造了一个虚函数表:
1 2 3 IRVisitor::FVisit& IRVisitor::vtable () { static FVisit inst; return inst; }
之后 IRVisitor 的派生类只需要去重载针对不同 Stmt 类型的 Visit_()
函数就好了。
Dump Ast 看 TVM 代码包括 debug 的时候一开始会觉得两眼一抹黑连语法树长啥样都没有个概念,也不知道 build 过程中每个 ir_pass 具体做了什么事情,就很想要有个工具能把语法树打成图看下。
最开始 TVM 里面是找不到这种现成的工具的,后来在论坛里有看到别人提的这方面相关 RFC,不过我后续也没有去关注最后到底有没有收进 repo 里了。其实了解完 IRVisitor 的实现之后,自己写一个语法树的 Dumper 还是挺简单的。
ir_visitor.cc
里面有一个 PostOrderVisit 的示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 class IRApplyVisit : public IRVisitor { public : explicit IRApplyVisit (std::function<void (const NodeRef&)> f) : f_(f) { } void Visit (const NodeRef& node) final { if (visited_.count (node.get ()) != 0 ) return ; visited_.insert (node.get ()); IRVisitor::Visit (node); f_ (node); } private : std::function<void (const NodeRef&)> f_; std::unordered_set<const Node*> visited_; }; void PostOrderVisit (const NodeRef& node, std::function<void (const NodeRef&)> fvisit) { IRApplyVisit (fvisit).Visit (node); }
按照对语法树后续遍历的顺序对每个语法树节点应用 f_()
函数,不过要实现语法树输出的目标,光靠这个还不够,需要再稍微进行一点点的扩充。
我们通过一个栈来记录下语法树上每一个 stmt 的从属关系,首先扩展一下上面那个 Visitor 来做到在访问 stmt 节点的前后分别调一个外部函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class IRPrePostOrderVisitor : public IRVisitor { public : explicit IRPrePostOrderVisitor (std::function<void (const NodeRef&)> f, std::function<void (const NodeRef&)> e) : f_(f), e_(e) { } void Visit (const NodeRef& node) final { if (visited_.count (node.get ()) != 0 ) return ; visited_.insert (node.get ()); f_ (node); IRVisitor::Visit (node); e_ (node); } private : std::function<void (const NodeRef&)> f_, e_; std::unordered_set<const Node*> visited_; }; void PrePostOrderVisit (const NodeRef& node, std::function<void (const NodeRef&)> fvisit, std::function<void (const NodeRef&)> evisit) { IRPrePostOrderVisitor (fvisit, evisit).Visit (node); }
PrePostOrderVisit()
相应的在 ir_visitor.h
里也要添加一下。
接下来再往 api_pass.cc
里添加一下_PrePostOrderVisit
的函数注册:
1 2 3 4 5 6 7 8 TVM_REGISTER_API ("_PrePostOrderVisit" ).set_body ([](TVMArgs args, TVMRetValue *ret) { Stmt stmt = args[0 ]; PackedFunc f = args[1 ]; PackedFunc e = args[2 ]; ir::PrePostOrderVisit (stmt, f, e); });
这样在 C++ 部分的工作就完成了,之后是 Python 这一层,我们的 dump 目标是语法树,所以找个能直接拿到 stmt 结构的地方,build_module.py
就很不错。
lower()
是从 TVM IR 往能够运行的代码编译的第一步,涉及到多种 ir_pass 的使用,不同 ir_pass 的前后插入 dump ast 的代码可以帮助我们快速搞清楚每个 ir_pass 到底实际做了什么事情。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def lower (... ): ... for f in lower_phase0: stmt = f(stmt) stmt = ir_pass.StorageFlatten(stmt, binds, 64 , cfg.instrument_bound_checkers) stmt = ir_pass.CanonicalSimplify(stmt) dump_ast(stmt) for f in lower_phase1: stmt = f(stmt) if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) if cfg.disable_vectorize: stmt = ir_pass.SkipVectorize(stmt) else : stmt = ir_pass.VectorizeLoop(stmt) ...
接下来看一下我们需要在 dump_ast()
里面写上什么:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 def dump_ast (stmt ): stack = [] ast_node = [] ast_edge = [] count = [0 ] def pre_func (stmt ): node_idx = count[0 ] count[0 ] += 1 ast_node.append([node_idx, stmt]) if len (stack): ast_edge.append([stack[-1 ], node_idx]) stack.append(node_idx) def post_func (stmt ): del stack[-1 ] _api_internal._PrePostOrderVisit(stmt, pre_func, post_func) with open ("graph.txt" , "w" ) as f: f.write("digraph {\n" ) f.write(" node [shape=matrix]\n" ) for node in ast_node: ast_type = type (node[1 ]) ast_str = str (node[1 ]).replace("\n" , "\\l" ).replace("\\n" , "\\l" ) f.write(" node%d" % (node[0 ])) f.write("[label=\"%s\n%s\"]" % (ast_type, ast_str)) f.write(";\n" ) for edge in ast_edge: f.write(" node%d -> node%d;\n" % (edge[0 ], edge[1 ])) f.write("}\n" )
思路也是很简单的,我们只要在 PrePostOrderVisit 访问 stmt 节点前将节点入栈,然后访问节点结束后将节点退栈就完事了。
这个地方也体现出上一篇中 TVM 特别搞出来的这套跨 Python 和 C++ 混合运行的 PackedFunc 机制的便利性,这里事实上我们是从 Python 层开始,调了一个 C++ 的函数,然后在这个 C++ 的函数里面又回调了两个 Python 的函数,并且这个过程中数据还是存在我们在 Python 层创建的结构上的。
试一下下面这段示例代码:
1 2 3 4 5 6 7 8 9 n = tvm.var("n" ) A = tvm.placeholder((n,), name='A' ) B = tvm.placeholder((n,), name='B' ) C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C" ) s = tvm.create_schedule([C.op]) bx, tx = s[C].split(C.op.axis[0 ], factor=64 ) res = tvm.lower(s, [A, B, C], simple_mode=True )
通过 dump_ast 打出来的 dot 代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 digraph { node [shape=matrix] node0[label="<class 'tvm.stmt.ProducerConsumer'> produce C {\l for (i.outer, 0, ((n + 63)/64)) {\l for (i.inner, 0, 64) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l }\l }\l }\l }\l}\l"]; node1[label="<class 'tvm.stmt.For'> for (i.outer, 0, ((n + 63)/64)) {\l for (i.inner, 0, 64) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l }\l }\l }\l}\l"]; node2[label="<class 'tvm.expr.IntImm'> 0"]; node3[label="<class 'tvm.expr.Div'> ((n + 63)/64)"]; node4[label="<class 'tvm.expr.Add'> (n + 63)"]; node5[label="<class 'tvm.expr.Var'> n"]; node6[label="<class 'tvm.expr.IntImm'> 63"]; node7[label="<class 'tvm.expr.IntImm'> 64"]; node8[label="<class 'tvm.stmt.For'> for (i.inner, 0, 64) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l }\l }\l}\l"]; node9[label="<class 'tvm.expr.IntImm'> 0"]; node10[label="<class 'tvm.expr.IntImm'> 64"]; node11[label="<class 'tvm.stmt.IfThenElse'> if (likely((((i.outer*64) + i.inner) < n))) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l }\l}\l"]; node12[label="<class 'tvm.expr.Call'> likely((((i.outer*64) + i.inner) < n))"]; node13[label="<class 'tvm.expr.LT'> (((i.outer*64) + i.inner) < n)"]; node14[label="<class 'tvm.expr.Add'> ((i.outer*64) + i.inner)"]; node15[label="<class 'tvm.expr.Mul'> (i.outer*64)"]; node16[label="<class 'tvm.expr.Var'> i.outer"]; node17[label="<class 'tvm.expr.IntImm'> 64"]; node18[label="<class 'tvm.expr.Var'> i.inner"]; node19[label="<class 'tvm.stmt.IfThenElse'> if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l}\l"]; node20[label="<class 'tvm.expr.Call'> likely((((i.outer*64) + i.inner) < n))"]; node21[label="<class 'tvm.expr.LT'> (((i.outer*64) + i.inner) < n)"]; node22[label="<class 'tvm.expr.Add'> ((i.outer*64) + i.inner)"]; node23[label="<class 'tvm.expr.Mul'> (i.outer*64)"]; node24[label="<class 'tvm.expr.IntImm'> 64"]; node25[label="<class 'tvm.stmt.Store'> C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l"]; node26[label="<class 'tvm.expr.Add'> (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])"]; node27[label="<class 'tvm.expr.Load'> A[((i.outer*64) + i.inner)]"]; node28[label="<class 'tvm.expr.Add'> ((i.outer*64) + i.inner)"]; node29[label="<class 'tvm.expr.Mul'> (i.outer*64)"]; node30[label="<class 'tvm.expr.IntImm'> 64"]; node31[label="<class 'tvm.expr.UIntImm'> (bool)1"]; node32[label="<class 'tvm.expr.Load'> B[((i.outer*64) + i.inner)]"]; node33[label="<class 'tvm.expr.Add'> ((i.outer*64) + i.inner)"]; node34[label="<class 'tvm.expr.Mul'> (i.outer*64)"]; node35[label="<class 'tvm.expr.IntImm'> 64"]; node36[label="<class 'tvm.expr.UIntImm'> (bool)1"]; node37[label="<class 'tvm.expr.Add'> ((i.outer*64) + i.inner)"]; node38[label="<class 'tvm.expr.Mul'> (i.outer*64)"]; node39[label="<class 'tvm.expr.IntImm'> 64"]; node40[label="<class 'tvm.expr.UIntImm'> (bool)1"]; node0 -> node1; node1 -> node2; node1 -> node3; node3 -> node4; node4 -> node5; node4 -> node6; node3 -> node7; node1 -> node8; node8 -> node9; node8 -> node10; node8 -> node11; node11 -> node12; node12 -> node13; node13 -> node14; node14 -> node15; node15 -> node16; node15 -> node17; node14 -> node18; node11 -> node19; node19 -> node20; node20 -> node21; node21 -> node22; node22 -> node23; node23 -> node24; node19 -> node25; node25 -> node26; node26 -> node27; node27 -> node28; node28 -> node29; node29 -> node30; node27 -> node31; node26 -> node32; node32 -> node33; node33 -> node34; node34 -> node35; node32 -> node36; node25 -> node37; node37 -> node38; node38 -> node39; node25 -> node40; }
通过 GraphViz 这种 dot 可视化工具处理一下:
当然这个实现还是相当简单了,根据每个节点的类型等等还可以再加一些更复杂的判断逻辑,控制一下输出的内容量等等,以及其实可以看到里面还有很多重复的节点也都可以被筛掉。
Relay IR 的整体处理结构跟 TVM IR 一致,用类似的方法也可以把 Relay 那一层的 AST 打出来。