向 Relay 中添加 Compiler Pass
Compiler Pass 是扩展 Relay 功能集 及优化 Relay 程序的主要接口。通过编写 compiler pass,用户可以基于最终目标,修改 AST 或收集 AST 相关信息。事实上,Relay 内置的一些重要特性(如自动微分和类型推断)都“标准”的 compiler pass。
整体来看,编写 pass 包括两个关键组成部分:
- 创建一个或多个遍历程序的 C++ 类
- 将遍历实现及其在 pass manager API 中的元数据包装,从而方便与 Pass Infrastructure 轻松交互
首先,我们将概述编写 compiler pass 的关键机制。然后通过 Relay 中常量折叠 pass 的具体示例进行演示。
AST 遍历器(Traversers)
用于遍历 Relay 程序的基类是 ExprFunctor
。它提供的公共接口是一个 VisitExpr
方法,该方法接收一个表达式以及零个或多个参数,并返回某种类型的实例。扩展此类时,可以通过覆盖每种表达式类型的 VisitExpr_
实现,来定义 AST 遍历模式。
VisitExpr
和 VisitExpr_
之间的关系与调度有关。每个 VisitExpr_
定义都针对特定类型的表达式,但用户无法每次都得知要访问的节点类型。为了解决这个问题,ExprFunctor
提供了一个 VisitExpr
函数,将给定表达式路由转换为 VisitExpr_
实例进而解决问题。尽管 C++ 已经提供了动态调度,但 ExprFunctor
定义了自己的虚表供 VisitExp
使用。通过定义虚表可以更好地控制调度。例如,定义一个在每次访问之前都打印 "Here" 的 PrintVisitor
遍历器,可以覆盖 VisitExpr
:
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << "Here" << std::endl;
ExprFunctor::VisitExpr(expr);
}
ExprFunctor
本身是一个非常通用的类,这就是为什么更多时候你会扩展 ExprVisitor
或 ExprMutator
。这些类扩展了 ExprFunctor
,并提供了 VisitExpr_
的默认实现,这些实现捕获了每种表达式类型的常见遍历模式。有了这些默认的实现,开发者只需针对想要不同行为的表达式类型,提供覆盖的实现。后续章节将针对每个子类进行详细描述。