| | #include <torch/library.h> |
| |
|
| | #include "pytorch_shim.h" |
| | #include "registration.h" |
| | #include "torch_binding.h" |
| |
|
| | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| | ops.def("fwd(Tensor! q," |
| | " Tensor k," |
| | " Tensor v," |
| | " Tensor? k_new," |
| | " Tensor? v_new," |
| | " Tensor? q_v," |
| | " Tensor!? out," |
| | " Tensor? cu_seqlens_q," |
| | " Tensor? cu_seqlens_k," |
| | " Tensor? cu_seqlens_k_new," |
| | " Tensor? seqused_q," |
| | " Tensor? seqused_k," |
| | " int? max_seqlen_q," |
| | " int? max_seqlen_k," |
| | " Tensor? page_table," |
| | " Tensor? kv_batch_idx," |
| | " Tensor? leftpad_k," |
| | " Tensor? rotary_cos," |
| | " Tensor? rotary_sin," |
| | " Tensor? seqlens_rotary," |
| | " Tensor? q_descale," |
| | " Tensor? k_descale," |
| | " Tensor? v_descale," |
| | " float softmax_scale," |
| | " bool is_causal," |
| | " int window_size_left," |
| | " int window_size_right," |
| | " float softcap," |
| | " bool is_rotary_interleaved," |
| | " Tensor? scheduler_metadata," |
| | " int num_splits," |
| | " bool? pack_gqa," |
| | " int sm_margin," |
| | " Tensor? s_aux_) -> Tensor[]"); |
| | ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); |
| |
|
| | ops.def("bwd(Tensor dout," |
| | " Tensor q," |
| | " Tensor k," |
| | " Tensor v," |
| | " Tensor out," |
| | " Tensor softmax_lse," |
| | " Tensor!? dq," |
| | " Tensor!? dk," |
| | " Tensor!? dv," |
| | " Tensor? cu_seqlens_q," |
| | " Tensor? cu_seqlens_k," |
| | " Tensor? seqused_q," |
| | " Tensor? seqused_k," |
| | " int? max_seqlen_q," |
| | " int? max_seqlen_k," |
| | " float softmax_scale," |
| | " bool is_causal," |
| | " int window_size_left," |
| | " int window_size_right," |
| | " float softcap," |
| | " bool deterministic," |
| | " int sm_margin) -> Tensor[]"); |
| | ops.impl("bwd", torch::kCUDA, make_pytorch_shim(&mha_bwd)); |
| |
|
| | ops.def("fwd_combine(Tensor out_partial," |
| | " Tensor lse_partial," |
| | " Tensor!? out," |
| | " ScalarType? out_dtype) -> Tensor[]"); |
| | ops.impl("fwd_combine", torch::kCUDA, make_pytorch_shim(&mha_combine)); |
| |
|
| | ops.def("get_scheduler_metadata(" |
| | " int batch_size," |
| | " int max_seqlen_q," |
| | " int max_seqlen_k," |
| | " int num_heads," |
| | " int num_heads_k," |
| | " int headdim," |
| | " int headdim_v," |
| | " ScalarType qkv_dtype," |
| | " Tensor seqused_k," |
| | " Tensor? cu_seqlens_q," |
| | " Tensor? cu_seqlens_k," |
| | " Tensor? cu_seqlens_k_new," |
| | " Tensor? seqused_q," |
| | " Tensor? leftpad_k," |
| | " int? page_size," |
| | " int max_seqlen_k_new," |
| | " bool is_causal," |
| | " int window_size_left," |
| | " int window_size_right," |
| | " bool has_softcap," |
| | " int num_splits," |
| | " bool? pack_gqa," |
| | " int sm_margin) -> Tensor"); |
| | ops.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata)); |
| | } |
| |
|
| | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
| |
|