Chenfan Blog

Do cool things that matter.

0%

TVM 拆包(二):IR

这一篇来记一下 TVM 跟语法树 IR 相关的一些内容。

TVM 是在 Halide 的基础上发展而来的一套编译架构。早一点版本的 TVM 其实还能看到 Halide 是作为一个 git 的 submodule 放在 TVM 目录里的,当时 TVM 自身的代码中还有一个 HalideIR 的 namespace,也有不少的结构是直接继承了 Halide IR 里的内容:

1
2
3
4
5
6
7
8
...
using HalideIR::Type;
using HalideIR::Float;
using HalideIR::Bool;
using HalideIR::Int;
using HalideIR::UInt;
using HalideIR::Handle;
...

[INFA][IR] Build and Evolve Low-level IR. Remove HalideIR dep. #3533 这个 PR 之后,Halide 相关的部分逐渐从 TVM 中删掉了,之后可以说从代码层面已经跟 Halide 没有关系了。

目前 TVM IR 的结构基本上还是从 Halide IR 一脉相承,可能以后 TVM IR 进一步演化以后跟 Halide 的差别就更大了。

IR

expr.h 里面定义了 TVM IR 的两个基础结构:Expr 和 Stmt,分别是语法表达式和语法树节点的基类。

Expr 的派生类有加减乘除、IntImm、FloatImm 等等,从文法上可以做一些 symbolic 的处理。

Stmt 的派生类有 AttrStmt(语法树属性节点)、Store(数据存储节点)、Allocate(数据 Buffer 分配节点)等等。

每个 Stmt 结构本身表示一个独立的语法树节点,但是语法树节点之间相互嵌套,通过 Stmt 的 body(Stmt 的通常结构)等成员继续向下查看就能够看到一颗完整的抽象语法树(AST)了。

例如 IfThenElse 这个 Stmt 的结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class IfThenElse : public StmtNode {
public:
/*! \brief The condition. */
Expr condition;
/*! \brief The branch to be executed when condition is true. */
Stmt then_case;
/*! \brief The branch to be executed when condition is false, can be null. */
Stmt else_case;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("condition", &condition);
v->Visit("then_case", &then_case);
v->Visit("else_case", &else_case);
}

TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());

static constexpr const char* _type_key = "IfThenElse";
TVM_DECLARE_NODE_TYPE_INFO(IfThenElse, StmtNode);
};

判断条件是个 Expr,then_case 和 else_case 都是另外两个 Stmt,再展开又是两棵子树。

IRVisitor & IRMutator

TVM 中定义了一些 ir_pass 来处理语法树,通过对语法树的修改和调整来完成编译优化的过程。

各种 ir_pass 的核心结构是 IRVisitor 和 IRMutator,从名称上也可以很容易看出来,IRVisitor 的功能是遍历语法树收集信息,本身对语法树的访问是只读的,然后通过 IRMutator 完成 ir_pass 需要的语法树修改需求。

看下 IRVisitor 的结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class TVM_DLL IRVisitor {
public:
/*!
* \brief recursively visit an IR node
*/
virtual void Visit(const NodeRef& node) {
static const FVisit& f = vtable();
if (node.defined()) f(node, this);
}
/*! \brief destructor */
virtual ~IRVisitor() {}
/*! \brief functor type of visitor */
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const For* op);
...
}

这里面最重要的实现其实是自己构造了一个虚函数表:

1
2
3
IRVisitor::FVisit& IRVisitor::vtable() {  // NOLINT(*)
static FVisit inst; return inst;
}

之后 IRVisitor 的派生类只需要去重载针对不同 Stmt 类型的 Visit_() 函数就好了。

Dump Ast

看 TVM 代码包括 debug 的时候一开始会觉得两眼一抹黑连语法树长啥样都没有个概念,也不知道 build 过程中每个 ir_pass 具体做了什么事情,就很想要有个工具能把语法树打成图看下。

最开始 TVM 里面是找不到这种现成的工具的,后来在论坛里有看到别人提的这方面相关 RFC,不过我后续也没有去关注最后到底有没有收进 repo 里了。其实了解完 IRVisitor 的实现之后,自己写一个语法树的 Dumper 还是挺简单的。

ir_visitor.cc 里面有一个 PostOrderVisit 的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class IRApplyVisit : public IRVisitor {
public:
explicit IRApplyVisit(std::function<void(const NodeRef&)> f) : f_(f) {}

void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
IRVisitor::Visit(node);
f_(node);
}

private:
std::function<void(const NodeRef&)> f_;
std::unordered_set<const Node*> visited_;
};


void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}

按照对语法树后续遍历的顺序对每个语法树节点应用 f_() 函数,不过要实现语法树输出的目标,光靠这个还不够,需要再稍微进行一点点的扩充。

