Skip to content

vllm.v1.attention.ops.common

DCPTritonContext

Avoids recompilation of the DCP Triton JIT kernel.

Source code in vllm/v1/attention/ops/common.py
class DCPTritonContext:
    """Avoids recompilation of the DCP Triton JIT kernel."""

    def __init__(self):
        self.inner_kernel = None

    def call_kernel(self, kernel, grid, *regular_args, **const_args):
        if self.inner_kernel is None:
            self.inner_kernel = kernel[grid](*regular_args, **const_args)
        else:
            self.inner_kernel[grid](*regular_args)

inner_kernel instance-attribute

inner_kernel = None

__init__

__init__()
Source code in vllm/v1/attention/ops/common.py
def __init__(self):
    self.inner_kernel = None

call_kernel

call_kernel(kernel, grid, *regular_args, **const_args)
Source code in vllm/v1/attention/ops/common.py
def call_kernel(self, kernel, grid, *regular_args, **const_args):
    if self.inner_kernel is None:
        self.inner_kernel = kernel[grid](*regular_args, **const_args)
    else:
        self.inner_kernel[grid](*regular_args)

_correct_dcp_attn_out_kernel

_correct_dcp_attn_out_kernel(
    outputs_ptr,
    new_output_ptr,
    lses_ptr,
    vlse_ptr,
    outputs_stride_B,
    outputs_stride_H,
    outputs_stride_D,
    lses_stride_N,
    lses_stride_B,
    lses_stride_H,
    lse_idx,
    HEAD_DIM: constexpr,
    N_ROUNDED: constexpr,
    IS_BASE_E: constexpr,
)

Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the final attention output.

Parameters:

Name Type Description Default
outputs_ptr PointerType

Pointer to input tensor of shape [ B, H, D ]

required
lses_ptr PointerType

Pointer to input tensor of shape [ N, B, H ]

required
new_output_ptr PointerType

Pointer to output tensor of shape [ B, H, D ]

required
vlse_ptr PointerType

Pointer to output tensor of shape [ B, H ]

required
Source code in vllm/v1/attention/ops/common.py
@triton.jit
def _correct_dcp_attn_out_kernel(
    outputs_ptr,
    new_output_ptr,
    lses_ptr,
    vlse_ptr,
    outputs_stride_B,
    outputs_stride_H,
    outputs_stride_D,
    lses_stride_N,
    lses_stride_B,
    lses_stride_H,
    lse_idx,
    HEAD_DIM: tl.constexpr,
    N_ROUNDED: tl.constexpr,
    IS_BASE_E: tl.constexpr,
):
    """
    Apply the all-gathered lses to correct each local rank's attention
    output. we still need perform a cross-rank reduction to obtain the
    final attention output.

    Args:
        outputs_ptr (triton.PointerType):
            Pointer to input tensor of shape [ B, H, D ]
        lses_ptr (triton.PointerType):
            Pointer to input tensor of shape [ N, B, H ]
        new_output_ptr (triton.PointerType):
            Pointer to output tensor of shape [ B, H, D ]
        vlse_ptr (triton.PointerType):
            Pointer to output tensor of shape [ B, H ]
    """
    batch_idx = tl.program_id(axis=0).to(tl.int64)
    head_idx = tl.program_id(axis=1).to(tl.int64)
    d_offsets = tl.arange(0, HEAD_DIM)
    num_n_offsets = tl.arange(0, N_ROUNDED)

    # shape = [N]
    lse_offsets = (
        num_n_offsets * lses_stride_N
        + batch_idx * lses_stride_B
        + head_idx * lses_stride_H
    )

    # calc final lse
    lse = tl.load(lses_ptr + lse_offsets)
    lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
    lse_max = tl.max(lse, axis=0)
    lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
    lse -= lse_max
    if IS_BASE_E:
        lse_exp = tl.exp(lse)
        lse_acc = tl.sum(lse_exp, axis=0)
        lse = tl.log(lse_acc)
    else:
        lse_exp = tl.exp2(lse)
        lse_acc = tl.sum(lse_exp, axis=0)
        lse = tl.log2(lse_acc)
    lse += lse_max

    lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
    tl.store(vlse_ptr + lse_offsets, lse)

    # shape = [D]
    output_offsets = (
        batch_idx * outputs_stride_B
        + head_idx * outputs_stride_H
        + d_offsets * outputs_stride_D
    )

    # correct output
    lse_offset = (
        lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H
    )
    lse_tmp = tl.load(lses_ptr + lse_offset)
    lse_finally = lse_tmp - lse
    lse_finally = tl.where(
        (lse_finally != lse_finally) | (lse_finally == float("inf")),
        -float("inf"),
        lse_finally,
    )
    factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
    output = tl.load(outputs_ptr + output_offsets)
    output = output * factor

    tl.store(new_output_ptr + output_offsets, output)

