0%

TVM 拆包(一):Runtime basics

还是应该多写东西,感觉离开学校以后写的越来越少了,真是惭愧。

之后的工作有很多是要跟神经网络的编译执行相关,准备继 TF 拆包之后再来记一下我对 TVM 的探索过程。这个系列的坑会开多久也不一定,毕竟之前 TF 的其实也只做了一点点微小的工作,还有很多方面的内容没有看,另外 TF 要更到 2.0 之后可能有不少地方已经有一些变动了。

(也许有空也会看一下 XLA 吧,不过那个应该是要合到 TF 的拆包系列里面去了)

一些基础的背景可以见前面一篇对 TVM 作者一门课的课程记录:【CSE 599W: Systems for ML】


第一篇先从代码中一些基础的结构开始,最早读代码的时候没看明白一些东西是怎么实现的,为了看懂细节吃了不少的苦头。

参考文档:【TVM Runtime System】

Node & NodeRef

Node 和 NodeRef 这两个类在 TVM 中几乎是所有对象的基类了,Node 是功能本体,NodeRef 可以看成是对 Node 的一个指针引用。举例来说 TVM IR 语法树中的两个结构 Statement 和 Expression 的实际存储对象是 StmtNodeExprNode,从 Node 继承而来,但是在被其他结构用到的时候用的却是 StmtExpr 两个结构,从 NodeRef 继承而来。

Node 这个结构本身没什么好看的,看一下 NodeRef 的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/*! \brief Base class of all node reference object */
class NodeRef {
public:
/*! \brief type indicate the container type */
using ContainerType = Node;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator==(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool same_as(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator<(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
/*! \return whether the expression is null */
inline bool defined() const;
/*! \return the internal type index of IRNode */
inline uint32_t type_index() const;
/*! \return the internal node pointer */
inline const Node* get() const;
/*! \return the internal node pointer */
inline const Node* operator->() const;
/*!
* \brief Downcast this ir node to its actual type (e.g. Add, or
* Select). This returns nullptr if the node is not of the requested
* type. Example usage:
*
* if (const Add *add = node->as<Add>()) {
* // This is an add node
* }
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as() const;
/*!
* \brief A more powerful version of as that also works with
* intermediate base types.
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as_derived() const;
/*! \brief default constructor */
NodeRef() = default;
explicit NodeRef(NodePtr<Node> node) : node_(node) {}
/*! \brief the internal node object, do not touch */
NodePtr<Node> node_;
};

inline const Node* NodeRef::operator->() const {
return node_.get();
}

最关键的点在于 NodeRef 对它的 ->运算符做了一下重载,即返回自己实际代表的实体对象,所以 NodeRef 以及其派生结构虽然实际是个对象,但是在代码里面可以当成指针来用,->运算符之后直接跟的就是本体 Node 的成员。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/*! \brief Container of all statements */
class Stmt : public NodeRef {
public:
TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode);
};

/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;

以上面的 Stmt 类为例,其他很多 NodeRef 的派生类结构也会用一个 TVM_DEFINE_NODE_REF 的宏来扩展出这部分的关键代码。


那么问题来了,为什么要设计成这个样子呢。

这里谈一下我自己的理解,如果有问题的话也欢迎看到的同学指正一下。

一方面可能是为了方便下面 C++ runtime 和 Python 部分所有结构的无缝衔接;另一方面从上面的代码也可以看到,NodeRef 重载了 -> 运算符之后返回的是个 const 的指针对象,所以通过 -> 访问到的实际的成员结构都是只读的了,这就保证了 TVM 这套复杂系统里面各种数据结构的安全性。

当然严格的只读访问在某些情况下是不够的,所以 TVM 提供了 CopyOnWrite() 的机制,如果某个 NodeRef 类的定义中包含了 TVM_DEFINE_NODE_REF_COW 这个宏的话,可以通过 NodeRef.CopyOnWrite() 获得一个可修改的 Node 指针,之后对成员内容的修改均通过这个指针来做就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*!
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
return static_cast<NodeName*>(node_.get()); \
}

话说 CopyOnWrite() 这个函数名称我感觉可能不是特别确切,也许改成 GetMutablePtr() 之类的会更好点?因为这个实际上并不 Copy,直接调用这个函数返回的是对这个 NodeRef 自己所指代对象的指针,之后的改动也都是对这个 Node 自身做的。

如果确切希望实现 Copy 的语义,则需要像前面注释里面示例的那样,先用另一个 ref2 复制一份 ref,之后再在 ref2 上进行修改。

PackedFunc

TVM 的整个软件栈涉及到很多高层脚本语言(Python、JavaScript)和 C++ 运行时的交互,因此这里提供了一套 PackedFunc 的基础用来把整个过程方便地串接起来。

第一次看到这种实现时真的是被惊到了,感觉非常神奇。

在 C++ 层面创建一个函数,可以直接进行本地调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <tvm/runtime/packed_func.h>

void MyAdd(TVMArgs args, TVMRetValue* rv) {
// automatically convert arguments to desired type.
int a = args[0];
int b = args[1];
// automatically assign value return to rv
*rv = a + b;
}

void CallPacked() {
PackedFunc myadd = PackedFunc(MyAdd);
// get back 3
int c = myadd(1, 2);
}

也可以通过 API 注册之后(要注册到 TVM 的 C++ 运行时库里面去),从 Python 层进行调用:

1
2
3
// register a global packed function in c++
TVM_REGISTER_GLOBAL("myadd")
.set_body(MyAdd);
1
2
3
4
5
import tvm

myadd = tvm.get_global_func("myadd")
# prints 3
print(myadd(1, 2))

反过来 Python 层写好的函数也可以直接从 C++ 层调用:

1
2
3
4
5
TVM_REGISTER_GLOBAL("callhello")
.set_body([](TVMArgs args, TVMRetValue* rv) {
PackedFunc f = args[0];
f("hello world");
});
1
2
3
4
5
6
7
8
9
10
import tvm

def callback(msg):
print(msg)

# convert to PackedFunc
f = tvm.convert(callback)
callhello = tvm.get_global_func("callhello")
# prints hello world
callhello(f)

Python 层定义的 callback(msg) 通过 callhello 传递给 C++ 层,C++ 层执行时直接从输入参数中得到了 Python 的函数对象并调用执行。


实现方面,C++ 层的 PackedFunc 是一个对 std::function 对象的封装结构,而 Python 层面tvm.convert 实际上是把 Python 的函数用 ctype 做了一下封装:

1
2
3
4
5
6
7
TVMPackedCFunc = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p)

PackedFunc 对函数的输入参数、返回值的解析处理做了比较精巧的处理,最终达到了从 API 层面看上去非常好的使用体验。

Amazing!

有空的话可以试一下把 TVM 里面的这部分内容单独扒出来,这个实现思路真的非常有意思。