intel-xpu-backend-for-triton绕过pytorch直接调用Unified-Runtime
intel-xpu-backend-for-triton绕过pytorch直接调用Unified-Runtime
·
intel-xpu-backend-for-triton绕过pytorch直接调用Unified-Runtime
背景
- 一般情况下triton kernel依赖pytorch
- 是否可以直接调用Unified-Runtime的API分配设备内存,直接给triton kernel使用呢
方法
- 将Unified-Runtime中的Mem,Queue等API封装为python接口
- 将UR分配的64bit设备地址直接传到triton kernel中
- 在triton kernel中将该地址转为具体类型的指针(如float32)
- 遗留问题:因为triton kernel为异步,需要为UR增加urDeviceSynchronize接口;下面的demo采用time.sleep规避
步骤
安装intel-xpu-backend-for-triton
docker stop triton_xpu_ipcx
docker rm triton_xpu_ipcx
docker run --shm-size=32g -it --privileged --net=host \
-v $PWD:/home -w /home \
--name triton_xpu_ipcx intel/deep-learning-essentials:latest /bin/bash
docker start triton_xpu_ipcx
docker exec -ti triton_xpu_ipcx bash
git clone https://github.com/intel/intel-xpu-backend-for-triton.git
cd intel-xpu-backend-for-triton/
git checkout 197f5f843fd17deab1de1df1c6f17e60978ecdce
source /opt/intel/oneapi/setvars.sh --force
export PATH=$PATH:/opt/intel/oneapi/compiler/2025.0/bin/compiler
apt install intel-ocloc -y
scripts/install-pytorch.sh --source --force-reinstall
export MAX_JOBS=16
scripts/compile-triton.sh --llvm --triton
运行测试用例
cat > xpu_triton_ur.py <<-'EOF'
# Import necessary libraries
import ctypes
from ctypes import *
import enum
import sys
import numpy as np
# Import Triton for GPU kernel execution
import triton
import triton.language as tl
# Define a Triton kernel for scaling operation
@triton.jit
def triton_scale_kernel(
input_ptr_addr, # Pointer to input data in device memory
output_ptr_addr, # Pointer to output data in device memory
scale, # Scaling factor to apply
n_elements, # Total number of elements to process
BLOCK_SIZE: tl.constexpr, # Number of elements processed by each thread block
):
# Get the program ID (determines which block this thread is in)
pid = tl.program_id(axis=0)
# Calculate the starting index for this block
block_start = pid * BLOCK_SIZE
# Calculate the offsets for all elements this block will process
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to avoid out-of-bounds memory access
mask = offsets < n_elements
# Convert raw pointers to proper float32 pointers
input_ptr = input_ptr_addr.to(tl.pointer_type(tl.float32))
output_ptr = output_ptr_addr.to(tl.pointer_type(tl.float32))
# Load input data from device memory
input_data = tl.load(input_ptr + offsets, mask=mask, other=0)
# Perform the scaling operation
output_data = input_data * scale
# Store the results back to device memory
tl.store(output_ptr + offsets, output_data, mask=mask)
# Python wrapper function for the Triton kernel
def triton_scale(_input, _output, scale, n_elements):
# Define the grid size (number of thread blocks)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# Launch the kernel with specified grid and block size
triton_scale_kernel[grid](_input, _output, scale, n_elements, BLOCK_SIZE=1024)
# Try loading the UR (Unified Runtime) loader library
try:
lib = cdll.LoadLibrary("libur_loader.so")
except OSError:
lib = cdll.LoadLibrary("ur_loader.so")
# Custom exception for UR API errors
class URException(Exception):
pass
# Function to check UR API call results
def _check_result(result):
success_codes = [0] # Assuming UR_RESULT_SUCCESS is 0
allowed_errors = [] # Add specific error codes if needed
if result not in success_codes + allowed_errors:
raise URException(f"UR API call failed with error code: {result}")
# Enum definitions for UR API
class ur_device_init_flag_t(enum.IntEnum):
UR_DEVICE_INIT_FLAG_GPU = 0x1
UR_DEVICE_INIT_FLAG_CPU = 0x2
UR_DEVICE_INIT_FLAG_FPGA = 0x4
UR_DEVICE_INIT_FLAG_MCA = 0x8
UR_DEVICE_INIT_FLAG_VPU = 0x10
ur_device_init_flags_t = c_uint32
class ur_device_type_t(enum.IntEnum):
UR_DEVICE_TYPE_GPU = 3
UR_DEVICE_TYPE_CPU = 4
# ...其他类型根据需要添加
class ur_mem_flag_t(enum.IntEnum):
UR_MEM_FLAG_READ_WRITE = 0x1
UR_MEM_FLAG_WRITE_ONLY = 0x2
UR_MEM_FLAG_READ_ONLY = 0x4
# ...其他标志
ur_mem_flags_t = c_uint32
# Structure definitions for UR API
class ur_context_properties_t(Structure):
_fields_ = [
("stype", c_uint32), # Structure type identifier
("pNext", c_void_p), # Pointer to extension structure
("flags", c_uint32), # Context creation flags
]
class ur_buffer_properties_t(Structure):
_fields_ = [
("stype", c_uint32), # Structure type identifier
("pNext", c_void_p), # Pointer to extension structure
("pHost", c_void_p), # Optional host pointer
]
# Handle types for UR objects
ur_adapter_handle_t = c_void_p
ur_platform_handle_t = c_void_p
ur_device_handle_t = c_void_p
ur_context_handle_t = c_void_p
ur_mem_handle_t = c_void_p
ur_queue_handle_t = c_void_p
ur_native_handle_t = c_void_p
# Function prototypes and Python wrappers for UR API
# urLoaderInit - Initialize the UR loader
lib.urLoaderInit.argtypes = [ur_device_init_flags_t, c_void_p]
lib.urLoaderInit.restype = c_int
def urLoaderInit(device_flags=0, hLoaderConfig=None):
result = lib.urLoaderInit(device_flags, hLoaderConfig)
_check_result(result)
# urAdapterGet - Get available adapters
lib.urAdapterGet.argtypes = [c_uint32, POINTER(ur_adapter_handle_t), POINTER(c_uint32)]
lib.urAdapterGet.restype = c_int
def urAdapterGet(NumEntries, phAdapters, pNumAdapters):
result = lib.urAdapterGet(NumEntries, phAdapters, pNumAdapters)
_check_result(result)
# urPlatformGet - Get platforms for given adapters
lib.urPlatformGet.argtypes = [POINTER(ur_adapter_handle_t), c_uint32, c_uint32,
POINTER(ur_platform_handle_t), POINTER(c_uint32)]
lib.urPlatformGet.restype = c_int
def urPlatformGet(phAdapters, NumAdapters, NumEntries, phPlatforms, pNumPlatforms):
result = lib.urPlatformGet(phAdapters, NumAdapters, NumEntries, phPlatforms, pNumPlatforms)
_check_result(result)
# urDeviceGet - Get devices for a platform
lib.urDeviceGet.argtypes = [ur_platform_handle_t, c_int, c_uint32,
POINTER(ur_device_handle_t), POINTER(c_uint32)]
lib.urDeviceGet.restype = c_int
def urDeviceGet(hPlatform, DeviceType, NumEntries, phDevices, pNumDevices):
result = lib.urDeviceGet(hPlatform, DeviceType, NumEntries, phDevices, pNumDevices)
_check_result(result)
# urContextCreate - Create a context for devices
lib.urContextCreate.argtypes = [c_uint32, POINTER(ur_device_handle_t),
POINTER(ur_context_properties_t), POINTER(ur_context_handle_t)]
lib.urContextCreate.restype = c_int
def urContextCreate(DeviceCount, phDevices, pProperties, phContext):
result = lib.urContextCreate(DeviceCount, phDevices, pProperties, phContext)
_check_result(result)
# urMemBufferCreate - Create a device memory buffer
lib.urMemBufferCreate.argtypes = [ur_context_handle_t, ur_mem_flags_t, c_size_t,
POINTER(ur_buffer_properties_t), POINTER(ur_mem_handle_t)]
lib.urMemBufferCreate.restype = c_int
def urMemBufferCreate(hContext, flags, size, pProperties, phBuffer):
result = lib.urMemBufferCreate(hContext, flags, size, pProperties, phBuffer)
_check_result(result)
# urMemRelease - Release a memory buffer
lib.urMemRelease.argtypes = [ur_mem_handle_t]
lib.urMemRelease.restype = c_int
def urMemRelease(hMem):
result = lib.urMemRelease(hMem)
_check_result(result)
# urQueueCreate - Create a command queue
lib.urQueueCreate.argtypes = [ur_context_handle_t, ur_device_handle_t, c_void_p,
POINTER(ur_queue_handle_t)]
lib.urQueueCreate.restype = c_int
def urQueueCreate(hContext, hDevice, pProperties, phQueue):
result = lib.urQueueCreate(hContext, hDevice, pProperties, phQueue)
_check_result(result)
# urQueueFinish - Wait for all commands in queue to complete
lib.urQueueFinish.argtypes = [ur_queue_handle_t]
lib.urQueueFinish.restype = c_int
def urQueueFinish(hQueue):
result = lib.urQueueFinish(hQueue)
_check_result(result)
# urEnqueueMemBufferWrite - Write data to device memory
lib.urEnqueueMemBufferWrite.argtypes = [ur_queue_handle_t, ur_mem_handle_t, c_bool,
c_size_t, c_size_t, c_void_p, c_uint32,
c_void_p, c_void_p]
lib.urEnqueueMemBufferWrite.restype = c_int
def urEnqueueMemBufferWrite(hQueue, hBuffer, blocking, offset, size, pSrc,
numEventsInWaitList, phEventWaitList, phEvent):
result = lib.urEnqueueMemBufferWrite(hQueue, hBuffer, blocking, offset, size, pSrc,
numEventsInWaitList, phEventWaitList, phEvent)
_check_result(result)
# urEnqueueMemBufferRead - Read data from device memory
lib.urEnqueueMemBufferRead.argtypes = [ur_queue_handle_t, ur_mem_handle_t, c_bool,
c_size_t, c_size_t, c_void_p, c_uint32,
c_void_p, c_void_p]
lib.urEnqueueMemBufferRead.restype = c_int
def urEnqueueMemBufferRead(hQueue, hBuffer, blocking, offset, size, pDst,
numEventsInWaitList, phEventWaitList, phEvent):
result = lib.urEnqueueMemBufferRead(hQueue, hBuffer, blocking, offset, size, pDst,
numEventsInWaitList, phEventWaitList, phEvent)
_check_result(result)
# urMemGetNativeHandle - Get native handle for memory object
lib.urMemGetNativeHandle.argtypes = [ur_mem_handle_t, ur_device_handle_t, POINTER(ur_native_handle_t)]
lib.urMemGetNativeHandle.restype = c_int
def urMemGetNativeHandle(hMem, hDevice, phNativeMem):
result = lib.urMemGetNativeHandle(hMem, hDevice, phNativeMem)
_check_result(result)
# Resource release functions
# urContextRelease - Release a context
lib.urContextRelease.argtypes = [ur_context_handle_t]
lib.urContextRelease.restype = c_int
def urContextRelease(hContext):
result = lib.urContextRelease(hContext)
_check_result(result)
# urAdapterRelease - Release an adapter
lib.urAdapterRelease.argtypes = [ur_adapter_handle_t]
lib.urAdapterRelease.restype = c_int
def urAdapterRelease(hAdapter):
result = lib.urAdapterRelease(hAdapter)
_check_result(result)
# urLoaderTearDown - Shutdown the UR loader
lib.urLoaderTearDown.argtypes = []
lib.urLoaderTearDown.restype = c_int
def urLoaderTearDown():
result = lib.urLoaderTearDown()
_check_result(result)
def main():
try:
# Initialize the UR loader
urLoaderInit()
# Get available adapters
adapter_count = ctypes.c_uint32()
# First call gets the count
urAdapterGet(0, None, ctypes.byref(adapter_count))
# Allocate array for adapters
adapters = (ur_adapter_handle_t * adapter_count.value)()
# Second call gets the adapters
urAdapterGet(adapter_count.value, adapters, None)
# Get platforms for the adapters
platform_count = ctypes.c_uint32()
# First call gets the count
urPlatformGet(adapters, adapter_count.value, 1, None, ctypes.byref(platform_count))
# Allocate array for platforms
platforms = (ur_platform_handle_t * platform_count.value)()
# Second call gets the platforms
urPlatformGet(adapters, adapter_count.value, platform_count.value, platforms, None)
# Process each platform
for platform in platforms:
# Get GPU devices for this platform
device_count = ctypes.c_uint32()
# First call gets the count
urDeviceGet(platform, ur_device_type_t.UR_DEVICE_TYPE_GPU.value, 0, None, ctypes.byref(device_count))
# Allocate array for devices
devices = (ur_device_handle_t * device_count.value)()
# Second call gets the devices
urDeviceGet(platform, ur_device_type_t.UR_DEVICE_TYPE_GPU.value, device_count.value, devices, None)
# Process each device (just the first one in this example)
for i in range(device_count.value):
device = devices[i]
# Create array with single device for context creation
device_array = (ur_device_handle_t * 1)(device)
# Create context for this device
hContext = ur_context_handle_t()
urContextCreate(1, device_array, None, ctypes.byref(hContext))
# Create input and output buffers in device memory
n_elements = 32 # Number of elements in our test array
# Create input buffer
dA = ur_mem_handle_t()
urMemBufferCreate(hContext, ur_mem_flag_t.UR_MEM_FLAG_READ_WRITE.value,
n_elements * ctypes.sizeof(ctypes.c_float), None, ctypes.byref(dA))
# Get native handle for Triton to use
dA_ptr = ur_native_handle_t()
urMemGetNativeHandle(dA, device, byref(dA_ptr))
# Create output buffer
dB = ur_mem_handle_t()
urMemBufferCreate(hContext, ur_mem_flag_t.UR_MEM_FLAG_READ_WRITE.value,
n_elements * ctypes.sizeof(ctypes.c_float), None, ctypes.byref(dB))
dB_ptr = ur_native_handle_t()
urMemGetNativeHandle(dB, device, byref(dB_ptr))
# Create command queue for this device
queue = ur_queue_handle_t()
urQueueCreate(hContext, device, None, ctypes.byref(queue))
# Prepare host memory
host_A = np.ones(n_elements, dtype=np.float32)*1.2 # Input array
host_B = np.empty(n_elements, dtype=np.float32) # Output array
# Copy input data to device
src_ptr = host_A.ctypes.data_as(ctypes.c_void_p)
urEnqueueMemBufferWrite(queue, dA, True, 0, n_elements * ctypes.sizeof(ctypes.c_float),
src_ptr, 0, None, None)
# Execute the Triton kernel to scale the data
scale = 10.0 # Scaling factor
triton_scale(dA_ptr.value, dB_ptr.value, scale, n_elements)
# TODO: Replace with proper synchronization
# Currently using sleep as a temporary solution
import time
time.sleep(2)
# Read results back from device
dst_ptr = host_B.ctypes.data_as(ctypes.c_void_p)
urEnqueueMemBufferRead(queue, dB, True, 0, n_elements * ctypes.sizeof(ctypes.c_float),
dst_ptr, 0, None, None)
# Wait for all commands to complete
urQueueFinish(queue)
# Compute ground truth for verification
gt = host_A * scale
# Calculate mean squared error between actual and expected results
mse = np.mean((gt - host_B) ** 2)
print(f"MSE:{mse}")
print(host_B)
# Clean up resources
urMemRelease(dA)
urMemRelease(dB)
urContextRelease(hContext)
break # Just process first device
break # Just process first platform
# Release adapters
for adapter in adapters:
urAdapterRelease(adapter)
# Shutdown UR loader
urLoaderTearDown()
except URException as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
EOF
python xpu_triton_ur.py

欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。
更多推荐
所有评论(0)