还是应该多写东西,感觉离开学校以后写的越来越少了,真是惭愧。
之后的工作有很多是要跟神经网络的编译执行相关,准备继 TF 拆包之后再来记一下我对 TVM 的探索过程。这个系列的坑会开多久也不一定,毕竟之前 TF 的其实也只做了一点点微小的工作,还有很多方面的内容没有看,另外 TF 要更到 2.0 之后可能有不少地方已经有一些变动了。
(也许有空也会看一下 XLA 吧,不过那个应该是要合到 TF 的拆包系列里面去了)
一些基础的背景可以见前面一篇对 TVM 作者一门课的课程记录:【CSE 599W: Systems for ML】
第一篇先从代码中一些基础的结构开始,最早读代码的时候没看明白一些东西是怎么实现的,为了看懂细节吃了不少的苦头。
参考文档:【TVM Runtime System】
Node & NodeRef
Node 和 NodeRef 这两个类在 TVM 中几乎是所有对象的基类了,Node 是功能本体,NodeRef 可以看成是对 Node 的一个指针引用。举例来说 TVM IR 语法树中的两个结构 Statement 和 Expression 的实际存储对象是 StmtNode
和 ExprNode
,从 Node 继承而来,但是在被其他结构用到的时候用的却是 Stmt
和 Expr
两个结构,从 NodeRef 继承而来。
Node 这个结构本身没什么好看的,看一下 NodeRef 的实现:
1 | /*! \brief Base class of all node reference object */ |
最关键的点在于 NodeRef 对它的 ->
运算符做了一下重载,即返回自己实际代表的实体对象,所以 NodeRef 以及其派生结构虽然实际是个对象,但是在代码里面可以当成指针来用,->
运算符之后直接跟的就是本体 Node 的成员。
1 | /*! \brief Container of all statements */ |
以上面的 Stmt
类为例,其他很多 NodeRef 的派生类结构也会用一个 TVM_DEFINE_NODE_REF
的宏来扩展出这部分的关键代码。
那么问题来了,为什么要设计成这个样子呢。
这里谈一下我自己的理解,如果有问题的话也欢迎看到的同学指正一下。
一方面可能是为了方便下面 C++ runtime 和 Python 部分所有结构的无缝衔接;另一方面从上面的代码也可以看到,NodeRef 重载了 ->
运算符之后返回的是个 const 的指针对象,所以通过 ->
访问到的实际的成员结构都是只读的了,这就保证了 TVM 这套复杂系统里面各种数据结构的安全性。
当然严格的只读访问在某些情况下是不够的,所以 TVM 提供了 CopyOnWrite()
的机制,如果某个 NodeRef 类的定义中包含了 TVM_DEFINE_NODE_REF_COW
这个宏的话,可以通过 NodeRef.CopyOnWrite()
获得一个可修改的 Node 指针,之后对成员内容的修改均通过这个指针来做就可以了。
1 | /*! |
话说 CopyOnWrite()
这个函数名称我感觉可能不是特别确切,也许改成 GetMutablePtr()
之类的会更好点?因为这个实际上并不 Copy,直接调用这个函数返回的是对这个 NodeRef 自己所指代对象的指针,之后的改动也都是对这个 Node 自身做的。
如果确切希望实现 Copy 的语义,则需要像前面注释里面示例的那样,先用另一个 ref2
复制一份 ref
,之后再在 ref2
上进行修改。
PackedFunc
TVM 的整个软件栈涉及到很多高层脚本语言(Python、JavaScript)和 C++ 运行时的交互,因此这里提供了一套 PackedFunc 的基础用来把整个过程方便地串接起来。
第一次看到这种实现时真的是被惊到了,感觉非常神奇。
在 C++ 层面创建一个函数,可以直接进行本地调用:
1 |
|
也可以通过 API 注册之后(要注册到 TVM 的 C++ 运行时库里面去),从 Python 层进行调用:
1 | // register a global packed function in c++ |
1 | import tvm |
反过来 Python 层写好的函数也可以直接从 C++ 层调用:
1 | TVM_REGISTER_GLOBAL("callhello") |
1 | import tvm |
Python 层定义的 callback(msg)
通过 callhello
传递给 C++ 层,C++ 层执行时直接从输入参数中得到了 Python 的函数对象并调用执行。
实现方面,C++ 层的 PackedFunc 是一个对 std::function 对象的封装结构,而 Python 层面tvm.convert
实际上是把 Python 的函数用 ctype 做了一下封装:
1 | TVMPackedCFunc = ctypes.CFUNCTYPE( |
PackedFunc 对函数的输入参数、返回值的解析处理做了比较精巧的处理,最终达到了从 API 层面看上去非常好的使用体验。
Amazing!
有空的话可以试一下把 TVM 里面的这部分内容单独扒出来,这个实现思路真的非常有意思。