我们通过一个栈来记录下语法树上每一个 stmt 的从属关系,首先扩展一下上面那个 Visitor 来做到在访问 stmt 节点的前后分别调一个外部函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class IRPrePostOrderVisitor : public IRVisitor {
public:
explicit IRPrePostOrderVisitor(std::function<void(const NodeRef&)> f,
std::function<void(const NodeRef&)> e) : f_(f), e_(e) {}

void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
f_(node);
IRVisitor::Visit(node);
e_(node);
}

private:
std::function<void(const NodeRef&)> f_, e_;
std::unordered_set<const Node*> visited_;
};

void PrePostOrderVisit(const NodeRef& node,
std::function<void(const NodeRef&)> fvisit,
std::function<void(const NodeRef&)> evisit) {
IRPrePostOrderVisitor(fvisit, evisit).Visit(node);
}

PrePostOrderVisit() 相应的在 ir_visitor.h 里也要添加一下。

接下来再往 api_pass.cc 里添加一下_PrePostOrderVisit的函数注册:

1
2
3
4
5
6
7
TVM_REGISTER_API("_PrePostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Stmt stmt = args[0];
PackedFunc f = args[1];
PackedFunc e = args[2];
ir::PrePostOrderVisit(stmt, f, e);
});

这样在 C++ 部分的工作就完成了,之后是 Python 这一层,我们的 dump 目标是语法树,所以找个能直接拿到 stmt 结构的地方,build_module.py 就很不错。

lower() 是从 TVM IR 往能够运行的代码编译的第一步,涉及到多种 ir_pass 的使用,不同 ir_pass 的前后插入 dump ast 的代码可以帮助我们快速搞清楚每个 ir_pass 到底实际做了什么事情。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def lower(...):
...
for f in lower_phase0:
stmt = f(stmt)
# Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)

dump_ast(stmt) # <-------------- here

for f in lower_phase1:
stmt = f(stmt)
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
if cfg.disable_vectorize:
stmt = ir_pass.SkipVectorize(stmt)
else:
stmt = ir_pass.VectorizeLoop(stmt)
...

接下来看一下我们需要在 dump_ast() 里面写上什么:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def dump_ast(stmt):
stack = []
ast_node = []
ast_edge = []
count = [0]

def pre_func(stmt):
node_idx = count[0]
count[0] += 1

ast_node.append([node_idx, stmt])
if len(stack):
ast_edge.append([stack[-1], node_idx])

stack.append(node_idx)

def post_func(stmt):
del stack[-1]

_api_internal._PrePostOrderVisit(stmt, pre_func, post_func)

with open("graph.txt", "w") as f:
f.write("digraph {\n")
f.write(" node [shape=matrix]\n")
for node in ast_node:
ast_type = type(node[1])
ast_str = str(node[1]).replace("\n", "\\l").replace("\\n", "\\l")
f.write(" node%d" % (node[0]))
f.write("[label=\"%s\n%s\"]" % (ast_type, ast_str))
f.write(";\n")
for edge in ast_edge:
f.write(" node%d -> node%d;\n" % (edge[0], edge[1]))
f.write("}\n")

思路也是很简单的,我们只要在 PrePostOrderVisit 访问 stmt 节点前将节点入栈,然后访问节点结束后将节点退栈就完事了。

这个地方也体现出上一篇中 TVM 特别搞出来的这套跨 Python 和 C++ 混合运行的 PackedFunc 机制的便利性,这里事实上我们是从 Python 层开始,调了一个 C++ 的函数,然后在这个 C++ 的函数里面又回调了两个 Python 的函数,并且这个过程中数据还是存在我们在 Python 层创建的结构上的。


试一下下面这段示例代码:

1
2
3
4
5
6
7
8
9
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")

s = tvm.create_schedule([C.op])
bx, tx = s[C].split(C.op.axis[0], factor=64)

res = tvm.lower(s, [A, B, C], simple_mode=True)