_dcp_correct_attn_out

_dcp_correct_attn_out(
    out: Tensor,
    lses: Tensor,
    dcp_rank: int,
    ctx: DCPTritonContext,
    is_lse_base_on_e: bool = True,
) -> tuple[Tensor, Tensor]

Correct the attention output using the all-gathered lses.

Parameters:

Name Type Description Default
out Tensor

Tensor of shape [ B, H, D ]

required
lses Tensor

Tensor of shape [ N, B, H ]

required
dcp_rank int

Current rank in the DCP group

required
ctx DCPTritonContext

Triton context to avoid recompilation

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (out, lse) with corrected attention and final log-sum-exp.

Source code in vllm/v1/attention/ops/common.py
def _dcp_correct_attn_out(
    out: torch.Tensor,
    lses: torch.Tensor,
    dcp_rank: int,
    ctx: DCPTritonContext,
    is_lse_base_on_e: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Correct the attention output using the all-gathered lses.

    Args:
        out: Tensor of shape [ B, H, D ]
        lses: Tensor of shape [ N, B, H ]
        dcp_rank: Current rank in the DCP group
        ctx: Triton context to avoid recompilation

    Returns:
        Tuple of (out, lse) with corrected attention and final log-sum-exp.
    """
    if ctx is None:
        ctx = DCPTritonContext()

    # --- Normalize to 3D views ---
    if out.ndim == 4 and out.shape[1] == 1:
        out = out.squeeze(1)
    assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"

    if lses.ndim == 4 and lses.shape[-1] == 1:
        lses = lses.squeeze(-1)
    if lses.ndim == 4 and lses.shape[1] == 1:
        lses = lses.squeeze(1)
    assert lses.ndim == 3, (
        f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
        f"got {tuple(lses.shape)}"
    )

    B, H, D = out.shape
    N = lses.shape[0]

    # Strides after we normalized shapes to 3-D views.  The kernel computes
    # offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
    # have the same B/H stride layout as a slice of `lses`.
    o_sB, o_sH, o_sD = out.stride()
    l_sN, l_sB, l_sH = lses.stride()

    # Allocate LSE with the same B/H strides as `lses` so writes land correctly
    # even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
    lse = torch.empty_strided(
        (B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
    )

    # Kernel launch config
    grid = (B, H, 1)

    regular_args = (
        out,
        out,
        lses,
        lse,
        o_sB,
        o_sH,
        o_sD,
        l_sN,
        l_sB,
        l_sH,
        dcp_rank,
    )
    const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e}
    ctx.call_kernel(_correct_dcp_attn_out_kernel, grid, *regular_args, **const_args)
    return out, lse

_pack_seq_kernel

_pack_seq_kernel(
    x_ptr,
    out_ptr,
    lengths_ptr,
    N: constexpr,
    D: constexpr,
    Lmax: constexpr,
    PAD_VALUE: constexpr,
    BLOCK_T: constexpr,
    BLOCK_D: constexpr,
)
Source code in vllm/v1/attention/ops/common.py
@triton.jit
def _pack_seq_kernel(
    x_ptr,  # [N, D]
    out_ptr,  # [B, Lmax, D]
    lengths_ptr,  # *i32, [B]
    N: tl.constexpr,
    D: tl.constexpr,
    Lmax: tl.constexpr,
    PAD_VALUE: tl.constexpr,
    BLOCK_T: tl.constexpr,  # timesteps per program
    BLOCK_D: tl.constexpr,  # features per program
):
    pid_b = tl.program_id(0)  # batch id
    pid_t = tl.program_id(1)  # block over time dimension
    pid_d = tl.program_id(2)  # block over feature dimension
    off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)  # [BLOCK_T]
    off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)  # [BLOCK_D]

    # Compute start index and sequence length from cumulative lengths
    in_start = 0
    for i in range(pid_b):
        in_start += tl.load(lengths_ptr + i)
    seq_len = tl.load(lengths_ptr + pid_b)

    # valid time positions for this block
    t_mask = off_t < Lmax

    # compute input row indices for valid (b, t)
    in_row = in_start + off_t
    valid_row = (off_t < seq_len) & t_mask

    # Pointers
    # x_ptr: row-major [N, D]
    x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]

    # out_ptr: row-major [B, Lmax, D]
    out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]

    # Initialize with PAD (cast will occur as needed based on out_ptr dtype)
    d_mask = off_d[None, :] < D
    pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
    tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)

    # Load & write only where within seq_len
    x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
    tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)

_unpack_seq_triton_kernel

_unpack_seq_triton_kernel(
    packed_ptr,
    out_ptr,
    lengths_ptr,
    B: constexpr,
    Lmax: constexpr,
    D: constexpr,
    BLOCK_T: constexpr,
    BLOCK_D: constexpr,
)
Source code in vllm/v1/attention/ops/common.py
@triton.jit
def _unpack_seq_triton_kernel(
    packed_ptr,  # [B, Lmax, D]
    out_ptr,  # [N, D]
    lengths_ptr,  # *i32, [B]
    B: tl.constexpr,
    Lmax: tl.constexpr,
    D: tl.constexpr,
    BLOCK_T: tl.constexpr,  # timesteps per program
    BLOCK_D: tl.constexpr,  # features per program
):
    pid_b = tl.program_id(0)  # batch id
    pid_t = tl.program_id(1)  # block over time dimension
    pid_d = tl.program_id(2)  # block over feature dimension
    off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)  # [BLOCK_T]
    off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)  # [BLOCK_D]

    # bounds: compute start from cumulative lengths
    in_start = 0
    for i in range(pid_b):
        in_start += tl.load(lengths_ptr + i)
    seq_len = tl.load(lengths_ptr + pid_b)

    # valid time positions for this block
    t_mask = off_t < Lmax
    valid_row = (off_t < seq_len) & t_mask

    # compute output row indices for valid (b, t)
    out_row = in_start + off_t

    # Pointers
    # packed_ptr: row-major [B, Lmax, D]
    packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]

    # out_ptr: row-major [N, D]
    out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]

    # Load from packed tensor and store to output
    d_mask = off_d[None, :] < D
    packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
    tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)

dcp_prepare_query

dcp_prepare_query(query: Tensor) -> Tensor

Prepare query for DCP decode attention by all-gathering across TP.

Source code in vllm/v1/attention/ops/common.py
def dcp_prepare_query(query: torch.Tensor) -> torch.Tensor:
    """
    Prepare query for DCP decode attention by all-gathering across TP.
    """
    return get_tp_group().all_gather(query, dim=1)

dcp_reduce_output

dcp_reduce_output(
    attn_output: Tensor,
    attn_lse: Tensor,
    ctx: DCPTritonContext | None = None,
    return_lse: bool = False,
) -> Tensor | tuple[Tensor, Tensor]

Reduce DCP partial attention outputs across the DCP group and scatter back to TP-local heads.

Each DCP rank holds attention computed over its local KV shard for all heads (after the TP all-gather in dcp_prepare_query). This function combines the partial results using the LSE correction and distributes the final output so each rank ends up with only its TP-local heads.

Source code in vllm/v1/attention/ops/common.py
def dcp_reduce_output(
    attn_output: torch.Tensor,
    attn_lse: torch.Tensor,
    ctx: DCPTritonContext | None = None,
    return_lse: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Reduce DCP partial attention outputs across the DCP group and
    scatter back to TP-local heads.

    Each DCP rank holds attention computed over its local KV shard for all
    heads (after the TP all-gather in ``dcp_prepare_query``).  This function
    combines the partial results using the LSE correction and distributes
    the final output so each rank ends up with only its TP-local heads.
    """
    tp_group = get_tp_group()
    dcp_group = get_dcp_group()
    if ctx is None:
        ctx = DCPTritonContext()

    # All-gather LSEs and apply correction to the local output.
    lse = attn_lse.contiguous()
    lses = dcp_group.all_gather(lse, dim=0).reshape(
        (dcp_group.world_size,) + attn_lse.shape
    )
    attn_output, lse = _dcp_correct_attn_out(
        attn_output, lses, dcp_group.rank_in_group, ctx
    )

    # Reduce across DCP ranks and scatter to TP-local heads.
    if get_pcp_group().world_size == 1:
        # PCP=1 ⇒ DCP group == TP group: reduce-scatter combines the
        # cross-rank sum and TP head scatter in a single collective.
        attn_output = dcp_group.reduce_scatter(attn_output, dim=1)
    else:
        # PCP>1 ⇒ DCP ⊂ TP: DCP peers share TP head assignments so
        # reduce-scatter would mis-distribute heads.  All-reduce first,
        # then slice to this rank's TP-local heads.
        attn_output = dcp_group.all_reduce(attn_output)
        h = attn_output.shape[1] // tp_group.world_size
        r = tp_group.rank_in_group
        attn_output = attn_output[:, h * r : h * (r + 1)].contiguous()

    if return_lse:
        h = lse.shape[1] // tp_group.world_size
        r = tp_group.rank_in_group
        return attn_output, lse[:, h * r : h * (r + 1)]
    return attn_output

pack_seq_triton

pack_seq_triton(
    x: Tensor,
    lengths: Tensor,
    pad_value: float = -float("inf"),
    block_t: int = 64,
    block_d: int = 64,
) -> Tensor

Pack sequences of different lengths into a batched tensor.

Parameters:

Name Type Description Default
x Tensor

[N, ...] - input tensor where N is total number of tokens

required
lengths Tensor

[B] - sequence lengths for each batch

required
pad_value float

value to use for padding

-float('inf')
block_t int

block size for time dimension

64
block_d int

block size for feature dimension

64

Returns:

Name Type Description
packed Tensor

[B, Lmax, ...] - packed tensor

Source code in vllm/v1/attention/ops/common.py
def pack_seq_triton(
    x: torch.Tensor,
    lengths: torch.Tensor,
    pad_value: float = -float("inf"),
    block_t: int = 64,
    block_d: int = 64,
) -> torch.Tensor:
    """
    Pack sequences of different lengths into a batched tensor.

    Args:
        x: [N, ...] - input tensor where N is total number of tokens
        lengths: [B] - sequence lengths for each batch
        pad_value: value to use for padding
        block_t: block size for time dimension
        block_d: block size for feature dimension

    Returns:
        packed: [B, Lmax, ...] - packed tensor
    """

    # Handle multi-dimensional input by reshaping to (N, -1)
    original_shape = x.shape
    if len(original_shape) > 2:
        N = original_shape[0]
        x_reshaped = x.reshape(N, -1)
        D = x_reshaped.shape[1]
    else:
        N, D = x.shape
        x_reshaped = x

    B = lengths.numel()
    Lmax = int(lengths.max().item())

    # Starts are computed inside the kernel from lengths

    out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)

    grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
    _pack_seq_kernel[grid](
        x_reshaped,
        out,
        lengths.int(),
        N,
        D,
        Lmax,
        PAD_VALUE=float(pad_value),
        BLOCK_T=block_t,
        BLOCK_D=block_d,
        num_warps=4,
        num_stages=2,
    )

    # Reshape output back to original dimensions (except first dimension)
    if len(original_shape) > 2:
        output_shape = (B, Lmax) + original_shape[1:]
        out = out.reshape(output_shape)

    return out

unpack_seq_triton

unpack_seq_triton(
    packed_tensor: Tensor,
    lengths: Tensor,
    block_t: int = 64,
    block_d: int = 64,
) -> Tensor

Unpack a packed decode query tensor back to the original format. Efficient Triton implementation.

Parameters:

Name Type Description Default
packed_tensor Tensor

[B, Lmax, ...] - packed tensor from pack_seq_triton

required
lengths Tensor

[B] - sequence lengths for each batch

required
block_t int

block size for time dimension

64
block_d int

block size for feature dimension

64

Returns:

Name Type Description
unpacked_tensor Tensor

[N, ...] where N = sum(lengths)

Source code in vllm/v1/attention/ops/common.py
def unpack_seq_triton(
    packed_tensor: torch.Tensor,
    lengths: torch.Tensor,
    block_t: int = 64,
    block_d: int = 64,
) -> torch.Tensor:
    """
    Unpack a packed decode query tensor back to the original format.
    Efficient Triton implementation.

    Args:
        packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
        lengths: [B] - sequence lengths for each batch
        block_t: block size for time dimension
        block_d: block size for feature dimension

    Returns:
        unpacked_tensor: [N, ...] where N = sum(lengths)
    """

    # Handle multi-dimensional input by reshaping to (B, Lmax, -1)
    original_shape = packed_tensor.shape
    if len(original_shape) > 3:
        B, Lmax = original_shape[:2]
        packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
        D = packed_reshaped.shape[2]
    else:
        B, Lmax, D = packed_tensor.shape
        packed_reshaped = packed_tensor

    # Calculate total number of elements
    N = int(lengths.sum().item())

    out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype)

    grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
    _unpack_seq_triton_kernel[grid](
        packed_reshaped,
        out,
        lengths.int(),
        B,
        Lmax,
        D,
        BLOCK_T=block_t,
        BLOCK_D=block_d,
        num_warps=4,
        num_stages=2,
    )

    # Reshape output back to original dimensions (except first dimension)
    if len(original_shape) > 3:
        output_shape = (N,) + original_shape[2:]
        out = out.reshape(output_shape)

    return out