Triton 以其低门槛开发和抽象的硬件细节处理,成为开发者的优选。

对于渴望参与 Triton 开源社区建设的开发者来说,优化 Triton 算子是一条理想的路径。优化后的 Triton 算子性能有望匹敌甚至超越 PyTorch 的原生实现。

正如古人云:“工欲善其事,必先利其器”,本文将介绍 Triton 算子优化的利器——自动调优(autotune)。

Triton 算子的实现代码

一个简单的 Triton 算子的实现代码通常分为两个部分。

  1. 计算准备阶段:涉及输入张量的预处理(如转换为连续布局张量),计算输出张量的形状并分配内存,以及设置运行参数(如 grid 和 BLOCK_SIZE)。

  2. 核函数调用:在 GPU 上实现计算逻辑。

以下是一个对三维张量 inp(形状为 [M,N,K])计算在第 1 维(N 所在维)上最大值下标的例子,即 argmax 算子。下文将在本例的基础上讲述 autotune 的用法。

def argmax(inp):
    # 第一部分
    dim = 1                       # 本例是 argmax 的一个特化实现,仅能处理 dim=1 的情形
    N = shape[dim]
    M = prod(shape[:dim])
    K = inp.numel() // M // N
    inp = inp.contiguous()        # 将输入转换为连续布局张量,本例中可以简化核函数实现
    shape = list(inp.shape)
    shape[dim] = 1                # 第 dim 维取最大值下标,因此该维度上输出 shape 归一
    out_index = torch.empty(shape, dtype=torch.int64, device=inp.device)  # 分配输出
    grid = lambda meta: (         # 本例 grid 使用二维,通过表达式指定
        triton.cdiv(M, meta["BLOCK_M"]),
        K,
    )
    # 第二部分
    with torch.cuda.device(inp.device):
        argmax_kernel[grid](
            inp,
            out_index,
            M,
            N,
            K,
        )
    return out_index

计算任务的划分

Triton 核函数的编程模式是面向 CTA(也即 Thread Block,线程块)的编程。由于线程及更低层次的硬件细节已被 Triton 隐藏,入门开发者可以专注于算法设计和线程块的划分,二者的关系类似算法与数据结构的关系。

线程块划分的本质

线程块的划分,本质是对计算资源的规划,它与数据张量的切分方案密不可分。在最简单的 pointwise 算子中,输入张量与输出张量形状相同,输入张量中每个元素通过相同的运算规则计算出输出张量的对应元素,此时线程块的划分就是数据张量的划分。较复杂一些的算子中,一个线程块可能处理多个数据块。不过即便是在这些复杂算子中,由 BLOCK_SIZE 线程组成的线程块内,每一步操作(加载、计算、存储等)通常也是以 BLOCK_SIZE 大小的数据块为单位的。多个线程块的组织结构则由 grid 描述。grid 可以是一维、二维或三维,对应线程块的 BLOCK_SIZE 也在不同维度上有所区分。

影响因素

在硬件和算子算法确定的情况下,线程块大小的划分成为影响 GPU 上执行时间的关键因素。划分方案的效率高低,受到多种因素的影响,包括但不限于数据张量的布局、缓存的有效利用、算法的计算顺序以及显卡硬件的具体限制等,更何况这还可能是个多维划分方案。可供决策的因素过多,划分方案的选择看起来并不是个显而易见的过程。

超参数的选择

此外还有一些超参需要选择,如 num_warps(一个线程块中可被调度的 warp 数量)、num_stages(循环流水线深度)等,在 Triton 编程中都可以设置,这使得不同输入尺寸、不同硬件型号下最优的算子超参选择成为一个更复杂的过程。

自动择优运行时参数

Triton 内建了自动调优(autotune)机制,使得核函数在运行时遍历所有被枚举的超参组合,选择最优的超参。我们需要做的,就是给 Triton 核函数增加一个 triton.autotune 的注解,配置上多种超参配置组合以供遍历择优。

上文例子中,核函数 argmax_kernel 的定义如下。triton.autotune 描述了多组 BLOCK_Mnum_warps 的超参组合,其中 key 表示不同的 MN 会选择不同的最优化超参组合。triton.heuristics 描述了如何通过超参中的其他参数用 heur_block_n 函数推导出 BLOCK_N。例中 BLOCK_M 是第 0 维(M 所在维)的 BLOCK_SIZE,BLOCK_N 是第 1 维(N 所在维)的 BLOCK_SIZE。

def heur_block_n(args):
    return min(4096, triton.next_power_of_2(args["N"]))

@libentry()
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 8}, num_warps=4),
        triton.Config({"BLOCK_M": 16}, num_warps=8),
        triton.Config({"BLOCK_M": 32}, num_warps=8),
    ],
    key=[
        "M",
        "N",
    ],
)
@triton.heuristics(
    {
        "BLOCK_N": heur_block_n,
    }
)
@triton.jit
def argmax_kernel(
    inp,
    out_index,
    M,
    N,
    K,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    # set offset
    pid_m = tl.program_id(0)
    pid_k = tl.program_id(1)
    ...    # 本文关注参数的自动择优配置,因此略去 argmax 的详细计算逻辑

configs 中的多组配置存在一定规律时,也可以用 Python 的 for 表达式来简洁地描述,例如 configs=[triton.Config({"BLOCK_M": m}, num_warps=8) for m in [8, 16, 32]]。花括号中也可以不止遍历 BLOCK_M 的组合,还可以加入 BLOCK_N 的组合来替代通过 triton.heuristics 直接推导,但是这样会导致超参择优耗时的增加。显然,triton.heuristics 的推导表达式是根据经验减少遍历空间的权衡。

例中 grid 为二元组,其中第 0 维是表达式,意为通过 cdiv(M, BLOCK_M) 计算得到,其中 cdiv 是上取整除法。

triton.autotune 还有可选的其他参数可以设置,例如指定择优过程的预热时间或基准测试时间(warmup、rep),核函数调用前后自定义回调函数(pre_hook、post_hook)等,详细用法可以查阅 triton 手册。

通过这种方式,Triton 能够在运行时自动测试不同的超参数组合,找到并应用最佳的性能配置,从而简化了性能优化过程,让开发者能够更专注于算法本身的开发。这种自动化的调优过程不仅节省了手动寻找最优参数的时间和精力,也提高了开发效率和最终产品的性能表现。

与直接设置参数的对比

我们可以通过设置环境变量 export TRITON_PRINT_AUTOTUNING=1 在超参择优过程结束后打印出 autotune 的结果,看看是否与个人经验相符。

如果不使用 autotune 而是直接指定线程块划分,则代码形式如下:

    BLOCK_M = 32
    BLOCK_N = min(4096, triton.next_power_of_2(N))
    num_warps = 8
    grid = (triton.cdiv(M, BLOCK_M), K,)
    with torch.cuda.device(inp.device):
        argmax_kernel[grid](
            inp,
            out_index,
            M,
            N,
            K,
            BLOCK_M,
            BLOCK_N,
            num_warps=num_warps,
        )

可以在不同输入形状下,对比直接指定和 autotune 的算子耗时,来验证参数设置的个人经验是否能接近最优结果。反过来,还可以辅助确定 triton.heuristics 的推导表达式,甚至做到完全使用直接指定超参就能获得较优的性能。考虑到 autotune 本身是有运行时开销的,简化遍历空间甚至替代 autotune 才是我们的最终目的。同时,我们也可以在这一过程中,加深对 GPU 编程性能优化的认知。

图片

扫码回复“Triton”

加入Triton中文社区交流群

Logo

欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