通过 dump_ast 打出来的 dot 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
digraph {
node [shape=matrix]
node0[label="<class 'tvm.stmt.ProducerConsumer'>
produce C {\l for (i.outer, 0, ((n + 63)/64)) {\l for (i.inner, 0, 64) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l }\l }\l }\l }\l}\l"];
node1[label="<class 'tvm.stmt.For'>
for (i.outer, 0, ((n + 63)/64)) {\l for (i.inner, 0, 64) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l }\l }\l }\l}\l"];
node2[label="<class 'tvm.expr.IntImm'>
0"];
node3[label="<class 'tvm.expr.Div'>
((n + 63)/64)"];
node4[label="<class 'tvm.expr.Add'>
(n + 63)"];
node5[label="<class 'tvm.expr.Var'>
n"];
node6[label="<class 'tvm.expr.IntImm'>
63"];
node7[label="<class 'tvm.expr.IntImm'>
64"];
node8[label="<class 'tvm.stmt.For'>
for (i.inner, 0, 64) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l }\l }\l}\l"];
node9[label="<class 'tvm.expr.IntImm'>
0"];
node10[label="<class 'tvm.expr.IntImm'>
64"];
node11[label="<class 'tvm.stmt.IfThenElse'>
if (likely((((i.outer*64) + i.inner) < n))) {\l if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l }\l}\l"];
node12[label="<class 'tvm.expr.Call'>
likely((((i.outer*64) + i.inner) < n))"];
node13[label="<class 'tvm.expr.LT'>
(((i.outer*64) + i.inner) < n)"];
node14[label="<class 'tvm.expr.Add'>
((i.outer*64) + i.inner)"];
node15[label="<class 'tvm.expr.Mul'>
(i.outer*64)"];
node16[label="<class 'tvm.expr.Var'>
i.outer"];
node17[label="<class 'tvm.expr.IntImm'>
64"];
node18[label="<class 'tvm.expr.Var'>
i.inner"];
node19[label="<class 'tvm.stmt.IfThenElse'>
if (likely((((i.outer*64) + i.inner) < n))) {\l C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l}\l"];
node20[label="<class 'tvm.expr.Call'>
likely((((i.outer*64) + i.inner) < n))"];
node21[label="<class 'tvm.expr.LT'>
(((i.outer*64) + i.inner) < n)"];
node22[label="<class 'tvm.expr.Add'>
((i.outer*64) + i.inner)"];
node23[label="<class 'tvm.expr.Mul'>
(i.outer*64)"];
node24[label="<class 'tvm.expr.IntImm'>
64"];
node25[label="<class 'tvm.stmt.Store'>
C[((i.outer*64) + i.inner)] = (A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])\l"];
node26[label="<class 'tvm.expr.Add'>
(A[((i.outer*64) + i.inner)] + B[((i.outer*64) + i.inner)])"];
node27[label="<class 'tvm.expr.Load'>
A[((i.outer*64) + i.inner)]"];
node28[label="<class 'tvm.expr.Add'>
((i.outer*64) + i.inner)"];
node29[label="<class 'tvm.expr.Mul'>
(i.outer*64)"];
node30[label="<class 'tvm.expr.IntImm'>
64"];
node31[label="<class 'tvm.expr.UIntImm'>
(bool)1"];
node32[label="<class 'tvm.expr.Load'>
B[((i.outer*64) + i.inner)]"];
node33[label="<class 'tvm.expr.Add'>
((i.outer*64) + i.inner)"];
node34[label="<class 'tvm.expr.Mul'>
(i.outer*64)"];
node35[label="<class 'tvm.expr.IntImm'>
64"];
node36[label="<class 'tvm.expr.UIntImm'>
(bool)1"];
node37[label="<class 'tvm.expr.Add'>
((i.outer*64) + i.inner)"];
node38[label="<class 'tvm.expr.Mul'>
(i.outer*64)"];
node39[label="<class 'tvm.expr.IntImm'>
64"];
node40[label="<class 'tvm.expr.UIntImm'>
(bool)1"];
node0 -> node1;
node1 -> node2;
node1 -> node3;
node3 -> node4;
node4 -> node5;
node4 -> node6;
node3 -> node7;
node1 -> node8;
node8 -> node9;
node8 -> node10;
node8 -> node11;
node11 -> node12;
node12 -> node13;
node13 -> node14;
node14 -> node15;
node15 -> node16;
node15 -> node17;
node14 -> node18;
node11 -> node19;
node19 -> node20;
node20 -> node21;
node21 -> node22;
node22 -> node23;
node23 -> node24;
node19 -> node25;
node25 -> node26;
node26 -> node27;
node27 -> node28;
node28 -> node29;
node29 -> node30;
node27 -> node31;
node26 -> node32;
node32 -> node33;
node33 -> node34;
node34 -> node35;
node32 -> node36;
node25 -> node37;
node37 -> node38;
node38 -> node39;
node25 -> node40;
}

通过 GraphViz 这种 dot 可视化工具处理一下:

Vis Result

当然这个实现还是相当简单了,根据每个节点的类型等等还可以再加一些更复杂的判断逻辑,控制一下输出的内容量等等,以及其实可以看到里面还有很多重复的节点也都可以被筛掉。

Relay IR 的整体处理结构跟 TVM IR 一致,用类似的方法也可以把 Relay 那一层的 AST 打出来。