Skip to content

The position of if-else significantly affect performance, which is unexpected. #6491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wenhaoli-xmu opened this issue Apr 15, 2025 · 0 comments

Comments

@wenhaoli-xmu
Copy link

wenhaoli-xmu commented Apr 15, 2025

Describe the issue

Hi, I'm here again. This time, I bring a very counter intuitive problem.

That is, why putting if-else condition inside a loop is way more faster (~10x) than putting is outside a loop?

Below are my two versions of code:

if-else outside loop
@triton.jit
def _masked_ffn_infer(
        a_ptr, 
        w1_ptr,
        w3_ptr, 
        u1_ptr, 
        u3_ptr,
        c_ptr, 
        m_ptr,
        M,
        stride_at, stride_am,
        stride_wt, stride_wn,
        stride_cm, stride_cn,
        T: tl.constexpr,
        N: tl.constexpr,
        TILE_M: tl.constexpr, 
        TILE_N: tl.constexpr, 
        TILE_K: tl.constexpr,
        GROUP_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, TILE_M)
    num_pid_n = tl.cdiv(N, TILE_N)
    num_pid_in_group = GROUP_SIZE * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    m_range = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M
    n_range = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N
    k_range = tl.arange(0, TILE_K)

    m_data = tl.load(m_ptr + m_range).to(tl.int1)

    a_offs = m_range[:, None] * stride_am + k_range[None, :]
    w_offs = n_range[:, None] * stride_wn + k_range[None, :]
    c_offs = m_range[:, None] * stride_cm + n_range[None, :]

    acc1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
    acc3 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)

    if tl.max(m_data) == 0:
        for t in range(0, T):
            a = tl.load(a_ptr + a_offs)
            u1 = tl.load(u1_ptr + w_offs)
            u3 = tl.load(u3_ptr + w_offs)
            acc1 = tl.dot(a, u1.T, acc1)
            acc3 = tl.dot(a, u3.T, acc3)
            a_ptr += stride_at
            u1_ptr += stride_wt
            u3_ptr += stride_wt
    
    elif tl.min(m_data) == 1:
        for t in range(0, T):
            a = tl.load(a_ptr + a_offs)
            w1 = tl.load(w1_ptr + w_offs)
            w3 = tl.load(w3_ptr + w_offs)
            acc1 = tl.dot(a, w1.T, acc1)
            acc3 = tl.dot(a, w3.T, acc3)
            a_ptr += stride_at
            w1_ptr += stride_wt
            w3_ptr += stride_wt
    
    else:
        for t in range(0, T):
            a = tl.load(a_ptr + a_offs)
            w1 = tl.load(w1_ptr + w_offs)
            w3 = tl.load(w3_ptr + w_offs)
            u1 = tl.load(u1_ptr + w_offs)
            u3 = tl.load(u3_ptr + w_offs)
            acc1 += tl.where(
                m_data[:, None],
                tl.dot(a, w1.T),
                tl.dot(a, u1.T))
            acc3 += tl.where(
                m_data[:, None],
                tl.dot(a, w3.T),
                tl.dot(a, u3.T))
            a_ptr += stride_at
            w1_ptr += stride_wt
            w3_ptr += stride_wt
            u1_ptr += stride_wt
            u3_ptr += stride_wt

    acc1 *= tl.sigmoid(acc1)
    acc1 *= acc3
    tl.store(
        c_ptr + c_offs, 
        acc1.to(tl.bfloat16), 
        m_range[:, None] < M)
if-else inside loop
@triton.jit
def _masked_ffn_infer(
        a_ptr, 
        w1_ptr,
        w3_ptr, 
        u1_ptr, 
        u3_ptr,
        c_ptr, 
        m_ptr,
        M,
        stride_at, stride_am,
        stride_wt, stride_wn,
        stride_cm, stride_cn,
        T: tl.constexpr,
        N: tl.constexpr,
        TILE_M: tl.constexpr, 
        TILE_N: tl.constexpr, 
        TILE_K: tl.constexpr,
        GROUP_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, TILE_M)
    num_pid_n = tl.cdiv(N, TILE_N)
    num_pid_in_group = GROUP_SIZE * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    m_range = (pid_m * TILE_M + tl.arange(0, TILE_M)) % M
    n_range = (pid_n * TILE_N + tl.arange(0, TILE_N)) % N
    k_range = tl.arange(0, TILE_K)

    m_data = tl.load(m_ptr + m_range).to(tl.int1)

    a_offs = m_range[:, None] * stride_am + k_range[None, :]
    w_offs = n_range[:, None] * stride_wn + k_range[None, :]
    c_offs = m_range[:, None] * stride_cm + n_range[None, :]

    acc1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
    acc3 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)

    
    for t in range(0, T):

        a = tl.load(a_ptr + a_offs)

        if tl.max(m_data) == 0:
            
            u1 = tl.load(u1_ptr + w_offs)
            u3 = tl.load(u3_ptr + w_offs)
            acc1 = tl.dot(a, u1.T, acc1)
            acc3 = tl.dot(a, u3.T, acc3)
    
        elif tl.min(m_data) == 1:
            a = tl.load(a_ptr + a_offs)
            w1 = tl.load(w1_ptr + w_offs)
            w3 = tl.load(w3_ptr + w_offs)
            acc1 = tl.dot(a, w1.T, acc1)
            acc3 = tl.dot(a, w3.T, acc3)
    
        else:
            w1 = tl.load(w1_ptr + w_offs)
            w3 = tl.load(w3_ptr + w_offs)
            u1 = tl.load(u1_ptr + w_offs)
            u3 = tl.load(u3_ptr + w_offs)
            acc1 += tl.where(
                m_data[:, None],
                tl.dot(a, w1.T),
                tl.dot(a, u1.T))
            acc3 += tl.where(
                m_data[:, None],
                tl.dot(a, w3.T),
                tl.dot(a, u3.T))

        a_ptr += stride_at
        w1_ptr += stride_wt
        w3_ptr += stride_wt
        u1_ptr += stride_wt
        u3_ptr += stride_wt

    acc1 *= tl.sigmoid(acc1)
    acc1 *= acc3
    tl.store(
        c_ptr + c_offs, 
        acc1.to(tl.bfloat16), 
        m_range[:, None] < M)

Considering the redundant computation in the inside version, it should be slower if we use any other programming language, but this is not how Triton works.

Do you have some ideas? Appreciate it! ☺️☺️

Environment details

Triton: 3.1.0
GPU: A100

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant