| | import math |
| | from typing import Optional, Tuple, TypeVar |
| | import torch.nn as nn |
| | import torch |
| | import triton |
| |
|
| | from functools import lru_cache |
| |
|
| |
|
| | from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd |
| |
|
| |
|
| | Layout = Tuple[torch.LongTensor, torch.LongTensor] |
| |
|
| |
|
| | def create_sparse_attn_mask( |
| | n_heads: int, |
| | max_seq_len: int, |
| | max_seq_len_k: int, |
| | dtype: torch.dtype, |
| | device: torch.device, |
| | BLOCK: int, |
| | local_blocks: int, |
| | vert_stride: int, |
| | homo_head: bool, |
| | return_dense: bool |
| | ) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]: |
| | layout, block_sparse_pattern, _ = _get_sparse_attn_mask( |
| | n_heads=n_heads, |
| | q_len=max_seq_len, |
| | N_CTX=max_seq_len_k, |
| | dtype=dtype, |
| | device=device, |
| | BLOCK=BLOCK, |
| | local_blocks=local_blocks, |
| | vert_stride=vert_stride, |
| | homo_head=homo_head, |
| | return_dense=return_dense |
| | ) |
| | return layout, block_sparse_pattern |
| |
|
| |
|
| | class BlockSparseAttentionLayer(nn.Module): |
| | def __init__( |
| | self, |
| | n_heads: int, |
| | max_seq_len: int, |
| | sparse_block_size: int, |
| | local_blocks: int, |
| | vert_stride: int, |
| | kernel_block_size: Optional[int] = None, |
| | homo_head: bool = False, |
| | active_head_range: Optional[Tuple[int]] = None |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.n_heads = n_heads |
| | self.max_seq_len = max_seq_len |
| | self.sparse_block_size = sparse_block_size |
| | self.kernel_block_size = kernel_block_size or sparse_block_size |
| | self.local_blocks = local_blocks |
| | self.vert_stride = vert_stride |
| | self.homo_head = homo_head |
| | self.active_head_range = active_head_range |
| |
|
| | |
| | self._sparse_block_mask = None |
| | self._sparse_layout = None |
| | self._dtype = None |
| | self._device = None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None: |
| | self._sparse_block_mask = self._sparse_block_mask[h_start: h_end] |
| | self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end] |
| | self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end] |
| | |
| | def _initialize_internals( |
| | self, |
| | dtype: torch.dtype, |
| | device: torch.device |
| | ) -> None: |
| | self._dtype, self._device = dtype, device |
| | self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask( |
| | n_heads=self.n_heads, |
| | max_seq_len=self.max_seq_len, |
| | max_seq_len_k=self.max_seq_len, |
| | dtype=dtype, |
| | device=device, |
| | BLOCK=self.sparse_block_size, |
| | local_blocks=self.local_blocks, |
| | vert_stride=self.vert_stride, |
| | homo_head=self.homo_head, |
| | return_dense=False, |
| | ) |
| | if (not self.homo_head) and (self.active_head_range is not None): |
| | assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads." |
| | h_start, h_end = self.active_head_range |
| | self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end) |
| |
|
| | assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}." |
| | assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given" |
| | if self.sparse_block_size // self.kernel_block_size > 1: |
| | _mul = self.sparse_block_size // self.kernel_block_size |
| | |
| | self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul)) |
| | num_sparse_blocks = self._sparse_block_mask.size(-1) |
| | block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None] |
| | self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask) |
| |
|
| |
|
| | def forward( |
| | self, |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | v: torch.Tensor, |
| | sm_scale: float, |
| | *, |
| | |
| | left_paddings: Optional[torch.LongTensor] = None, |
| | seqlens: Optional[torch.LongTensor] = None, |
| | |
| | cu_seqlens_k: Optional[torch.LongTensor] = None, |
| | cu_seqlens_q: Optional[torch.LongTensor] = None, |
| | ) -> torch.Tensor: |
| |
|
| | if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None: |
| | blocksparse_op = get_local_strided_sparse_attention_op( |
| | n_heads=self.n_heads, |
| | max_seq_len=self.max_seq_len, |
| | sparse_block_size=self.sparse_block_size, |
| | kernel_block_size=self.kernel_block_size, |
| | local_blocks=self.local_blocks, |
| | vert_stride=self.vert_stride, |
| | homo_head=self.homo_head, |
| | device=q.device, |
| | inference=not self.training |
| | ) |
| | return blocksparse_op(q, k, v, sm_scale) |
| |
|
| | assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context" |
| | |
| | if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device): |
| | self._initialize_internals(dtype=q.dtype, device=q.device) |
| | |
| | if k.dim() == 3: |
| | assert cu_seqlens_k is not None |
| | return blocksparse_flash_attn_varlen_fwd( |
| | q=q, |
| | k=k, |
| | v=v, |
| | cu_seqlens_k=cu_seqlens_k, |
| | cu_seqlens_q=cu_seqlens_q, |
| | sm_scale=sm_scale, |
| | sparse_layout=self._sparse_layout, |
| | block_size=self.kernel_block_size, |
| | max_seqlen=self.max_seq_len, |
| | ) |
| | if k.dim() == 4: |
| | assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference." |
| | return blocksparse_flash_attn_padded_fwd( |
| | q=q, |
| | k=k, |
| | v=v, |
| | sm_scale=sm_scale, |
| | sparse_layout=self._sparse_layout, |
| | left_paddings=left_paddings, |
| | seqlens=seqlens, |
| | block_size=self.kernel_block_size, |
| | max_seqlen=self.max_seq_len, |
| | ) |
| | raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.') |
| |
|