Triton 的 Python API 里除了 triton.jit  还有  triton.autotune triton heuristics triton.Config 等接口用于调优以生成性能更好的 kernel。

1.Triton 调优简介

triton.autotune 是一个装饰器,用于对 triton.jit 装饰的函数进行自动调优。

triton.autotune(*configs*, *key*, *prune_configs_by=None*, *reset_to_zero=None*, *restore_value=None*, *pre_hook=None*, *post_hook=None*, *warmup=None*, *rep=None*, *use_cuda_graph=False*, *do_bench=None*)
@triton.autotune(configs=[
    triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4),
    triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8),
  ],
  key=['x_size']# the two above configs will be evaluated anytime
                # the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
    BLOCK_SIZE = META['BLOCK_SIZE']

可以发现该装饰器有两个重要的参数,一个是 configs,它是一个包含 triton 配置类型的列表 (list[triton.Config])。另一个是 字符串列表 (list[str]),包含了可以触发调优的参数的名称。在运行的过程中,我们打开环境变量 TRITON_PRINT_AUTOTUNING=1 便可以打印出调优过程中的 kernel 信息以及最优配置。

triton.Config 类型用来表示自动调优将尝试的配置:

class triton.Config(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None)
  • kwargs 是 meta 参数的字典,用于作为关键词参数传递给 kernel

  • num_warps kernel 编译时的 warp 数量

  • num_stages multi-stages 的数量

triton.heuristics 用于指定特定的 meta 参数的计算方法,减少自动调优的开销。


@triton
.heuristics(values={'BLOCK_SIZE': 

lambda

 args: 2 ** int(math.ceil(math.log2(args[1])))})

@triton
.jit

def

 kernel(x_ptr, x_size, **META):
    BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size

参数为一个 [字符串-函数] 的字典 (dict[str, Callable[[list[Any]], Any]])

Parameters:values – a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. each such function takes a list of positional arguments as input.

2. 矩阵乘法优化

首先我们来看一下 triton 官方的 GEMM 的 kernel 代码:

def get_autotune_config():
    return [
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        ...
    ]
@triton.autotune(
    configs=get_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
        a_ptr, b_ptr, c_ptr, # Pointers to matrices
        M, N, K, # Matrix dimensions
        # strides
        stride_am, stride_ak,

        stride_bk, stride_bn,
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
        GROUP_SIZE_M: tl.constexpr,  #
        ACTIVATION: tl.constexpr  #
):

get_autotune_config 函数返回一个 list[triton.Config],该列表包含了针对矩阵乘法中 MNK 三个维度大小的自动调优配置。通过这些配置,Triton 可以根据不同的输入尺寸自动选择最佳的计算策略,从而最大化计算效率。对于这一部分,读者可以参考前文中的相关定义,理解调优的基本原理。此处不再详细展开。

Triton 的教程中进一步对 L2 Cache 进行优化,核心方法是通过 GROUP_M 技术来优化内存访问模式。GROUP_M 方法的主要作用是将计算中的操作分组,从而优化数据在共享内存和 L2 Cache 中的布局,减少内存访问冲突。这一优化策略在矩阵乘法中尤为重要,因为大规模矩阵计算会涉及大量的内存访问,优化内存访问顺序能够显著提高性能。

关于 GROUP_M 方法的具体细节,官方文档已进行了详细说明,本文不再赘述。在此,我们通过下图来直观地理解这一优化策略的效果,从而帮助读者更好地掌握 L2 Cache 优化的基本原理和实践。

Triton 通过配置 num_stages 参数来增加计算中的流水线阶段,从而进一步提升计算性能。具体来说,num_stages 参数可以控制计算任务的并行度和流水线的深度。增加流水线阶段能够使得计算任务在执行过程中更多的并行化,从而减少数据依赖和等待时间。

下图解释了为什么 multi-stage 可以加速矩阵乘法以及在多阶段流水线中,计算任务如何被拆分为多个阶段,并行执行,从而加速整体计算过程。每个阶段的计算可以在不同的计算单元上独立进行,从而充分利用 GPU 的计算资源,避免了因计算步骤之间的依赖而带来的延迟。Triton 的 ir 最终会生成 cp.async.cg.shared.global 的指令。

然而,值得注意的是,num_stages 的值并非越大越好。如果 num_stages 设置得过高,可能会导致共享内存的需求激增,最终因为共享内存不足而导致编译失败。为了避免这种情况,建议在设置 num_stages 时,结合具体的硬件资源和任务需求进行合理选择。

Triton 会通过 shared memory swizzling 来避免 bank 冲突。最终实现和 cublas 性能相当的 kernel。

Triton 在内存访问优化方面的另一个关键技术是 shared memory。通过 swizzling 技术,Triton 可以在计算过程中动态调整共享内存的访问模式,避免 bank冲突(bank conflicts),从而提高内存带宽利用率。

bank 冲突发生在多个线程同时访问同一个内存bank时,会显著降低内存访问效率。Triton 通过 swizzling 会重新排列数据在共享内存中的布局,使得线程的访问能够更均匀地分布到不同的内存 bank 上,从而避免 bank 冲突。

这种优化技术使得 Triton 在进行大规模矩阵计算时,能够有效提高内存访问效率,并且保证与 cuBLAS 相当的性能水平。通过这些技术的结合,Triton 实现了对矩阵乘法等计算密集型任务的显著加速,提供了与 cuBLAS 相当的性能。

本文转载自智源FlagOpen公众号,作者张博。

扫码回复“Triton”

加入Triton中文社区交流群

Logo

欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