diff --git a/ops/Makefile b/ops/Makefile deleted file mode 100644 index 7bf1b86d5500..000000000000 --- a/ops/Makefile +++ /dev/null @@ -1,19 +0,0 @@ -.PHONY: pre-release bdist-wheel - -# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) -export PYTHONPATH = src - -bdist-wheel: - cd csrc && python setup.py build && cd .. && python setup.py bdist_wheel - -pre-release: - python utils/release.py - -pre-patch: - python utils/release.py --patch - -post-release: - python utils/release.py --post_release - -post-patch: - python utils/release.py --post_release --patch diff --git a/ops/README.md b/ops/README.md deleted file mode 100644 index f084c6f94bed..000000000000 --- a/ops/README.md +++ /dev/null @@ -1,72 +0,0 @@ -# PaddleNLP Kernel 库 -> paddlenlp-kernel 是一个专为 PaddleNLP 量身打造的 GPU 算子库,它集成了一系列常用的自然语言处理(NLP)算子,并提供了 CUDA 和 Triton 两种高效的实现方式,旨在充分利用 GPU 的卓越计算能力,为 NLP 任务加速。 - -当前支持的算子包括: -- mamba1 和 mamba2 算子 -- fast_ln 和 fused_ln 算子 -- ml-cross-entropy 算子 -- inf_cl 算子 - -# 安装指南 - -## 编译 cuda 算子 -要编译 `CUDA` 算子,请执行以下命令: -```bash -cd csrc -rm -rf build dist *.egg-info # 清除旧的构建文件和目录 -python setup.py build # 开始新的编译过程 -``` - -## 打包 wheel -编译完成后,您可以将 `CUDA` 算子打包成 `Wheel` 包以便安装: -```bash -python setup.py bdist_wheel -``` - -## 安装 wheel -使用 pip 命令安装刚刚生成的 Wheel 包: -```bash -pip install dist/*.whl -``` - -## 使用 paddlenlp_kernel 库 -以下是如何在代码中使用 `CUDA` 和 `Triton` 算子的示例: -```python -# 导入并使用 CUDA 算子 -from paddlenlp_kernel.cuda.selective_scan import selective_scan_fn -xxx = selective_scan_fn(xxx) - -# 导入并使用 Triton 算子 -from paddlenlp_kernel.triton.inf_cl import cal_flash_loss -xxx = cal_flash_loss(xxx) -``` - -# 测试 - -要测试 `CUDA` 和 `Triton` 算子,请分别运行以下命令: -```bash -pytest -v tests/cuda # 测试 CUDA 算子 -pytest -v tests/triton # 测试 Triton 算子 -``` - -通过上述步骤,您将能够顺利安装并测试 `paddlenlp_kernel` 库,享受 GPU 加速带来的高效 NLP 算子体验。 - -# 注意 - -推荐用户使用以下版本的库: -- paddlepaddle-gpu >= 3.0.0b2 -- triton >= 3.0.0 - -由于 `Triton` 库原本依赖于 `PyTorch`,为了方便 `Paddle` 用户使用 `Triton`,您可以按照以下步骤替换 `Triton` 库的部分源码,使其与 `Paddle` 兼容: - -```bash -python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git -# 只需执行一次以下命令,之后即可在任意终端中使用 Triton,无需重复执行 -python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()" -``` - -# 参考资料 -- https://github.com/state-spaces/mamba -- https://github.com/Dao-AILab/causal-conv1d -- https://github.com/apple/ml-cross-entropy -- https://github.com/DAMO-NLP-SG/Inf-CLIP diff --git a/ops/csrc/causal_conv1d/causal_conv1d.cpp b/ops/csrc/causal_conv1d/causal_conv1d.cpp deleted file mode 100644 index ac167293bd95..000000000000 --- a/ops/csrc/causal_conv1d/causal_conv1d.cpp +++ /dev/null @@ -1,468 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#include -#include -#include - -#include "causal_conv1d.h" - -#define CHECK_SHAPE(x, ...) PD_CHECK(x.dims() == common::make_ddim({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == paddle::DataType::FLOAT16) { \ - using input_t = phi::dtype::float16; \ - __VA_ARGS__(); \ - } else if (ITYPE == paddle::DataType::BFLOAT16) { \ - using input_t = phi::dtype::bfloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == paddle::DataType::FLOAT32) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - PADDLE_THROW(#NAME, " not implemented for input type '", ITYPE, "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == paddle::DataType::FLOAT16) { \ - using weight_t = phi::dtype::float16; \ - __VA_ARGS__(); \ - } else if (WTYPE == paddle::DataType::BFLOAT16) { \ - using weight_t = phi::dtype::bfloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == paddle::DataType::FLOAT32) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - PADDLE_THROW(#NAME, " not implemented for weight type '", WTYPE, "'"); \ - } - -template -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -template -void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template -void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); - -template -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -void set_conv_params_fwd(ConvParamsBase ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t width, - // device pointers - const paddle::Tensor x, - const paddle::Tensor weight, - const paddle::Tensor out, - void* bias_ptr, - bool silu_activation) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.width = width; - - params.silu_activation = silu_activation; - - // Set the pointers and strides. - params.x_ptr = const_cast(x.data()); - params.weight_ptr = const_cast(weight.data()); - params.bias_ptr = const_cast(bias_ptr); - params.out_ptr = const_cast(out.data()); - // All stride are in elements, not bytes. - params.x_batch_stride = x.strides()[0]; - params.x_c_stride = x.strides()[1]; - params.x_l_stride = x.strides()[x.strides().size() - 1]; - params.weight_c_stride = weight.strides()[0]; - params.weight_width_stride = weight.strides()[1]; - params.out_batch_stride = out.strides()[0]; - params.out_c_stride = out.strides()[1]; - params.out_l_stride = out.strides()[out.strides().size() - 1]; -} - - -void set_conv_params_bwd(ConvParamsBwd ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t width, - // device pointers - const paddle::Tensor x, - const paddle::Tensor weight, - void* bias_ptr, - const paddle::Tensor dout, - const paddle::Tensor dx, - const paddle::Tensor dweight, - void* dbias_ptr, - bool silu_activation) { - // Pass in "dout" instead of "out", we're not gonna use "out" at all. - set_conv_params_fwd(params, batch, dim, seqlen, width, - x, weight, dout, bias_ptr, silu_activation); - - // Set the pointers and strides. - params.dout_ptr = const_cast(dout.data()); - params.dx_ptr = const_cast(dx.data()); - params.dweight_ptr = const_cast(dweight.data()); - params.dbias_ptr = const_cast(dbias_ptr); - // All stride are in elements, not bytes. - params.dout_batch_stride = dout.strides()[0]; - params.dout_c_stride = dout.strides()[1]; - params.dout_l_stride = dout.strides()[2]; - params.dweight_c_stride = dweight.strides()[0]; - params.dweight_width_stride = dweight.strides()[1]; - params.dx_batch_stride = dx.strides()[0]; - params.dx_c_stride = dx.strides()[1]; - params.dx_l_stride = dx.strides()[2]; -} - -paddle::Tensor -causal_conv1d_fwd(const paddle::Tensor &x, const paddle::Tensor &weight, - const std::optional &bias_, - const std::optional &seq_idx_, - const std::optional &initial_states_, - std::optional &final_states_out_, - bool silu_activation) { - auto input_type = x.dtype(); - auto weight_type = weight.dtype(); - PD_CHECK(input_type == paddle::DataType::FLOAT32 || input_type == paddle::DataType::FLOAT16 || input_type == paddle::DataType::BFLOAT16); - PD_CHECK(weight_type == paddle::DataType::FLOAT32 || weight_type == paddle::DataType::FLOAT16 || weight_type == paddle::DataType::BFLOAT16); - - PD_CHECK(x.is_gpu()); - PD_CHECK(weight.is_gpu()); - - const auto sizes = x.dims(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.dims()[weight.dims().size() - 1]; - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(weight, dim, width); - - PD_CHECK(x.strides()[2] == 1 || x.strides()[1] == 1); - const bool is_channel_last = x.strides()[1] == 1 && x.strides()[2] > 1; - - if (is_channel_last) { - PD_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); - PD_CHECK(x.strides()[2] % 8 == 0 and x.strides()[0] % 8 == 0, "causal_conv1d with channel last layout requires strides (x.strides()[0] and x.strides()[2]) to be multiples of 8"); - } - PD_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - PD_CHECK(bias.dtype() == weight_type); - PD_CHECK(bias.is_gpu()); - PD_CHECK(bias.strides()[bias.strides().size() - 1] == 1); - CHECK_SHAPE(bias, dim); - } - - if (seq_idx_.has_value()) { - PD_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); - auto seq_idx = seq_idx_.value(); - PD_CHECK(seq_idx.dtype() == paddle::DataType::INT32 || seq_idx.dtype() == paddle::DataType::INT64); - PD_CHECK(seq_idx.is_gpu()); - // PD_CHECK(seq_idx.is_contiguous()); - CHECK_SHAPE(seq_idx, batch_size, seqlen); - } - - paddle::Tensor out = paddle::empty_like(x); - // NOTE: new added - if (is_channel_last) { - out = paddle::experimental::as_strided(out, {batch_size, dim, seqlen}, {dim * seqlen, 1, dim}); - } - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_.has_value() ? const_cast(bias_.value().data()) : nullptr, - silu_activation); - - if (seq_idx_.has_value()) { - params.seq_idx_ptr = const_cast(seq_idx_.value().data()); - } else { - params.seq_idx_ptr = nullptr; - } - - if (initial_states_.has_value()) { - PD_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); - auto initial_states = initial_states_.value(); - PD_CHECK(initial_states.dtype() == input_type); - PD_CHECK(initial_states.is_gpu()); - CHECK_SHAPE(initial_states, batch_size, dim, width - 1); - PD_CHECK(initial_states.strides()[1] == 1); - params.initial_states_ptr = const_cast(initial_states.data()); - params.initial_states_batch_stride = initial_states.strides()[0]; - params.initial_states_c_stride = initial_states.strides()[1]; - params.initial_states_l_stride = initial_states.strides()[2]; - } else { - params.initial_states_ptr = nullptr; - } - - if (final_states_out_.has_value()) { - PD_CHECK(is_channel_last, "final_states is only supported for channel last layout"); - auto final_states = final_states_out_.value(); - PD_CHECK(final_states.dtype() == input_type); - PD_CHECK(final_states.is_gpu()); - CHECK_SHAPE(final_states, batch_size, dim, width - 1); - PD_CHECK(final_states.strides()[1] == 1); - params.final_states_ptr = const_cast(final_states.data()); - params.final_states_batch_stride = final_states.strides()[0]; - params.final_states_c_stride = final_states.strides()[1]; - params.final_states_l_stride = final_states.strides()[2]; - } else { - params.final_states_ptr = nullptr; - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - auto stream = x.stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.dtype(), "causal_conv1d_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.dtype(), "causal_conv1d_fwd", [&] { - if (!is_channel_last) { - causal_conv1d_fwd_cuda(params, stream); - } else { - causal_conv1d_channellast_fwd_cuda(params, stream); - } - }); - }); - return out; -} - -std::vector -causal_conv1d_bwd(const paddle::Tensor &x, const paddle::Tensor &weight, - const std::optional &bias_, - paddle::Tensor &dout, - const std::optional &seq_idx_, - const std::optional &initial_states_, - const std::optional &dfinal_states_, - std::optional &dx_, - bool return_dinitial_states, - bool silu_activation) { - auto input_type = x.dtype(); - auto weight_type = weight.dtype(); - PD_CHECK(input_type == paddle::DataType::FLOAT32 || input_type == paddle::DataType::FLOAT16 || input_type == paddle::DataType::BFLOAT16); - PD_CHECK(weight_type == paddle::DataType::FLOAT32 || weight_type == paddle::DataType::FLOAT16 || weight_type == paddle::DataType::BFLOAT16); - - PD_CHECK(x.is_gpu()); - PD_CHECK(weight.is_gpu()); - PD_CHECK(dout.is_gpu()); - - const auto sizes = x.dims(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.dims()[weight.dims().size() - 1]; - - PD_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(weight, dim, width); - CHECK_SHAPE(dout, batch_size, dim, seqlen); - - PD_CHECK(x.strides()[2] == 1 || x.strides()[1] == 1); - const bool is_channel_last = x.strides()[1] == 1 && x.strides()[2] > 1; - // NOTE: 由于缺少contiguous算子,所以在外面做。 - // if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); } - // if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose({0, 2, 1}).contiguous().transpose({0, 2, 1}); } - - if (is_channel_last) { - PD_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); - PD_CHECK(x.strides()[2] % 8 == 0 and x.strides()[0] % 8 == 0, "causal_conv1d with channel last layout requires strides (x.strides()[0] and x.strides()[2]) to be multiples of 8"); - PD_CHECK(dout.strides()[2] % 8 == 0 and dout.strides()[0] % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.strides()[0] and dout.strides()[2]) to be multiples of 8"); - } - - if (bias_.has_value()) { - auto bias = bias_.value(); - PD_CHECK(bias.dtype() == weight_type); - PD_CHECK(bias.is_gpu()); - PD_CHECK(bias.strides()[bias.strides().size() - 1] == 1); - CHECK_SHAPE(bias, dim); - } - - if (seq_idx_.has_value()) { - PD_CHECK(is_channel_last, "seq_idx only supported for channel last layout"); - auto seq_idx = seq_idx_.value(); - PD_CHECK(seq_idx.dtype() == paddle::DataType::INT32 || seq_idx.dtype() == paddle::DataType::INT64); - PD_CHECK(seq_idx.is_gpu()); - // PD_CHECK(seq_idx.is_contiguous()); - CHECK_SHAPE(seq_idx, batch_size, seqlen); - } - - paddle::Tensor dx; - if (dx_.has_value()) { - dx = dx_.value(); - PD_CHECK(dx.dtype() == input_type); - PD_CHECK(dx.is_gpu()); - CHECK_SHAPE(dx, batch_size, dim, seqlen); - if (!is_channel_last) { PD_CHECK(dx.strides()[2] == 1); } - if (is_channel_last) { PD_CHECK(dx.strides()[1] == 1); } - } else { - dx = paddle::empty_like(x); - if (is_channel_last) { - dx = paddle::experimental::as_strided(dx, {batch_size, dim, seqlen}, {dim * seqlen, 1, dim}); - } - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - // make sure dweight and dbias dtype paddle::DataType::FLOAT32 - paddle::Tensor dweight = paddle::experimental::zeros_like(weight, paddle::DataType::FLOAT32); - paddle::Tensor dbias; - if (bias_.has_value()) { dbias = paddle::experimental::zeros_like(bias_.value(), paddle::DataType::FLOAT32); } - - ConvParamsBwd params; - set_conv_params_bwd(params, batch_size, dim, seqlen, width, - x, weight, bias_.has_value() ? const_cast(bias_.value().data()) : nullptr, - dout, dx, dweight, bias_.has_value() ? const_cast(dbias.data()) : nullptr, - silu_activation); - - if (seq_idx_.has_value()) { - params.seq_idx_ptr = const_cast(seq_idx_.value().data()); - } else { - params.seq_idx_ptr = nullptr; - } - - if (initial_states_.has_value()) { - PD_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); - auto initial_states = initial_states_.value(); - PD_CHECK(initial_states.dtype() == input_type); - PD_CHECK(initial_states.is_gpu()); - CHECK_SHAPE(initial_states, batch_size, dim, width - 1); - PD_CHECK(initial_states.strides()[1] == 1); - params.initial_states_ptr = const_cast(initial_states.data()); - params.initial_states_batch_stride = initial_states.strides()[0]; - params.initial_states_c_stride = initial_states.strides()[1]; - params.initial_states_l_stride = initial_states.strides()[2]; - } else { - params.initial_states_ptr = nullptr; - } - - if (dfinal_states_.has_value()) { - PD_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout"); - auto dfinal_states = dfinal_states_.value(); - PD_CHECK(dfinal_states.dtype() == input_type); - PD_CHECK(dfinal_states.is_gpu()); - CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1); - params.dfinal_states_ptr = const_cast(dfinal_states.data()); - params.dfinal_states_batch_stride = dfinal_states.strides()[0]; - params.dfinal_states_c_stride = dfinal_states.strides()[1]; - params.dfinal_states_l_stride = dfinal_states.strides()[2]; - } else { - params.dfinal_states_ptr = nullptr; - } - - paddle::Tensor dinitial_states; - if (return_dinitial_states) { - dinitial_states = paddle::experimental::transpose(paddle::empty({batch_size, width - 1, dim}, x.dtype(), x.place()), {0, 2, 1}); - PD_CHECK(dinitial_states.strides()[1] == 1); - params.dinitial_states_ptr = const_cast(dinitial_states.data()); - params.dinitial_states_batch_stride = dinitial_states.strides()[0]; - params.dinitial_states_c_stride = dinitial_states.strides()[1]; - params.dinitial_states_l_stride = dinitial_states.strides()[2]; - } else { - params.dinitial_states_ptr = nullptr; - } - - auto stream = dx.stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.dtype(), "causal_conv1d_bwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.dtype(), "causal_conv1d_bwd", [&] { - if (!is_channel_last) { - causal_conv1d_bwd_cuda(params, stream); - } else { - causal_conv1d_channellast_bwd_cuda(params, stream); - } - }); - }); - return {dx, dweight.cast(weight.dtype()), bias_.has_value() ? dbias.cast(bias_.value().dtype()) : dbias, dinitial_states}; -} - -paddle::Tensor -causal_conv1d_update(const paddle::Tensor &x, - const paddle::Tensor &conv_state, - const paddle::Tensor &weight, - const std::optional &bias_, - bool silu_activation, - const std::optional &cache_seqlens_ - ) { - auto input_type = x.dtype(); - auto weight_type = weight.dtype(); - PD_CHECK(input_type == paddle::DataType::FLOAT32 || input_type == paddle::DataType::FLOAT16 || input_type == paddle::DataType::BFLOAT16); - PD_CHECK(weight_type == paddle::DataType::FLOAT32 || weight_type == paddle::DataType::FLOAT16 || weight_type == paddle::DataType::BFLOAT16); - PD_CHECK(conv_state.dtype() == input_type); - - PD_CHECK(x.is_gpu()); - PD_CHECK(conv_state.is_gpu()); - PD_CHECK(weight.is_gpu()); - - const auto sizes = x.dims(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.dims()[weight.dims().size() - 1]; - const int conv_state_len = conv_state.dims()[2]; - PD_CHECK(conv_state_len >= width - 1); - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); - CHECK_SHAPE(weight, dim, width); - - PD_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - PD_CHECK(bias.dtype() == weight_type); - PD_CHECK(bias.is_gpu()); - PD_CHECK(bias.strides()[bias.strides().size() - 1] == 1); - CHECK_SHAPE(bias, dim); - } - - paddle::Tensor out = paddle::empty_like(x); - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_.has_value() ? const_cast(bias_.value().data()) : nullptr, - silu_activation); - params.conv_state_ptr = const_cast(conv_state.data()); - params.conv_state_len = conv_state_len; - // All stride are in elements, not bytes. - params.conv_state_batch_stride = conv_state.strides()[0]; - params.conv_state_c_stride = conv_state.strides()[1]; - params.conv_state_l_stride = conv_state.strides()[2]; - - if (cache_seqlens_.has_value()) { - auto cache_seqlens = cache_seqlens_.value(); - PD_CHECK(cache_seqlens.dtype() == paddle::DataType::INT32 || cache_seqlens.dtype() == paddle::DataType::INT64); - PD_CHECK(cache_seqlens.is_gpu()); - PD_CHECK(cache_seqlens.strides()[cache_seqlens.dims().size() - 1] == 1); - CHECK_SHAPE(cache_seqlens, batch_size); - params.cache_seqlens = cache_seqlens.data(); - } else { - params.cache_seqlens = nullptr; - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - auto stream = x.stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.dtype(), "causal_conv1d_update", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.dtype(), "causal_conv1d_update", [&] { - causal_conv1d_update_cuda(params, stream); - }); - }); - return out; -} - -PYBIND11_MODULE(causal_conv1d_cuda_pd, m) { - m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward"); - m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward"); - m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update"); -} diff --git a/ops/csrc/causal_conv1d/causal_conv1d.h b/ops/csrc/causal_conv1d/causal_conv1d.h deleted file mode 100644 index df9d38ef3691..000000000000 --- a/ops/csrc/causal_conv1d/causal_conv1d.h +++ /dev/null @@ -1,77 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ConvParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, width; - bool silu_activation; - - index_t x_batch_stride; - index_t x_c_stride; - index_t x_l_stride; - index_t weight_c_stride; - index_t weight_width_stride; - index_t out_batch_stride; - index_t out_c_stride; - index_t out_l_stride; - - int conv_state_len; - index_t conv_state_batch_stride; - index_t conv_state_c_stride; - index_t conv_state_l_stride; - - // Common data pointers. - void *__restrict__ x_ptr; - void *__restrict__ weight_ptr; - void *__restrict__ bias_ptr; - void *__restrict__ out_ptr; - - void *__restrict__ conv_state_ptr; - int32_t *__restrict__ cache_seqlens; - - void *__restrict__ seq_idx_ptr; - - // No __restrict__ since initial_states could be the same as final_states. - void * initial_states_ptr; - index_t initial_states_batch_stride; - index_t initial_states_l_stride; - index_t initial_states_c_stride; - - void * final_states_ptr; - index_t final_states_batch_stride; - index_t final_states_l_stride; - index_t final_states_c_stride; -}; - -struct ConvParamsBwd: public ConvParamsBase { - index_t dx_batch_stride; - index_t dx_c_stride; - index_t dx_l_stride; - index_t dweight_c_stride; - index_t dweight_width_stride; - index_t dout_batch_stride; - index_t dout_c_stride; - index_t dout_l_stride; - - // Common data pointers. - void *__restrict__ dx_ptr; - void *__restrict__ dweight_ptr; - void *__restrict__ dbias_ptr; - void *__restrict__ dout_ptr; - - void * dinitial_states_ptr; - index_t dinitial_states_batch_stride; - index_t dinitial_states_l_stride; - index_t dinitial_states_c_stride; - - void * dfinal_states_ptr; - index_t dfinal_states_batch_stride; - index_t dfinal_states_l_stride; - index_t dfinal_states_c_stride; -}; diff --git a/ops/csrc/causal_conv1d/causal_conv1d_bwd.cu b/ops/csrc/causal_conv1d/causal_conv1d_bwd.cu deleted file mode 100644 index 45d8dfeddde6..000000000000 --- a/ops/csrc/causal_conv1d/causal_conv1d_bwd.cu +++ /dev/null @@ -1,627 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#include -#include - - -#ifndef USE_ROCM - #include - #include - #include -#else - #include - namespace cub = hipcub; -#endif - -#include "causal_conv1d.h" -#include "causal_conv1d_common.h" -#include "static_switch.h" - -template -struct Causal_conv1d_bwd_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr bool kSiluAct = kSiluAct_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static_assert(kWidth <= kNElts); - // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits - // (since then we'd have 8 values of float, and each round we can exchange 4 floats). - static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - using BlockReduceFloatT = cub::BlockReduce; - static constexpr int kSmemIOSize = kIsVecLoad - ? 0 - : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); - static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1); - static constexpr int kSmemSize = custom_max({kSmemExchangeSize, - int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_bwd_kernel(ConvParamsBwd params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr bool kSiluAct = Ktraits::kSiluAct; - static constexpr int kNElts = Ktraits::kNElts; - constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds; - static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - extern __shared__ char smem_[]; - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_vec = reinterpret_cast(smem_); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_store_vec = reinterpret_cast(smem_); - vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - vec_t *smem_exchange_x = reinterpret_cast(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds; - auto& smem_reduce_float = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + dim_id * params.x_c_stride; - weight_t *weight = reinterpret_cast(params.weight_ptr) + dim_id * params.weight_c_stride; - input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride - + dim_id * params.dout_c_stride; - input_t *dx = reinterpret_cast(params.dx_ptr) + batch_id * params.dx_batch_stride - + dim_id * params.dx_c_stride; - float *dweight = reinterpret_cast(params.dweight_ptr) + dim_id * params.dweight_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[dim_id]); - - // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0. - if (tidx == 0) { - if constexpr (!kSiluAct) { - input_t zeros[kNElts] = {input_t(0)}; - smem_exchange[0] = reinterpret_cast(zeros)[0]; - } else { - float zeros[kNElts] = {input_t(0)}; - #pragma unroll - for (int r = 0; r < kNExchangeRounds; ++r) { - smem_exchange[r * kNThreads] = reinterpret_cast(zeros)[r]; - } - } - } - - float weight_vals[kWidth]; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; } - - float dweight_vals[kWidth] = {input_t(0)}; - float dbias_val = 0; - - constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; - x += (n_chunks - 1) * kChunkSize; - dout += (n_chunks - 1) * kChunkSize; - dx += (n_chunks - 1) * kChunkSize; - for (int chunk = n_chunks - 1; chunk >= 0; --chunk) { - input_t x_vals_load[2 * kNElts] = {input_t(0)}; - input_t dout_vals_load[2 * kNElts] = {input_t(0)}; - if constexpr(kIsVecLoad) { - typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); - typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(dout), *reinterpret_cast(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - __syncthreads(); - typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); - __syncthreads(); - typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize); - } - float dout_vals[2 * kNElts], x_vals[2 * kNElts]; - if constexpr (!kSiluAct) { - __syncthreads(); - // Thread 0 don't write yet, so that thread kNThreads - 1 can read - // the first elements of the next chunk. - if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; } - __syncthreads(); - reinterpret_cast(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0]; - __syncthreads(); - // Now thread 0 can write the first elements of the current chunk. - if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; } - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { - dout_vals[i] = float(dout_vals_load[i]); - x_vals[i] = float(x_vals_load[i]); - } - } else { - if (tidx == 0 && chunk > 0) { - if constexpr(kIsVecLoad) { - reinterpret_cast(x_vals_load)[0] = reinterpret_cast(x)[-1]; - } else { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; } - } - } - } - __syncthreads(); - smem_exchange_x[tidx] = reinterpret_cast(x_vals_load)[1]; - __syncthreads(); - if (tidx > 0) { reinterpret_cast(x_vals_load)[0] = smem_exchange_x[tidx - 1]; } - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } - // Recompute the output - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - float out_val = bias_val; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; - } - float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val)); - dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val - * (1.0f + out_val * (1.0f - out_sigmoid_val)); - } - // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange - // if input_t is 16 bits (since then we'd have 8 values of float) - __syncthreads(); - // Thread 0 don't write yet, so that thread kNThreads - 1 can read - // the first elements of the next chunk. - if (tidx > 0) { - #pragma unroll - for (int r = 0; r < kNExchangeRounds; ++r) { - smem_exchange[r * kNThreads + tidx] = reinterpret_cast(dout_vals)[r]; - } - } - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNExchangeRounds; ++r) { - reinterpret_cast(dout_vals)[kNExchangeRounds + r] - = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)]; - } - __syncthreads(); - // Now thread 0 can write the first elements of the current chunk. - if (tidx == 0) { - #pragma unroll - for (int r = 0; r < kNExchangeRounds; ++r) { - smem_exchange[r * kNThreads + tidx] = reinterpret_cast(dout_vals)[r]; - } - } - } - dout -= kChunkSize; - x -= kChunkSize; - - #pragma unroll - for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; } - - float dx_vals[kNElts] = {input_t(0)}; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1]; - } - } - - input_t dx_vals_store[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; } - if constexpr(kIsVecLoad) { - typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(dx), reinterpret_cast(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize); - } - dx -= kChunkSize; - - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1]; - } - } - } - - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - __syncthreads(); - dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]); - if (tidx == 0) { - atomicAdd(&reinterpret_cast(dweight)[w * params.dweight_width_stride], dweight_vals[w]); - } - } - if (params.bias_ptr != nullptr) { - __syncthreads(); - dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val); - if (tidx == 0) { - atomicAdd(&reinterpret_cast(params.dbias_ptr)[dim_id], dbias_val); - } - } -} - -template -void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { - static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { - BOOL_SWITCH(params.silu_activation, kSiluAct, [&] { - using Ktraits = Causal_conv1d_bwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - auto kernel = &causal_conv1d_bwd_kernel; - - if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - #else - // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. - cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - } - - - kernel<<>>(params); - }); - }); -} - -template -void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - -template -struct Causal_conv1d_channellast_bwd_kernel_traits { - // The cache line is 128 bytes, and we try to read 16 bytes per thread. - // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. - // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 - // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr bool kSiluAct = kSiluAct_; - static constexpr int kNThreads = kNThreads_; - static_assert(kNThreads % 32 == 0); - static constexpr int kNWarps = kNThreads / 32; - static constexpr int kWidth = kWidth_; - static constexpr int kChunkSizeL = kChunkSizeL_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static constexpr int kNEltsPerRow = 128 / kNBytes; - static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now - static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); - static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now - static_assert(kNColsPerWarp * kNThreadsPerRow == 32); - static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; - static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; - static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - // using BlockLoadT = cub::BlockLoad; - // using BlockStoreT = cub::BlockStore; - // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), - // sizeof(typename BlockStoreT::TempStorage)}); - // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr bool kSiluAct = Ktraits::kSiluAct; - constexpr int kNElts = Ktraits::kNElts; - constexpr int kNWarp = Ktraits::kNWarps; - constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; - constexpr int kLPerLoad = Ktraits::kNColsPerLoad; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts]; - __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts]; - - const int batch_id = blockIdx.x; - const int chunk_l_id = blockIdx.y; - const int chunk_c_id = blockIdx.z; - const int tid = threadIdx.x; - const int l_idx = tid / kNThreadsPerC; - const int c_idx = tid % kNThreadsPerC; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - weight_t *weight = reinterpret_cast(params.weight_ptr) - + chunk_c_id * kChunkSizeC * params.weight_c_stride; - input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - input_t *dx = reinterpret_cast(params.dx_ptr) + batch_id * params.dx_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - float *dweight = reinterpret_cast(params.dweight_ptr) - + chunk_c_id * kChunkSizeC * params.dweight_c_stride; - int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) - + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; - input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr - : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr - : reinterpret_cast(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr - : reinterpret_cast(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC; - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t dout_vals_load[kNElts] = {input_t(0)}; - input_t x_vals_load[kNElts] = {input_t(0)}; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(dout_vals_load)[0] = *reinterpret_cast(dout + l * kLPerLoad * params.dout_l_stride); - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); - } - reinterpret_cast(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0]; - reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; - } - // Load the elements from the previous chunk or next chunk that are needed for convolution. - if (l_idx < kWidth - 1) { - input_t dout_vals_load[kNElts] = {input_t(0)}; - input_t x_vals_load[kNElts] = {input_t(0)}; - if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(dout_vals_load)[0] = *reinterpret_cast(dout + kChunkSizeL * params.dout_l_stride); - } - if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); - } else if (initial_states != nullptr - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); - } - reinterpret_cast(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0]; - reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; - } - // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs - if constexpr (kSiluAct) { - if (l_idx < kWidth - 1) { - input_t x_vals_load[kNElts] = {input_t(0)}; - if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + kChunkSizeL * params.x_l_stride); - } - reinterpret_cast(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; - } - } - - __syncthreads(); - - constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); - static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); - constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; - static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); - // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity - static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); - static_assert((kLPerThread & (kLPerThread - 1)) == 0); - static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); - static_assert(kNThreadsPerRow <= 32); - - const int row_idx = tid / kNThreadsPerRow; - const int col_idx = tid % kNThreadsPerRow; - - float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); - float weight_vals[kWidth] = {input_t(0)}; - if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; - } - } - float dout_vals[kLPerThread + kWidth - 1]; - float x_vals[kWidth - 1 + kLPerThread + kWidth - 1]; - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]); - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - - int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1]; - if constexpr (kHasSeqIdx) { - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) { - const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1); - seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; - } - } - - if constexpr (kSiluAct) { // Recompute the output - #pragma unroll - for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) { - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - #pragma unroll - for (int i = 0; i < kLPerThread + kWidth - 1; ++i) { - float out_val = bias_val; - const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - if constexpr (!kHasSeqIdx) { - out_val += weight_vals[w] * x_vals[i + w]; - } else { - out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; - } - } - float out_val_sigmoid = 1.f / (1.f + expf(-out_val)); - dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid)); - } - } - - float dweight_vals[kWidth] = {input_t(0)}; - SumOp sum_op; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - if constexpr (!kHasSeqIdx) { - dweight_vals[w] += x_vals[i + w] * dout_vals[i]; - } else { - dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f; - } - } - dweight_vals[w] = Allreduce::run(dweight_vals[w], sum_op); - if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { - atomicAdd(&reinterpret_cast(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]); - } - } - - if (params.bias_ptr != nullptr) { - float dbias_val = 0.f; - for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; } - dbias_val = Allreduce::run(dbias_val, sum_op); - if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { - atomicAdd(&reinterpret_cast(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val); - } - } - - float dx_vals[kLPerThread] = {input_t(0)}; - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - if constexpr (!kHasSeqIdx) { - dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; - } else { - dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f; - } - } - // if (dfinal_states != nullptr) { - if constexpr (kHasDfinalStates) { - if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1 - && chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen - && chunk_c_id * kChunkSizeC + row_idx < params.dim) { - dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]); - } - } - } - - float dxinit_vals[kWidth - 1] = {input_t(0)}; - static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states - if (dinitial_states != nullptr && col_idx == 0) { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f; - } - // chunk_l_id must be 0 because dinitial_states != nullptr - // if (dfinal_states != nullptr) { - if constexpr (kHasDfinalStates) { - if (i >= params.seqlen) { - dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]); - } - } - } - } - - __syncthreads(); - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; } - if (dinitial_states != nullptr && col_idx == 0) { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; } - } - __syncthreads(); - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t dx_vals_store[kNElts]; - reinterpret_cast(dx_vals_store)[0] = reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx]; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - *reinterpret_cast(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast(dx_vals_store)[0]; - } - } - if (dinitial_states != nullptr - && l_idx < kWidth - 1 - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - input_t dxinit_vals_store[kNElts]; - reinterpret_cast(dxinit_vals_store)[0] = reinterpret_cast(x_smem[l_idx])[c_idx]; - *reinterpret_cast(dinitial_states) = reinterpret_cast(dxinit_vals_store)[0]; - } - -} - -template -void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.silu_activation, kSiluAct, [&] { - BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { - BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] { - BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] { - // kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger - static constexpr int kChunk = kChunkSizeL64 ? 64 : 128; - using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; - const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; - dim3 grid(params.batch, n_chunks_L, n_chunks_C); - dim3 block(Ktraits::kNThreads); - auto kernel = &causal_conv1d_channellast_bwd_kernel; - // if (kSmemSize >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - // } - // kernel<<>>(params); - kernel<<>>(params); - }); - }); - }); - }); -} - -template -void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); - -#if defined(CUDA_BFLOAT16_AVAILABLE) -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -#endif \ No newline at end of file diff --git a/ops/csrc/causal_conv1d/causal_conv1d_common.h b/ops/csrc/causal_conv1d/causal_conv1d_common.h deleted file mode 100644 index b2b8c1ca9294..000000000000 --- a/ops/csrc/causal_conv1d/causal_conv1d_common.h +++ /dev/null @@ -1,115 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#if defined(CUDA_BFLOAT16_AVAILABLE) - #ifndef USE_ROCM - #include - - template - __device__ inline T shuffle_xor(T val, int offset) { - return __shfl_xor_sync(uint32_t(-1), val, offset); - } - - constexpr size_t custom_max(std::initializer_list ilist) - { - return std::max(ilist); - } - - template - constexpr T constexpr_min(T a, T b) { - return std::min(a, b); - } - - #else - #include - - template - __device__ inline T shuffle_xor(T val, int offset) { - return __shfl_xor(val, offset); - } - constexpr size_t custom_max(std::initializer_list ilist) - { - return *std::max_element(ilist.begin(), ilist.end()); - } - - template - constexpr T constexpr_min(T a, T b) { - return a < b ? a : b; - } - #endif -#else - template - __device__ inline T shuffle_xor(T val, int offset) { - return __shfl_xor_sync(uint32_t(-1), val, offset); - } - - constexpr size_t custom_max(std::initializer_list ilist) - { - return std::max(ilist); - } - - template - constexpr T constexpr_min(T a, T b) { - return std::min(a, b); - } -#endif -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } -}; - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, shuffle_xor(x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -template<> -struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { - x = op(x, shuffle_xor(x, 1)); - return x; -} -}; diff --git a/ops/csrc/causal_conv1d/causal_conv1d_fwd.cu b/ops/csrc/causal_conv1d/causal_conv1d_fwd.cu deleted file mode 100644 index 8e75e10679bd..000000000000 --- a/ops/csrc/causal_conv1d/causal_conv1d_fwd.cu +++ /dev/null @@ -1,397 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#include - -#ifndef USE_ROCM - #include - #include -#else - #include - namespace cub = hipcub; -#endif - -#include "causal_conv1d.h" -#include "causal_conv1d_common.h" -#include "static_switch.h" - -template -struct Causal_conv1d_fwd_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static_assert(kWidth <= kNElts); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - static constexpr int kSmemIOSize = kIsVecLoad - ? 0 - : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); - static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - extern __shared__ char smem_[]; - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_vec = reinterpret_cast(smem_); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_store_vec = reinterpret_cast(smem_); - vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + channel_id * params.x_c_stride; - weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. - if (tidx == 0) { - input_t zeros[kNElts] = {input_t(0)}; - smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; - } - - float weight_vals[kWidth]; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - - constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; - for (int chunk = 0; chunk < n_chunks; ++chunk) { - input_t x_vals_load[2 * kNElts] = {input_t(0)}; - if constexpr(kIsVecLoad) { - typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - __syncthreads(); - typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); - } - x += kChunkSize; - __syncthreads(); - // Thread kNThreads - 1 don't write yet, so that thread 0 can read - // the last elements of the previous chunk. - if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - __syncthreads(); - reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; - __syncthreads(); - // Now thread kNThreads - 1 can write the last elements of the current chunk. - if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - - float x_vals[2 * kNElts]; - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } - - float out_vals[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = bias_val; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; - } - } - - if (params.silu_activation) { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); - } - } - - input_t out_vals_store[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } - if constexpr(kIsVecLoad) { - typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); - } - out += kChunkSize; - } -} - -template -void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { - using Ktraits = Causal_conv1d_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - - auto kernel = &causal_conv1d_fwd_kernel; - - if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - #else - // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. - cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - } - kernel<<>>(params); - - }); -} - -template -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - -template -struct Causal_conv1d_channellast_fwd_kernel_traits { - // The cache line is 128 bytes, and we try to read 16 bytes per thread. - // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. - // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 - // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static_assert(kNThreads % 32 == 0); - static constexpr int kNWarps = kNThreads / 32; - static constexpr int kWidth = kWidth_; - static constexpr int kChunkSizeL = kChunkSizeL_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static constexpr int kNEltsPerRow = 128 / kNBytes; - static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now - static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); - static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now - static_assert(kNColsPerWarp * kNThreadsPerRow == 32); - static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; - static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; - static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - // using BlockLoadT = cub::BlockLoad; - // using BlockStoreT = cub::BlockStore; - // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), - // sizeof(typename BlockStoreT::TempStorage)}); - // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - constexpr int kNWarp = Ktraits::kNWarps; - constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; - constexpr int kLPerLoad = Ktraits::kNColsPerLoad; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; - - const int batch_id = blockIdx.x; - const int chunk_l_id = blockIdx.y; - const int chunk_c_id = blockIdx.z; - const int tid = threadIdx.x; - const int l_idx = tid / kNThreadsPerC; - const int c_idx = tid % kNThreadsPerC; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - weight_t *weight = reinterpret_cast(params.weight_ptr) - + chunk_c_id * kChunkSizeC * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) - + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; - input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr - : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - // The last L-chunk will also have enough info to write to final states, since it also contain a few x values - // from the previous L-chunk. - input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr - : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t x_vals_load[kNElts] = {input_t(0)}; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); - } - reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; - } - // Load the elements from the previous chunk that are needed for convolution. - if (l_idx < kWidth - 1) { - input_t x_vals_load[kNElts] = {input_t(0)}; - if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); - } else if (initial_states != nullptr - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); - } - reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; - } - - __syncthreads(); - - if (final_states != nullptr - && l_idx < kWidth - 1 - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) - // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] - *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; - } - - constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); - static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); - constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; - static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); - // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity - static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); - static_assert((kLPerThread & (kLPerThread - 1)) == 0); - static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); - static_assert(kNThreadsPerRow <= 32); - - const int row_idx = tid / kNThreadsPerRow; - const int col_idx = tid % kNThreadsPerRow; - - float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); - float weight_vals[kWidth] = {input_t(0)}; - if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; - } - } - float x_vals[kWidth - 1 + kLPerThread]; - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - int seq_idx_thread[kWidth - 1 + kLPerThread]; - if constexpr (kHasSeqIdx) { - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; - } - } - - float out_vals[kLPerThread]; - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - out_vals[i] = bias_val; - const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - if constexpr (!kHasSeqIdx) { - out_vals[i] += weight_vals[w] * x_vals[i + w]; - } else { - out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; - } - } - if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } - } - - __syncthreads(); - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } - __syncthreads(); - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t out_vals_store[kNElts]; - reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; - } - } - -} - -template -void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { - using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; - const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; - dim3 grid(params.batch, n_chunks_L, n_chunks_C); - dim3 block(Ktraits::kNThreads); - auto kernel = &causal_conv1d_channellast_fwd_kernel; - // if (kSmemSize >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - // } - // kernel<<>>(params); - kernel<<>>(params); - }); -} - -template -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -#if defined(CUDA_BFLOAT16_AVAILABLE) -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -#endif \ No newline at end of file diff --git a/ops/csrc/causal_conv1d/causal_conv1d_update.cu b/ops/csrc/causal_conv1d/causal_conv1d_update.cu deleted file mode 100644 index 03cd9250642a..000000000000 --- a/ops/csrc/causal_conv1d/causal_conv1d_update.cu +++ /dev/null @@ -1,130 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include - -#include "causal_conv1d.h" -#include "causal_conv1d_common.h" -#include "static_switch.h" - -template -struct Causal_conv1d_update_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_update_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y * kNThreads + tidx; - if (channel_id >= params.dim) return; - - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + channel_id * params.x_c_stride; - input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride - + channel_id * params.conv_state_c_stride; - weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - int state_len = params.conv_state_len; - int advance_len = params.seqlen; - int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; - int update_idx = cache_seqlen - (kWidth - 1); - update_idx = update_idx < 0 ? update_idx + state_len : update_idx; - - float weight_vals[kWidth] = {0}; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - - float x_vals[kWidth] = {0}; - if constexpr (!kIsCircularBuffer) { - #pragma unroll 2 - for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { - conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; - } - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { - input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; - if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { - conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; - } - x_vals[i] = float(state_val); - } - } else { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { - input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; - x_vals[i] = float(state_val); - } - } - #pragma unroll 2 - for (int i = 0; i < params.seqlen; ++i) { - input_t x_val = x[i * params.x_l_stride]; - if constexpr (!kIsCircularBuffer) { - if (i < advance_len && state_len - advance_len + i >= 0) { - conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; - } - } else { - conv_state[update_idx * params.conv_state_l_stride] = x_val; - ++update_idx; - update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; - } - x_vals[kWidth - 1] = float(x_val); - float out_val = bias_val; - #pragma unroll - for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } - if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } - out[i * params.out_l_stride] = input_t(out_val); - // Shift the input buffer by 1 - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } - } -} - -template -void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - using Ktraits = Causal_conv1d_update_kernel_traits; - dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); - auto kernel = params.cache_seqlens == nullptr - ? &causal_conv1d_update_kernel - : &causal_conv1d_update_kernel; - kernel<<>>(params); -} - -template -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); - } -} - -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -#if defined(CUDA_BFLOAT16_AVAILABLE) -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -#endif \ No newline at end of file diff --git a/ops/csrc/causal_conv1d/static_switch.h b/ops/csrc/causal_conv1d/static_switch.h deleted file mode 100644 index 5af502904da6..000000000000 --- a/ops/csrc/causal_conv1d/static_switch.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - static constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - static constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/ops/csrc/fast_ln/ln.h b/ops/csrc/fast_ln/ln.h deleted file mode 100644 index 8a051263f768..000000000000 --- a/ops/csrc/fast_ln/ln.h +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#pragma once - -#include -#include - -#include -#include -#include - -namespace layer_norm { - -template -struct LaunchParams { - size_t workspace_bytes; - size_t barrier_size; - - cudaDeviceProp *props; - - cudaStream_t stream; - - Params params; -}; - -struct ParamsBase { - ParamsBase() - : ctas_per_col(0), - rows(0), - cols(0), - x(nullptr), - mean(nullptr), - invvar(nullptr), - scale(nullptr), - workspace(nullptr), - barrier(nullptr) {} - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x; - void *mean; - void *invvar; - void *scale; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; -}; - -struct FwdParams : public ParamsBase { - FwdParams() : ParamsBase(), y(nullptr), bias(nullptr), epsilon(0.f) {} - - // Output of LN FWD. - void *y; - void *bias; - float epsilon; -}; - -struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase(), - dy(nullptr), - dbias_part(nullptr), - dscale_part(nullptr), - dx(nullptr), - dbias(nullptr), - dscale(nullptr) {} - - // Input: gradient wrt. LN FWD output. - void *dy; - - // Workspace for Wgrad pre-reduction. - void *dbias_part; - void *dscale_part; - - // Output: Dgrad. - void *dx; - // Output: Wgrad. - void *dbias; - void *dscale; -}; - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdRegistry = std::unordered_map; -using BwdRegistry = std::unordered_map; - -extern FwdRegistry FWD_FUNCS; -extern BwdRegistry BWD_FUNCS; - -using fp32 = float; -using fp16 = half; -using bf16 = nv_bfloat16; - -template -struct TypeToIdTrait {}; - -template <> -struct TypeToIdTrait { - constexpr static uint32_t Value = 0; -}; - -template <> -struct TypeToIdTrait { - constexpr static uint32_t Value = 1; -}; - -template <> -struct TypeToIdTrait { - constexpr static uint32_t Value = 2; -}; - -template -struct Type2KeyTrait { - constexpr static uint32_t Value = TypeToIdTrait::Value << Significant; -}; - -template -struct WeightType2KeyTrait : public Type2KeyTrait {}; - -template -struct InputType2KeyTrait : public Type2KeyTrait {}; - -template -struct OutputType2KeyTrait : public Type2KeyTrait {}; - -template -struct ComputeType2KeyTrait : public Type2KeyTrait {}; - -template -struct Types2KeyTrait { - constexpr static uint32_t Value = WeightType2KeyTrait::Value | - InputType2KeyTrait::Value | - OutputType2KeyTrait::Value | - ComputeType2KeyTrait::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size) { - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } -}; - -template -struct FwdRegistrar { - FwdRegistrar(FwdFunction f) { // NOLINT - uint64_t key = - Types2KeyTrait::get(HIDDEN_SIZE); - FWD_FUNCS.insert({key, f}); - } -}; - -template -struct BwdRegistrar { - BwdRegistrar(BwdFunction f) { // NOLINT - uint64_t key = - Types2KeyTrait::get(HIDDEN_SIZE); - BWD_FUNCS.insert({key, f}); - } -}; - -} // namespace layer_norm diff --git a/ops/csrc/fast_ln/ln_api.cpp b/ops/csrc/fast_ln/ln_api.cpp deleted file mode 100644 index 45d77190a499..000000000000 --- a/ops/csrc/fast_ln/ln_api.cpp +++ /dev/null @@ -1,576 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#include "paddle/extension.h" -#include "ln.h" // NOLINT - -#ifdef CUSTOM_OP_WITH_SPMD -#include "paddle/phi/api/ext/spmd_infer.h" -#include "paddle/phi/infermeta/spmd_rules/rules.h" -#endif - -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp16 fp16 -fp32 fp32 bf16 bf16 - -Remarks: -Output type = Weight type -Compute always in FP32 - -*/ - -namespace layer_norm { - -// Create registries and provide runtime versions of config hash functions. - -FwdRegistry FWD_FUNCS; -BwdRegistry BWD_FUNCS; - -uint32_t get_type_id(paddle::DataType dtype) { - if (dtype == paddle::DataType::FLOAT16) { - return TypeToIdTrait::Value; - } else if (dtype == paddle::DataType::BFLOAT16) { - return TypeToIdTrait::Value; - } else if (dtype == paddle::DataType::FLOAT32) { - return TypeToIdTrait::Value; - } else { - PD_CHECK(false, "Type not supported: ", dtype); - } -} - -uint64_t get_key(paddle::DataType weight_type, - paddle::DataType input_type, - paddle::DataType output_type, - paddle::DataType compute_type, - uint64_t hidden_size) { - uint64_t type_key = - get_type_id(weight_type) | (get_type_id(input_type) << 2) | // NOLINT - (get_type_id(output_type) << 4) | (get_type_id(compute_type) << 6); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; -} - -} // namespace layer_norm - -layer_norm::FwdFunction &get_fwd_launcher(paddle::DataType weight_type, - paddle::DataType input_type, - paddle::DataType output_type, - paddle::DataType compute_type, - uint32_t hidden_size) { - auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key( - weight_type, input_type, output_type, compute_type, hidden_size)); - if (iter != layer_norm::FWD_FUNCS.end()) { - return iter->second; - } else { - PD_CHECK(false, - "FWD: Unsupported hidden_size or types: ", - hidden_size, - weight_type, - input_type, - output_type, - compute_type); - } -} - -layer_norm::BwdFunction &get_bwd_launcher(paddle::DataType weight_type, - paddle::DataType input_type, - paddle::DataType output_type, - paddle::DataType compute_type, - uint32_t hidden_size) { - auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key( - weight_type, input_type, output_type, compute_type, hidden_size)); - if (iter != layer_norm::BWD_FUNCS.end()) { - return iter->second; - } else { - PD_CHECK(false, - "BWD: Unsupported hidden_size or types: ", - hidden_size, - weight_type, - input_type, - output_type, - compute_type); - } -} - -static cudaDeviceProp GetDevicePropImpl() { - int device = -1; - PD_CHECK(cudaGetDevice(&device) == cudaSuccess); - cudaDeviceProp prop; - PD_CHECK(cudaGetDeviceProperties(&prop, device) == cudaSuccess); - return prop; -} - -static cudaDeviceProp *GetDeviceProp() { - static auto prop = GetDevicePropImpl(); - return ∝ -} - -void LaunchNormFwd(const cudaStream_t& stream, - const paddle::Place& place, - const void* x_ptr, - const void* scale_ptr, - const void* bias_ptr, - void* y_ptr, - void* mean_ptr, - void* invvar_ptr, - const paddle::DataType weight_type, - const paddle::DataType input_type, - const paddle::DataType output_type, - const paddle::DataType compute_type, - const uint32_t hidden_size, - const int64_t rows, - const int64_t cols, - const float epsilon) { - layer_norm::LaunchParams launch_params; - - launch_params.props = GetDeviceProp(); - launch_params.stream = stream; - - // Request the kernel launcher. - auto launcher = get_fwd_launcher( - weight_type, input_type, output_type, compute_type, hidden_size); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Set the kernel runtime parameters. - layer_norm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = const_cast(x_ptr); - params.scale = const_cast(scale_ptr); - params.bias = const_cast(bias_ptr); - params.y = y_ptr; - params.mean = mean_ptr; - params.invvar = invvar_ptr; - params.epsilon = epsilon; - - paddle::Tensor workspace, barrier; - if (launch_params.barrier_size > 0) { - barrier = paddle::zeros( - {launch_params.barrier_size}, paddle::DataType::INT32, place); - workspace = paddle::empty( - {launch_params.workspace_bytes}, paddle::DataType::UINT8, place); - params.workspace = workspace.data(); - params.barrier = barrier.data(); - } - - launcher(launch_params, false); -} - -std::vector LnFwd(const paddle::Tensor &x, - const paddle::Tensor &scale, - const paddle::Tensor &bias, - const float epsilon) { - auto input_type = x.type(); - auto weight_type = scale.type(); - auto output_type = weight_type; - auto compute_type = paddle::DataType::FLOAT32; - - PD_CHECK(bias.type() == weight_type); - - PD_CHECK(!x.is_cpu()); - PD_CHECK(!scale.is_cpu()); - PD_CHECK(!bias.is_cpu()); - - auto sizes = x.shape(); - PD_CHECK(sizes.size() >= 2); - - std::vector row_sizes(sizes.begin(), sizes.begin() + sizes.size() - 1); - - const int cols = sizes[sizes.size() - 1]; - const int rows = x.numel() / cols; - auto hidden_size = scale.numel(); - - PD_CHECK(scale.shape() == bias.shape()); - PD_CHECK(hidden_size == cols); - - PD_CHECK(epsilon >= 0.f); - - auto place = x.place(); - - auto y = paddle::empty(sizes, output_type, place); - - auto mean = paddle::empty({row_sizes}, compute_type, place); - auto invvar = paddle::empty({row_sizes}, compute_type, place); - - LaunchNormFwd(x.stream(), - place, - /* x_ptr */ x.data(), - /* scale_ptr */ scale.data(), - /* bias_ptr */ bias.data(), - /* y_ptr */ y.data(), - /* mean_ptr */ mean.data(), - /* invvar_ptr */ invvar.data(), - weight_type, - input_type, - output_type, - compute_type, - hidden_size, - rows, - cols, - epsilon); - - return {y, mean, invvar}; -} - -std::vector RMSLnFwd(const paddle::Tensor &x, - const paddle::Tensor &scale, - const float epsilon) { - auto input_type = x.type(); - auto weight_type = scale.type(); - auto output_type = weight_type; - auto compute_type = paddle::DataType::FLOAT32; - - PD_CHECK(!x.is_cpu()); - PD_CHECK(!scale.is_cpu()); - - auto sizes = x.shape(); - PD_CHECK(sizes.size() >= 2); - - int rows = 1; - for (size_t i = 0; i + 1 < sizes.size(); ++i) { - rows *= sizes[i]; - } - - const int cols = sizes[sizes.size() - 1]; - auto hidden_size = scale.numel(); - - PD_CHECK(hidden_size == cols); - PD_CHECK(epsilon >= 0.f); - - auto place = x.place(); - - auto y = paddle::empty(sizes, output_type, place); - auto invvar = paddle::empty({rows}, compute_type, place); - - LaunchNormFwd(x.stream(), - place, - /* x_ptr */ x.data(), - /* scale_ptr */ scale.data(), - /* bias_ptr */ nullptr, - /* y_ptr */ y.data(), - /* mean_ptr */ nullptr, - /* invvar_ptr */ invvar.data(), - weight_type, - input_type, - output_type, - compute_type, - hidden_size, - rows, - cols, - epsilon); - - return {y, invvar}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void LaunchNormBwd(const cudaStream_t& stream, - const paddle::Place& place, - const void* x_ptr, - const void* scale_ptr, - const void* mean_ptr, - const void* invvar_ptr, - const void* dy_ptr, - void* dx_ptr, - void* dscale_ptr, - void* dbias_ptr, - const paddle::DataType weight_type, - const paddle::DataType input_type, - const paddle::DataType output_type, - const paddle::DataType compute_type, - const uint32_t hidden_size, - const int64_t rows, - const int64_t cols, - const float epsilon) { - layer_norm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.props = GetDeviceProp(); - - auto launcher = get_bwd_launcher( - weight_type, input_type, output_type, compute_type, hidden_size); - - launcher(launch_params, true); - - paddle::Tensor dscale_part, dbias_part; - - dscale_part = paddle::empty( - {launch_params.params.ctas_per_col, hidden_size}, compute_type, place); - if (dbias_ptr) { - dbias_part = paddle::empty({launch_params.params.ctas_per_col, hidden_size}, compute_type, place); - } - - layer_norm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = const_cast(x_ptr); - params.scale = const_cast(scale_ptr); - params.mean = const_cast(mean_ptr); - params.invvar = const_cast(invvar_ptr); - params.dy = const_cast(dy_ptr); - params.dx = dx_ptr; - params.dscale = dscale_ptr; - params.dbias = dbias_ptr; - params.dscale_part = dscale_part.data(); - params.dbias_part = dbias_ptr ? dbias_part.data() : nullptr; - - paddle::Tensor workspace, barrier; - if (launch_params.barrier_size > 0) { - barrier = paddle::zeros( - {launch_params.barrier_size}, paddle::DataType::INT32, place); - workspace = paddle::empty( - {launch_params.workspace_bytes}, paddle::DataType::UINT8, place); - params.workspace = workspace.data(); - params.barrier = barrier.data(); - } - - launcher(launch_params, false); -} - -std::vector LnBwd(const paddle::Tensor &x, - const paddle::Tensor &scale, - const paddle::Tensor &mean, - const paddle::Tensor &invvar, - const paddle::Tensor &dy, - const float epsilon) { - auto input_type = x.type(); - auto weight_type = scale.type(); - auto output_type = weight_type; - auto compute_type = paddle::DataType::FLOAT32; - - PD_CHECK(dy.dtype() == output_type); - PD_CHECK(mean.dtype() == compute_type); - PD_CHECK(invvar.dtype() == compute_type); - - PD_CHECK(!x.is_cpu()); - PD_CHECK(!dy.is_cpu()); - PD_CHECK(!mean.is_cpu()); - PD_CHECK(!invvar.is_cpu()); - PD_CHECK(!scale.is_cpu()); - - auto sizes = x.shape(); - PD_CHECK(sizes.size() >= 2); - PD_CHECK(dy.shape() == sizes); - - int64_t rows = 1; - for (size_t i = 0; i + 1 < sizes.size(); ++i) { - rows *= sizes[i]; - } - auto cols = sizes[sizes.size() - 1]; - - auto hidden_size = scale.numel(); - - PD_CHECK(mean.numel() == rows); - PD_CHECK(mean.shape() == invvar.shape()); - - PD_CHECK(scale.numel() == cols); - - auto dx = paddle::empty_like(x); - auto dscale = paddle::empty_like(scale); - auto dbias = paddle::empty_like(scale); - - auto place = x.place(); - - LaunchNormBwd(x.stream(), - place, - /* x_ptr */ x.data(), - /* scale_ptr */ scale.data(), - /* mean_ptr */ mean.data(), - /* invvar_ptr */ invvar.data(), - /* dy_ptr */ dy.data(), - /* dx_ptr */ dx.data(), - /* dscale_ptr */ dscale.data(), - /* dbias_ptr */ dbias.data(), - weight_type, - input_type, - output_type, - compute_type, - hidden_size, - rows, - cols, - epsilon); - - return {dx, dscale, dbias}; -} - -std::vector RMSLnBwd(const paddle::Tensor &x, - const paddle::Tensor &scale, - const paddle::Tensor &invvar, - const paddle::Tensor &dy, - const float epsilon) { - auto input_type = x.type(); - auto weight_type = scale.type(); - auto output_type = weight_type; - auto compute_type = paddle::DataType::FLOAT32; - - PD_CHECK(dy.dtype() == output_type); - PD_CHECK(invvar.dtype() == compute_type); - - PD_CHECK(!x.is_cpu()); - PD_CHECK(!dy.is_cpu()); - PD_CHECK(!invvar.is_cpu()); - PD_CHECK(!scale.is_cpu()); - - auto sizes = x.shape(); - PD_CHECK(sizes.size() >= 2); - PD_CHECK(dy.shape() == sizes); - - int64_t rows = 1; - for (size_t i = 0; i + 1 < sizes.size(); ++i) { - rows *= sizes[i]; - } - auto cols = sizes[sizes.size() - 1]; - - auto hidden_size = scale.numel(); - - PD_CHECK(scale.numel() == cols); - - auto dx = paddle::empty_like(x); - auto dscale = paddle::empty_like(scale); - - auto place = x.place(); - - LaunchNormBwd(x.stream(), - place, - /* x_ptr */ x.data(), - /* scale_ptr */ scale.data(), - /* mean_ptr */ nullptr, - /* invvar_ptr */ invvar.data(), - /* dy_ptr */ dy.data(), - /* dx_ptr */ dx.data(), - /* dscale_ptr */ dscale.data(), - /* dbias_ptr */ nullptr, - weight_type, - input_type, - output_type, - compute_type, - hidden_size, - rows, - cols, - epsilon); - - return {dx, dscale}; -} - -std::vector> LnFwdInferShape( - std::vector x_shape, - std::vector scale_shape, - std::vector bias_shape, - float epsilon) { - std::vector row_shape(x_shape.begin(), x_shape.begin() + x_shape.size() - 1); - return {x_shape, row_shape, row_shape}; -} - -std::vector> RMSLnFwdInferShape( - std::vector x_shape, - std::vector scale_shape, - float epsilon) { - int64_t rows = 1; - for (size_t i = 0; i + 1 < x_shape.size(); ++i) { - rows *= x_shape[i]; - } - return {x_shape, {rows}}; -} - -std::vector LnFwdInferDtype(paddle::DataType x_dtype, - paddle::DataType scale_dtype, - paddle::DataType bias_dtype) { - return {x_dtype, paddle::DataType::FLOAT32, paddle::DataType::FLOAT32}; -} - -std::vector RMSLnFwdInferDtype(paddle::DataType x_dtype, - paddle::DataType scale_dtype) { - return {x_dtype, paddle::DataType::FLOAT32}; -} - -std::vector> LnBwdInferShape( - std::vector x_shape, - std::vector scale_shape, - std::vector mean_shape, - std::vector invvar_shape, - std::vector dy_shape, - float epsilon) { - return {x_shape, scale_shape, scale_shape}; -} - -std::vector> RMSLnBwdInferShape( - std::vector x_shape, - std::vector scale_shape, - std::vector invvar_shape, - std::vector dy_shape, - float epsilon) { - return {x_shape, scale_shape}; -} - -std::vector LnBwdInferDtype(paddle::DataType x_dtype, - paddle::DataType scale_dtype, - paddle::DataType mean_dtype, - paddle::DataType invvar_dtype, - paddle::DataType dy_dtype) { - return {x_dtype, scale_dtype, scale_dtype}; -} - -PD_BUILD_OP(fast_ln) - .Inputs({"x", "scale", "bias"}) - .Outputs({"y", "mean", "invvar"}) - .Attrs({"epsilon: float"}) - .SetKernelFn(PD_KERNEL(LnFwd)) - .SetInferShapeFn(PD_INFER_SHAPE(LnFwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(LnFwdInferDtype)) -#ifdef CUSTOM_OP_WITH_SPMD - .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::FastLnInferSpmd)) -#endif -; - -PD_BUILD_GRAD_OP(fast_ln) - .Inputs({"x", "scale", "mean", "invvar", paddle::Grad("y")}) - .Outputs({paddle::Grad("x"), paddle::Grad("scale"), paddle::Grad("bias")}) - .Attrs({"epsilon: float"}) - .SetKernelFn(PD_KERNEL(LnBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(LnBwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(LnBwdInferDtype)) -#ifdef CUSTOM_OP_WITH_SPMD - .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::FastLnGradInferSpmd)) -#endif -; - -PD_BUILD_OP(fast_rms_norm) - .Inputs({"x", "scale"}) - .Outputs({"y", "invvar"}) - .Attrs({"epsilon: float"}) - .SetKernelFn(PD_KERNEL(RMSLnFwd)) - .SetInferShapeFn(PD_INFER_SHAPE(RMSLnFwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype)); - -PD_BUILD_GRAD_OP(fast_rms_norm) - .Inputs({"x", "scale", "invvar", paddle::Grad("y")}) - .Outputs({paddle::Grad("x"), paddle::Grad("scale")}) - .Attrs({"epsilon: float"}) - .SetKernelFn(PD_KERNEL(RMSLnBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape)); diff --git a/ops/csrc/fast_ln/ln_bwd_kernels.h b/ops/csrc/fast_ln/ln_bwd_kernels.h deleted file mode 100644 index 49ff81792b31..000000000000 --- a/ops/csrc/fast_ln/ln_bwd_kernels.h +++ /dev/null @@ -1,354 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#pragma once - -#include "ln.h" // NOLINT -#include "ln_utils.h" // NOLINT - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel( - layer_norm::BwdParams params) { - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { COLS = Ktraits::COLS }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using compute_t = typename Ktraits::compute_t; - using index_t = typename Ktraits::index_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Reducer = typename Ktraits::Reducer; - using reduce_t = typename Reducer::Type; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / Ktraits::WARPS_N; - const index_t warp_n = warp % Ktraits::WARPS_N; - const index_t tid_r = warp_n * THREADS_PER_WARP + lane; - - const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); - - Cvec dzy_sum[LDGS]; - Cvec dz_sum[LDGS]; - - memset(dzy_sum, 0, sizeof(dzy_sum)); - memset(dz_sum, 0, sizeof(dz_sum)); - - compute_t *smem_wgrad = reinterpret_cast(smem_); - char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; - - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); - - Sum sum; - bool is_rmsnorm = params.mean == nullptr; - constexpr float rn = 1.f / static_cast(COLS); - Wvec gamma[LDGS]; - index_t idx = c; -#pragma unroll - for (int it = 0; it < LDGS; it++) { - gamma[it].load_from(params.scale, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } -#pragma unroll 1 - for (int row = r; row < params.rows; - row += params.ctas_per_col * ROWS_PER_CTA) { - const compute_t mu_r = is_rmsnorm ? static_cast(0.) : static_cast(params.mean)[row]; - const compute_t rs_r = static_cast(params.invvar)[row]; - Ivec x[LDGS]; - Ovec dz[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; -#pragma unroll - for (int it = 0; it < LDGS; it++) { - dz[it].load_from(params.dy, idx); - x[it].load_from(params.x, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - - compute_t dy[LDGS * NUM_ELTS]; - compute_t y[LDGS * NUM_ELTS]; - - compute_t mdy_local = 0.f; - compute_t mdyy_local = 0.f; -#pragma unroll - for (int it = 0; it < LDGS; it++) { -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t x_tmp = x[it].data.elt[jt]; - compute_t y_tmp = rs_r * (x_tmp - mu_r); - compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); - dy_tmp *= compute_t(dz[it].data.elt[jt]); - compute_t dz_tmp = dz[it].data.elt[jt]; - - mdy_local += dy_tmp; - mdyy_local += dy_tmp * y_tmp; - - dy[it * NUM_ELTS + jt] = dy_tmp; - y[it * NUM_ELTS + jt] = y_tmp; - - dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; - dz_sum[it].data.elt[jt] += dz_tmp; - } - } - - reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - if (is_rmsnorm) { - mdy_local = 0.f; - } else { - mdy_local = layer_norm::Get<0>::of(result) * rn; - } - mdyy_local = layer_norm::Get<1>::of(result) * rn; - Ivec dx[LDGS]; - idx = row * Ktraits::VEC_COLS + c; -#pragma unroll - for (int it = 0; it < LDGS; it++) { -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t dy_tmp = dy[it * NUM_ELTS + jt]; - compute_t y_tmp = y[it * NUM_ELTS + jt]; - compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); - dx[it].data.elt[jt] = dx_tmp; - } - dx[it].store_to(params.dx, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - } // end: grid stride loop - - if (WARPS_M == 1) { - idx = r * Ktraits::VEC_COLS + c; -#pragma unroll - for (int it = 0; it < LDGS; it++) { - if (params.dbias) { - dz_sum[it].store_to(params.dbias_part, idx); - } - dzy_sum[it].store_to(params.dscale_part, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - } else { - static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, - "Multiple rows per CTA not supported for Multi-CTA."); - // Finalize reduction of part dgamma and dbeta for this CTA - // by reducing over the rows held across the WARPS_M warps - - // Assumption: blockSize divides hidden size. - enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; - static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; -#pragma unroll - for (int it = 0; it < LDGS; it++) { - dz_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dz_sum[NUM_RES]; - memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); - for (int it = 0; it < ROWS_PER_CTA; it++) { - for (int jt = 0; jt < NUM_RES; jt++) { - cta_dz_sum[jt] += - smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - __syncthreads(); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; -#pragma unroll - for (int it = 0; it < LDGS; it++) { - dzy_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dzy_sum[NUM_RES]; - memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); - for (int it = 0; it < ROWS_PER_CTA; it++) { - for (int jt = 0; jt < NUM_RES; jt++) { - cta_dzy_sum[jt] += - smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - - compute_t *dgamma_part = - static_cast(params.dscale_part) + bidm * COLS + tidx; - for (int jt = 0; jt < NUM_RES; jt++) { - *dgamma_part = cta_dzy_sum[jt]; - dgamma_part += Ktraits::THREADS_PER_CTA; - } - - if (params.dbias) { - compute_t *dbeta_part = - static_cast(params.dbias_part) + bidm * COLS + tidx; - for (int jt = 0; jt < NUM_RES; jt++) { - *dbeta_part = cta_dz_sum[jt]; - dbeta_part += Ktraits::THREADS_PER_CTA; - } - } - } -} - -template -__global__ -__launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_kernel( - BwdParams params) { - using compute_t = typename Kernel_traits::compute_t; - using weight_t = typename Kernel_traits::weight_t; - using index_t = typename Kernel_traits::index_t; - using Reducer = typename Kernel_traits::Reducer; - using reduce_t = typename Reducer::Type; - - Sum sum; - enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; - - __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; - - constexpr uint32_t bidm = 0; - - const uint32_t bidn = blockIdx.x; - const uint32_t tidx = threadIdx.x; - const uint32_t warp = tidx / THREADS_PER_WARP; - const uint32_t lane = tidx % THREADS_PER_WARP; - - Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); - - const uint32_t c = bidn * THREADS_PER_WARP + lane; - const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; - col += COL_STRIDE, col_out += COL_STRIDE / 2) { - // Each thread sums over NUM_ELT columns. - Vec dbeta_local, dgamma_local; - memset(&dgamma_local, 0, sizeof(dgamma_local)); - memset(&dbeta_local, 0, sizeof(dbeta_local)); - for (uint32_t row = warp; row < params.ctas_per_col; - row += Kernel_traits::ROWS_PER_CTA) { - index_t idx = row * Kernel_traits::COLS + col; - - Vec dbeta_part, dgamma_part; - if (params.dbias) { - dbeta_part.load_from(params.dbias_part, idx); - } else { - dbeta_part.init(0.); - } - dgamma_part.load_from(params.dscale_part, idx); -#pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; - dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; - } - } - - void *smem_gamma = smem_; - void *smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; - - const int write_row = warp; - const int write_col = lane ^ write_row; - const int write_idx = write_row * THREADS_PER_WARP + write_col; - - dgamma_local.store_to(smem_gamma, write_idx); - dbeta_local.store_to(smem_beta, write_idx); - - __syncthreads(); - - // It would be probably safe to reuse the first row of smem_beta and - // smem_gamma - void *smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void *smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + - Kernel_traits::SMEM_BYTES_OUTPUT]; - - // More than one iter iff ROWS_PER_CTA < 32. - for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) { - const int read_row = lane; - const int read_col = w ^ read_row; - const int read_idx = read_row * THREADS_PER_WARP + read_col; - - memset(&dbeta_local, 0, sizeof(dbeta_local)); - memset(&dgamma_local, 0, sizeof(dgamma_local)); - - // Load beta and gamma transposed - if (read_row < Kernel_traits::ROWS_PER_CTA) { - dbeta_local.load_from(smem_beta, read_idx); - dgamma_local.load_from(smem_gamma, read_idx); - } - -// Call reducer on the loaded value(s) and convert. -#pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - compute_t b_i = dbeta_local.data.elt[it]; - compute_t g_i = dgamma_local.data.elt[it]; - b_i = reducer.allreduce(b_i, sum); - g_i = reducer.allreduce(g_i, sum); - - dgamma_local.data.elt[it] = g_i; - dbeta_local.data.elt[it] = b_i; - } - - // Leader stores the result at the current column. - if (lane == 0) { - dgamma_local.store_to(smem_gamma_out, w); - dbeta_local.store_to(smem_beta_out, w); - } - } - - // All writes done. - __syncthreads(); - - // Pack and store: 2-wide stores with half the threads. - if (warp == Kernel_traits::ROWS_PER_CTA - 1 && - lane < THREADS_PER_WARP / 2) { - using src_t = typename TypeToVec2::Type; - using dst_t = typename TypeToVec2::Type; - Vec dbeta_vec2, dgamma_vec2; - Vec dbeta_out2, dgamma_out2; - - dgamma_vec2.load_from(smem_gamma_out, lane); - dbeta_vec2.load_from(smem_beta_out, lane); -#pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - dgamma_out2.data.elt[it] = - Converter::convert(dgamma_vec2.data.elt[it]); - dbeta_out2.data.elt[it] = - Converter::convert(dbeta_vec2.data.elt[it]); - } - dgamma_out2.store_to(params.dscale, col_out); - if (params.dbias) { - dbeta_out2.store_to(params.dbias, col_out); - } - } - } -} -} // namespace layer_norm diff --git a/ops/csrc/fast_ln/ln_bwd_semi_cuda_kernel.cu b/ops/csrc/fast_ln/ln_bwd_semi_cuda_kernel.cu deleted file mode 100644 index 35144c69555c..000000000000 --- a/ops/csrc/fast_ln/ln_bwd_semi_cuda_kernel.cu +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#include "ln.h" // NOLINT -#include "ln_bwd_kernels.h" // NOLINT -#include "ln_kernel_traits.h" // NOLINT -#include "ln_utils.h" // NOLINT - -using namespace layer_norm; // NOLINT - -template -void launch_(LaunchParams &launch_params, // NOLINT - const bool configure_params) { - using KernelTraits = KernelTraits; - auto kernel = &ln_bwd_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, - kernel, - KernelTraits::THREADS_PER_CTA, - KernelTraits::SMEM_BYTES); - launch_params.params.ctas_per_col = - launch_params.props->multiProcessorCount * ctas_per_sm / - KernelTraits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (KernelTraits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = - launch_params.params.ctas_per_col * KernelTraits::WARPS_M * - KernelTraits::CTAS_PER_ROW * sizeof(typename KernelTraits::reduce_t) * - 2; - } - return; - } - - if (KernelTraits::SMEM_BYTES >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - KernelTraits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if (KernelTraits::CTAS_PER_ROW == 1) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(KernelTraits::CTAS_PER_ROW * ctas_per_col); - dim3 block(KernelTraits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; // NOLINT - cudaLaunchCooperativeKernel((void *)kernel, // NOLINT - grid, - block, - (void **)¶ms_, // NOLINT - KernelTraits::SMEM_BYTES, - stream); - } - - using KernelTraitsF = - layer_norm::KernelTraitsFinalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; - kernel_f<<>>( - launch_params.params); -} - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, -// BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); - -REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); diff --git a/ops/csrc/fast_ln/ln_fwd_cuda_kernel.cu b/ops/csrc/fast_ln/ln_fwd_cuda_kernel.cu deleted file mode 100644 index 626edbac3bca..000000000000 --- a/ops/csrc/fast_ln/ln_fwd_cuda_kernel.cu +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#include "ln.h" // NOLINT -#include "ln_fwd_kernels.h" // NOLINT -#include "ln_kernel_traits.h" // NOLINT -#include "ln_utils.h" // NOLINT - -using namespace layer_norm; // NOLINT - -template -void launch_(LaunchParams &launch_params, // NOLINT - const bool configure_params) { - using KernelTraits = KernelTraits; - auto kernel = &ln_fwd_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, - kernel, - KernelTraits::THREADS_PER_CTA, - KernelTraits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_col = - launch_params.props->multiProcessorCount * ctas_per_sm / - KernelTraits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (KernelTraits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = - launch_params.params.ctas_per_col * KernelTraits::WARPS_M * - KernelTraits::CTAS_PER_ROW * - sizeof(typename KernelTraits::Stats::stats_t) * 2; - } - return; - } - - if (KernelTraits::SMEM_BYTES_FWD >= 48 * 1024) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - KernelTraits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if (KernelTraits::CTAS_PER_ROW == 1) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(KernelTraits::CTAS_PER_ROW * ctas_per_col); - dim3 block(KernelTraits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; // NOLINT - cudaLaunchCooperativeKernel((void *)kernel, // NOLINT - grid, - block, - (void **)¶ms_, // NOLINT - KernelTraits::SMEM_BYTES_FWD, - stream); - } -} - -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, -// BYTES_PER_LDG - -REGISTER_FWD_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16); diff --git a/ops/csrc/fast_ln/ln_fwd_kernels.h b/ops/csrc/fast_ln/ln_fwd_kernels.h deleted file mode 100644 index b99ec330982d..000000000000 --- a/ops/csrc/fast_ln/ln_fwd_kernels.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#pragma once - -#include "ln.h" // NOLINT -#include "ln_utils.h" // NOLINT - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel( - FwdParams params) { - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - - using Stats = typename Ktraits::Stats; - using stats_t = typename Stats::stats_t; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t r = bidm * ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); - - compute_t *mu_ptr = static_cast(params.mean); - compute_t *rs_ptr = static_cast(params.invvar); - - Wvec gamma[LDGS]; - Wvec beta[LDGS]; - index_t idx = c; - if (params.bias) { -#pragma unroll - for (int it = 0; it < LDGS; it++) { - gamma[it].load_from(params.scale, idx); - beta[it].load_from(params.bias, idx); - idx += VEC_COLS_PER_LDG; - } - } else { -#pragma unroll - for (int it = 0; it < LDGS; it++) { - gamma[it].load_from(params.scale, idx); - beta[it].init(0.); - idx += VEC_COLS_PER_LDG; - } - } - - constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); - bool is_rmsnorm = mu_ptr == nullptr; - - for (int row = r; row < params.rows; - row += params.ctas_per_col * ROWS_PER_CTA) { - Ivec x[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; - compute_t xf[LDGS * NUM_ELTS]; -#pragma unroll - for (int it = 0; it < LDGS; it++) { - x[it].load_from(params.x, idx); -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t x_ij = compute_t(x[it].data.elt[jt]); - xf[it * NUM_ELTS + jt] = x_ij; - } - idx += VEC_COLS_PER_LDG; - } - - stats_t s = stats.compute(xf, rn, is_rmsnorm); - - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); - - if (mu_ptr && bidn == 0 && warp_n == 0 && lane == 0) { - mu_ptr[row] = mu; - } - - compute_t rs = rsqrtf(rn * m2 + params.epsilon); - - if (bidn == 0 && warp_n == 0 && lane == 0) { - rs_ptr[row] = rs; - } - - Ovec z[LDGS]; - idx = row * Ktraits::VEC_COLS + c; -#pragma unroll - for (int it = 0; it < LDGS; it++) { -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - output_t y_ij; - if (is_rmsnorm) { - y_ij = output_t(rs * xf[it * NUM_ELTS + jt]); - } else { - y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu)); - } - output_t g_ij = gamma[it].data.elt[jt]; - output_t b_ij = beta[it].data.elt[jt]; - z[it].data.elt[jt] = (g_ij * y_ij + b_ij); - } - z[it].store_to(params.y, idx); - idx += VEC_COLS_PER_LDG; - } - } -} - -} // namespace layer_norm diff --git a/ops/csrc/fast_ln/ln_kernel_traits.h b/ops/csrc/fast_ln/ln_kernel_traits.h deleted file mode 100644 index dc539b65195e..000000000000 --- a/ops/csrc/fast_ln/ln_kernel_traits.h +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#pragma once - -#include "ln_bwd_kernels.h" // NOLINT -#include "ln_fwd_kernels.h" // NOLINT -#include "ln_utils.h" // NOLINT - -namespace layer_norm { -template -struct KernelTraitsBase { - using weight_t = weight_t_; - using input_t = input_t_; - using output_t = output_t_; - using compute_t = compute_t_; - using index_t = index_t_; - - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; - enum { THREADS_PER_WARP = 32 }; -}; - -template > -struct KernelTraitsFinalize : public Base { - enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; - static_assert((int)ROWS_PER_CTA <= (int)Base::THREADS_PER_WARP); // NOLINT - // Bytes per global load from the input. - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - // Number of elements fetched by a global load. - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; - // Bytes per global store of the weights. - enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; - static_assert( - sizeof(BYTES_PER_LDG) == 4, - "Conflict-free smem transpose only implemented for 4B compute type!"); - static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, - "We assume one warp per row!"); - // The total number of BYTES_PER_LDG-wide words in a hidden vector. - enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; - static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); - - // Shared memory size to transpose the CTA result. - enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; - // Shared memory size to coalesce the CTA result. - enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; - // Shared memory requirement per CTA. - enum { - SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT - }; - - // The type of the reducer. - using Reducer = layer_norm::Reducer; - - // Condition for the whole CTA to participate in syncthreads. - static_assert(COLS % Base::THREADS_PER_WARP == 0); - enum { CTAS = COLS / Base::THREADS_PER_WARP }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template > -struct KernelTraits : public Base { - using input_t = typename Base::input_t; - using weight_t = typename Base::weight_t; - using compute_t = typename Base::compute_t; - using output_t = typename Base::output_t; - using index_t = typename Base::index_t; - - enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; - enum { WARPS_M = WARPS_M_ }; - enum { WARPS_N = WARPS_N_ }; - enum { COLS = HIDDEN_SIZE_ }; - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; - - enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; - enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; - enum { ROWS_PER_CTA = WARPS_M }; - - enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; - enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; - // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed - enum { - SMEM_BYTES_WGRAD = - CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA* COLS * sizeof(compute_t) - }; - static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); - - using reduce_t = typename layer_norm::TypeToVec2::Type; - using Reducer = layer_norm::Reducer; - - enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; - enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; - - using Ivec = layer_norm::Vec; - using Ovec = layer_norm::Vec; - using Wvec = layer_norm::Vec; - using Cvec = layer_norm::Vec; - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; - - // Assume that each thread can handle the same number of elements in the - // output and weights as in the input. - static_assert(sizeof(input_t) >= sizeof(output_t)); - static_assert(sizeof(input_t) >= sizeof(weight_t)); - // The number of columns fetched per load from input: one per thread. - enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; - // The total number of vectorized loads/stores per hidden vector. - enum { VEC_COLS = COLS / ELTS_PER_LDG }; - // The number of loads per thread for the input. - enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; - static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); - // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, - // ""); - - using Stats = layer_norm::Stats; - enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; -}; - -} // namespace layer_norm diff --git a/ops/csrc/fast_ln/ln_utils.h b/ops/csrc/fast_ln/ln_utils.h deleted file mode 100644 index 627a3ab06bfd..000000000000 --- a/ops/csrc/fast_ln/ln_utils.h +++ /dev/null @@ -1,823 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ -#pragma once - -#include -#include -#include - -#include // NOLINT -#include // NOLINT - -#include "ln.h" // NOLINT - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -constexpr uint32_t THREADS_PER_WARP = 32; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline void check_cuda_(cudaError_t status, const char *file, int line) { - if (status != cudaSuccess) { - fprintf(stderr, - "CUDA Error: %s %s %d\n", - cudaGetErrorString(status), - file, - line); - exit(status); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(ans) \ - { check_cuda_((ans), __FILE__, __LINE__); } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, \ - WTYPE, \ - ITYPE, \ - OTYPE, \ - CTYPE, \ - CTAS_PER_ROW, \ - WARPS_M, \ - WARPS_N, \ - BYTES_PER_LDG) \ - void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_(launch_params, configure_params); \ - } \ - static FwdRegistrar \ - reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_LAUNCHER(HIDDEN_SIZE, \ - WTYPE, \ - ITYPE, \ - OTYPE, \ - CTYPE, \ - CTAS_PER_ROW, \ - WARPS_M, \ - WARPS_N, \ - BYTES_PER_LDG, \ - BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_(launch_params, configure_params); \ - } \ - static BwdRegistrar \ - reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 operator+(const float2 &a, const float2 &b) { - return {a.x + b.x, a.y + b.y}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT - a.x += b.x; - a.y += b.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Sum { - inline __device__ Sum() {} - inline __device__ T operator()(const T &a, const T &b) { return a + b; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) { - return __shfl_xor_sync(uint32_t(-1), x, idx); -} - -template <> -inline __device__ float2 warp_shuffle_xor(const float2 &x, - uint32_t idx) { - return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)}; -} - -template -inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) { - return __shfl_down_sync(uint32_t(-1), x, idx); -} - -template <> -inline __device__ float2 warp_shuffle_down(const float2 &x, - uint32_t idx) { - return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct uint16 { - uint4 u; - uint4 v; - uint4 s; - uint4 t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct uint8 { - uint4 u; - uint4 v; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BytesToType {}; - -template <> -struct BytesToType<64> { - using Type = uint16; - static_assert(sizeof(Type) == 64); -}; - -template <> -struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); -}; - -template <> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template <> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template <> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template <> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template <> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeToVec2 {}; - -template <> -struct TypeToVec2 { - using Type = float2; -}; - -template <> -struct TypeToVec2 { - using Type = half2; -}; - -template <> -struct TypeToVec2 { - using Type = nv_bfloat162; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Get { - template - static inline __device__ R of(const T &vec); -}; - -template <> -template -inline __device__ R Get<0>::of(const T &vec) { - return vec.x; -} - -template <> -template -inline __device__ R Get<1>::of(const T &vec) { - return vec.y; -} - -template <> -template -inline __device__ R Get<2>::of(const T &vec) { - return vec.z; -} - -template <> -template -inline __device__ R Get<3>::of(const T &vec) { - return vec.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Converter { - static inline __device__ Dst convert(const Src &from) { return Dst(from); } -}; - -template <> -struct Converter { - static inline __device__ half2 convert(const float2 &x) { - return __float22half2_rn(x); - } -}; - -template <> -struct Converter { - static inline __device__ nv_bfloat162 convert(const float2 &x) { -#if __CUDA_ARCH__ >= 800 - return __float22bfloat162_rn(x); -#else - union { - nv_bfloat162 raw; - nv_bfloat16 x; - nv_bfloat16 y; - } tmp; - tmp.x = __float2bfloat16_rn(x.x); - tmp.y = __float2bfloat16_rn(x.y); - return tmp.raw; -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Zeros { - static inline __device__ T get() { return T(0.f); } -}; - -template <> -struct Zeros { - static inline __device__ float2 get() { return make_float2(0.f, 0.f); } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Vec { - enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - - using Vec_type = typename BytesToType::Type; - - using Alias_type = union { - Vec_type vec; - Elt_type elt[NUM_ELT]; - }; - - Alias_type data; - - inline __device__ void init(Elt_type value) { -#pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - this->data.elt[it] = value; - } - } - - template - inline __device__ void to(Vec &other) { // NOLINT -#pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - other.data.elt[it] = S(this->data.elt[it]); - } - } - - template - inline __device__ void assign(const Op &op) { -#pragma unroll - for (int it = 0; it < NUM_ELT; it++) { - this->data.elt[it] = op(it); - } - } - - inline __device__ void load_from(const void *base_ptr, const size_t idx) { - this->data.vec = static_cast(base_ptr)[idx]; - } - - inline __device__ void store_to(void *base_ptr, const size_t idx) { - static_cast(base_ptr)[idx] = this->data.vec; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct InterCTASync { - template - inline __device__ InterCTASync(Params ¶ms, // NOLINT - uint32_t bidm, - uint32_t bidn) - : phase_counter_(0), - b0_(params.barrier + bidm) // The barrier for this group of CTAs. - , - b1_(params.barrier + bidm + params.ctas_per_col) { - } // The barrier for this group of CTAs. - - inline __device__ void spin_wait_(int *barrier, int step, int expected) { - asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), - "r"(step)); - for (int found = -1; found != expected;) { - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" - : "=r"(found) - : "l"(barrier)); - } - } - - inline __device__ void sync() { - // ALL THREADS MUST ENTER! - - // We switch barrier every iteration. - int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; - // We decrement every other iteration. - bool dec = phase_counter_ & 0x2; - int step = dec ? -1 : 1; - int expected = dec ? 0 : CTAS_PER_ROW; - // There are only 4 phases: up/down for b0/b1. - phase_counter_ = (phase_counter_ + 1) & 0x3; - - if (threadIdx.x == 0) { - spin_wait_(barrier, step, expected); - } - // CTA waits for thread 0 - __syncthreads(); - } - - int phase_counter_; - int *b0_; - int *b1_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer : public Reducer { - using InterCTASync = InterCTASync; - using Base = Reducer; - using Type = typename Base::Type; - - enum { SMEM_BYTES = Base::SMEM_BYTES }; - - enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; - enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; - - // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW - // to get total) - enum { - WORKSPACE_BYTES_PER_GROUP = - Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES - }; - - template - inline __device__ Reducer(Params ¶ms, // NOLINT - uint32_t bidm, - uint32_t bidn, - uint32_t warp_m, - uint32_t warp_n, - uint32_t lane, - void *smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), - inter_cta_(params, bidm, bidn), - bidn_(bidn) // CTA id within the group. - , - w0_(static_cast(params.workspace) + - (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), - w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {} - - template - inline __device__ T allreduce(T data, Op &op) { // NOLINT - data = Base::reduce(data, op); - // We switch workspace every iteration. - T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - // Warp leaders 0 hold the CTA-local results. - if (this->warp_n_ == 0 && this->lane_ == 0) { - workspace[bidn_] = data; - } - inter_cta_.sync(); - static_assert(CTAS_PER_ROW <= 32); - T total = Zeros::get(); - if (this->lane_ < CTAS_PER_ROW) { - total = workspace[this->lane_]; - } - total = Reducer::allreduce_(total, op); - - return total; - } - - InterCTASync inter_cta_; - - T *w0_; - T *w1_; - int bidn_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer { - using Type = T; - enum { SMEM_BYTES = 0 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params ¶ms, // NOLINT - uint32_t bidm, - uint32_t bidn, - uint32_t warp_m, - uint32_t warp_n, - uint32_t lane, - void *smem) - : warp_n_(warp_n), lane_(lane) {} - - template - static inline __device__ T allreduce_(T data, Op &op) { -#pragma unroll - for (int it = 1; it < THREADS_PER_WARP; it *= 2) { - data = op(data, warp_shuffle_xor(data, it)); - } - return data; - } - - template - inline __device__ T allreduce(T data, Op &op) { // NOLINT - return allreduce_(data, op); - } - - template - inline __device__ T reduce(T data, Op &op) { // NOLINT -// only lane 0 holds the result! -#pragma unroll - for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) { - data = op(data, warp_shuffle_down(data, it)); - } - return data; - } - int warp_n_; - int lane_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer : public Reducer { - using Base = Reducer; - - using Type = T; - - enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params ¶ms, // NOLINT - uint32_t bidm, - uint32_t bidn, - uint32_t warp_m, - uint32_t warp_n, - uint32_t lane, - void *smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { - smem0_ = &static_cast(smem)[warp_m * WARPS_N]; // NOLINT - smem1_ = smem0_ + WARPS_M * WARPS_N; - } - - template - inline __device__ T allreduce(T data, Op &op) { // NOLINT - T *smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - data = Base::reduce(data, op); - if (this->lane_ == 0) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); -#pragma unroll - for (int it = 0; it < WARPS_N; it++) { - out = op(out, smem[it]); - } - return out; - } - - template - inline __device__ T reduce(T data, Op &op) { // NOLINT - T *smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // only intra-CTA group leader holds the result! - data = Base::reduce(data, op); - if (this->lane_ == 0) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - if (this->warp_n_ == 0 && this->lane_ == 0) { -#pragma unroll - for (int it = 0; it < WARPS_N; it++) { - out = op(out, smem[it]); - } - } - return out; - } - - T *smem0_; - T *smem1_; - bool use0_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void warp_chan_upd_dynamic(T &m_a, // NOLINT - T &m2_a, // NOLINT - T &n_a, // NOLINT - int num_active) { - // Assume at least leftmost is valid and init: step = next_pow2(num_active) / - // 2 (might get NaN otherwise) - int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); - -#pragma unroll - for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { - // Exchange - T n_b = warp_shuffle_down(n_a, step); - T m_b = warp_shuffle_down(m_a, step); - T m2_b = warp_shuffle_down(m2_a, step); - - // Update - const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. - const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise - // this would simplify :( - const T delta = m_a - m_b; - const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; - const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; - - n_a = n_ab; - m_a = m_ab; - m2_a = m2_ab; - } - // Intra-warp broadcast (only lane 0 has valid stats). - m_a = __shfl_sync(uint32_t(-1), m_a, 0); - m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - // This could be done generically with the Reducer. But then we would have to - // exchange 3 instead of 2 fields. - - using InterCTASync = InterCTASync; - using BlockStats = Stats; - using stats_t = typename BlockStats::stats_t; - - enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; - - template - inline __device__ Stats(Params ¶ms, // NOLINT - uint32_t bidm, - uint32_t bidn, - uint32_t warp_m, - uint32_t warp_n, - uint32_t lane, - void *smem) - : inter_cta_(params, bidm, bidn), - block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), - bidn_(bidn) // CTA id within the group. - , - w0_(static_cast(params.workspace) + - (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), - w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), - warp_n_(warp_n), - lane_(lane) {} - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn, bool is_rmsnorm) { - constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; - constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); - stats_t block_stats = block_stats_.compute(elts, block_rn, is_rmsnorm); - - stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - if (warp_n_ == 0 && lane_ == 0) { - workspace[bidn_] = block_stats; - } - - // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. - inter_cta_.sync(); - - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume CTA group size in N less than 32, such that we can finalize with a - // single warp. - static_assert(CTAS_PER_ROW <= 32); - - // Every warp does the final reduction locally. - if (lane_ < CTAS_PER_ROW) { - stats_t result = workspace[lane_]; - n = ELTS_PER_ROW_PER_CTA; - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); - - return {m, m2}; - } - - InterCTASync inter_cta_; - BlockStats block_stats_; - - stats_t *w0_; - stats_t *w1_; - int bidn_; - int warp_n_; - int lane_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - using WarpStats = Stats; - using stats_t = typename WarpStats::stats_t; - - enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; - - template - inline __device__ Stats(Params ¶ms, // NOLINT - uint32_t bidm, - uint32_t bidn, - uint32_t warp_m, - uint32_t warp_n, - uint32_t lane, - void *smem) - : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), - use0_(true) { - smem0_ = static_cast(smem) + warp_m * WARPS_N; - smem1_ = smem0_ + WARPS_M * WARPS_N; - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn, bool is_rmsnorm) { - stats_t *smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // Compute warp local for all WARPS_N - constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); - stats_t warp_stats = warp_stats_.compute(elts, warp_rn, is_rmsnorm); - - // Each warp warp leader stores its stats - const auto warp_n = warp_stats_.reducer_.warp_n_; - const auto lane = warp_stats_.reducer_.lane_; - if (lane == 0) { - smem[warp_n] = warp_stats; - } - __syncthreads(); - - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume that there are less than 32 warps, such that we can finalize with - // a single warp - static_assert(WARPS_N <= 32); - if (lane < WARPS_N) { - stats_t result = smem[lane]; - n = N * THREADS_PER_WARP; - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, WARPS_N); - - return {m, m2}; - } - WarpStats warp_stats_; - stats_t *smem0_; - stats_t *smem1_; - bool use0_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - using stats_t = typename TypeToVec2::Type; - // The simple Warp reducer. - using Reducer = Reducer; - - enum { SMEM_BYTES = 0 }; - - template - inline __device__ Stats(Params ¶ms, // NOLINT - uint32_t bidm, - uint32_t bidn, - uint32_t warp_m, - uint32_t warp_n, - uint32_t lane, - void *smem) - : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {} - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn, bool is_rmsnorm) { - auto sum = Sum(); - - T m = Zeros::get(); -#pragma unroll - for (int it = 0; it < N; it++) { - m += elts[it]; - } - m = reducer_.allreduce(m, sum) * rn; - - T m2 = Zeros::get(); -#pragma unroll - for (int it = 0; it < N; it++) { - if (is_rmsnorm) { - m2 += elts[it] * elts[it]; - } else { - T diff = (elts[it] - m); - m2 += diff * diff; - } - } - m2 = reducer_.allreduce(m2, sum); - - return {m, m2}; - } - - Reducer reducer_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/ops/csrc/fp8/deep_gemm/__init__.py b/ops/csrc/fp8/deep_gemm/__init__.py deleted file mode 100644 index 67e14edc807d..000000000000 --- a/ops/csrc/fp8/deep_gemm/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -from . import jit -from .jit_kernels import ( - ceil_div, - gemm_fp8_fp8_bf16_nt, - get_col_major_tma_aligned_tensor, - get_m_alignment_for_contiguous_layout, - get_num_sms, - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked, - set_num_sms, -) -from .utils import bench, calc_diff, get_cuda_home diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh deleted file mode 100644 index cc74f21bce49..000000000000 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// The file has been adapted from DeepSeek DeepEP project -// Copyright (c) 2025 DeepSeek -// Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunknown-attributes" -#pragma once - -#include -#include - -#include -#include -#include - -#include "mma_utils.cuh" -#include "scheduler.cuh" -#include "tma_utils.cuh" -#include "utils.cuh" - -namespace deep_gemm { - -enum class Layout { - RowMajor, - ColMajor -}; - -template -__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { - DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); - return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; -} - -template -__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) -fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, - uint32_t shape_m, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_scales_a, - const __grid_constant__ CUtensorMap tensor_map_d) { -#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) - // Scaling checks - DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); - - // Types - using WGMMA = typename FP8MMASelector::type; - using Barrier = cutlass::arch::ClusterTransactionBarrier; - - // Shared memory - static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); - static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); - static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); - - // Configs - constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); - constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; - constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); - const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_id(); - - // Prefetch TMA descriptors at very beginning - if (threadIdx.x == kNumMathThreads) { - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); - } - __syncwarp(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); - - // Data on shared memory - auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); - __nv_fp8_e4m3* smem_a[kNumStages]; - __nv_fp8_e4m3* smem_b[kNumStages]; - float* smem_scales_a[kNumStages]; - float* smem_scales_b; - - // TMA Barrier for both divisible and non-divisible cases - Barrier* full_barriers[kNumStages]; - Barrier* empty_barriers[kNumStages]; - - // Fill shared memory pointers - #pragma unroll - for (int i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); - } - smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); - - // Fill barriers - auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); - #pragma unroll - for (int i = 0; i < kNumStages; ++ i) { - full_barriers[i] = barrier_start_ptr + i; - empty_barriers[i] = barrier_start_ptr + kNumStages + i; - } - - // Initialize barriers - DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); - if (threadIdx.x == kNumMathThreads) { - #pragma unroll - for (int i = 0; i < kNumStages; ++ i) { - full_barriers[i]->init(1); - empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); - } - - // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); - (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); - } - - // Synchronize all threads to make barrier visible in normal memory model - (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); - - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - auto launch_k_iterations = [](const auto& func) { - if constexpr (SHAPE_K % kFullKOfAllStages == 0) { - for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) - func(k_iter, DivisibleK{}); - } else { - for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}); - func(kNumIterations - 1, NotDivisibleK{}); - } - }; - - // Register reconfigurations - constexpr int kNumTMARegisters = 40; - constexpr int kNumMathRegisters = 232; - - // Block scheduler - uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); - - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); - - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - - // Issue TMA A with broadcasting - auto& full_barrier = *full_barriers[s]; - int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - - // Issue TMA B without broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }); - } - - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); - } - } - } else { - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); - const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); - uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; - if constexpr (not kMustUseUniformedScaleB) { - num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; - num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; - } - uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); - - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; - - // Empty barrier arrival - auto empty_barrier_arrive = [&](int s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); - } - }; - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - - // Promote with scales - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }); - - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); - } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - __syncwarp(); - } - } -#else - if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); -#endif -} - -template -class Gemm { -private: - using Barrier = cuda::barrier; - -public: - Gemm() = default; - - static void run(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, - uint32_t shape_m, - const CUtensorMap& tma_a_desc, - const CUtensorMap& tma_b_desc, - const CUtensorMap& tma_scales_a_desc, - const CUtensorMap& tma_d_desc, - cudaStream_t stream, - int num_sms, uint32_t smem_size) { - // NOTES: we must use 4 warps to do TMA, because `setmaxnreg.aligned` requires 4 warps - constexpr uint32_t kNumTMAThreads = 128; - constexpr uint32_t kNumMathThreadsPerGroup = 128; - auto kernel = fp8_gemm_kernel; - DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); - - // Cluster launch - cudaLaunchConfig_t config; - config.gridDim = num_sms; - config.blockDim = get_num_threads_per_sm(BLOCK_M); - config.dynamicSmemBytes = smem_size; - config.stream = stream; - - // Clusters for TMA multicast - // NOTES: `>= 4` cluster size will cause performance degradation - cudaLaunchAttribute attr; - attr.id = cudaLaunchAttributeClusterDimension; - attr.val.clusterDim = {kNumTMAMulticast, 1, 1}; - config.attrs = &attr; - config.numAttrs = 1; - - // Launch - auto status = cudaLaunchKernelEx(&config, kernel, - gmem_d, scales_b, grouped_layout, - shape_m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); - DG_HOST_ASSERT(status == cudaSuccess); - } - - template - static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) { - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K); - } - - template - static CUtensorMap make_2d_tma_b_desc(T* global_address) { - return make_2d_tma_desc(global_address, Layout::ColMajor, - SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N); - } - - template - static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) { - return make_2d_tma_desc(global_address, Layout::RowMajor, - shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, - min(BLOCK_M, shape_m), BLOCK_N, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); - } - - template - static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) { - // Make TMA aligned to 16 bytes - constexpr uint32_t kAlignment = 16 / sizeof(T); - shape_m = ceil_div(shape_m, kAlignment) * kAlignment; - - return make_2d_tma_desc(global_address, Layout::ColMajor, - shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); - } - - template - static CUtensorMap make_2d_tma_desc( - T* global_address, Layout layout, - uint32_t gmem_rows, uint32_t gmem_cols, - uint32_t smem_rows, uint32_t smem_cols, - CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { - if (layout == Layout::RowMajor) { - uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; - uint32_t smem_dim[2] = {smem_cols, smem_rows}; - return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type); - } else { - uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; - uint32_t smem_dim[2] = {smem_rows, smem_cols}; - return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type); - } - } -}; - -}; // namespace deep_gemm - -#pragma clang diagnostic pop \ No newline at end of file diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh deleted file mode 100644 index c1b9dc5148e5..000000000000 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh +++ /dev/null @@ -1,903 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// The file has been adapted from DeepSeek DeepEP project -// Copyright (c) 2025 DeepSeek -// Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -#pragma once - -#include - -#include "utils.cuh" - -namespace deep_gemm { - -struct SM90_64x16x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 16; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x24x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %14, 0;\n" - "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 24; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x32x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 32; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x40x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 40; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x48x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 48; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x56x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %30, 0;\n" - "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}, " - " %28," - " %29," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 56; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x64x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}, " - " %32," - " %33," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 64; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x72x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %38, 0;\n" - "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}, " - " %36," - " %37," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 72; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x80x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}, " - " %40," - " %41," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 80; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x88x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %46, 0;\n" - "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}, " - " %44," - " %45," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 88; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x96x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}, " - " %48," - " %49," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 96; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x104x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %54, 0;\n" - "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}, " - " %52," - " %53," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 104; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x112x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}, " - " %56," - " %57," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 112; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x120x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - float& d56, float& d57, float& d58, float& d59, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %62, 0;\n" - "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}, " - " %60," - " %61," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - d[56], d[57], d[58], d[59], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 120; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x128x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}, " - " %64," - " %65," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 128; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -struct SM90_64x192x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, - float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, - float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, - float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87, - float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}, " - " %96," - " %97," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], - d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], - d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], - d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], - d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 192; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -template -struct SM90_U32x2_STSM_N { - __device__ __forceinline__ static void - copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { - const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; - asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" - :: "l"(smem_dst), "r"(src[0]), "r"(src[1])); - } -}; - -template -struct SM90_U32x4_STSM_N { - __device__ __forceinline__ static void - copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { - const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), - *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; - asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" - :: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); - } -}; - -__device__ void warpgroup_arrive() { - asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); -} - -__device__ void warpgroup_commit_batch() { - asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); -} - -__device__ void warpgroup_fence_operand(float& reg) { - asm volatile("" : "+f"(reg) :: "memory"); -} - -__forceinline__ __device__ uint32_t get_lane_id() { - uint32_t lane_id; - asm("mov.u32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; -} - -__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) { - uint32_t ret; - asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) { - int4 ret; - asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { - float ret; - asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ void st_shared(const float* ptr, float val) { - asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); -} - -__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { - asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); -} - -template -__device__ void warpgroup_wait() { - DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); - asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); -} - -union GmmaDescriptor { - __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} - - __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {} - - __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {} - - __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {} - - __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept { - desc_ = t.desc_; - return *this; - } - - __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept { - desc_ = t.desc_; - return *this; - } - - uint64_t desc_; - uint32_t reg32_[2]; - uint16_t reg16_[4]; - - struct { - uint16_t start_address_: 14, : 2; - uint16_t leading_byte_offset_: 14, : 2; - uint16_t stride_byte_offset_: 14, : 2; - uint8_t : 1, base_offset_: 3, : 4; - uint8_t : 6, layout_type_: 2; - } bitfield; - - // Decay to an `uint64_t` - __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } -}; - -template -__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, - int leading_byte_offset = 0, - int stride_byte_offset = 1024) { - GmmaDescriptor desc; - auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); - desc.bitfield.start_address_ = uint_ptr >> 4; - desc.bitfield.layout_type_ = layout_type; - desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; - desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; - desc.bitfield.base_offset_ = 0; - return desc; -} - -template -struct FP8MMASelector { - static constexpr auto select_type() { - if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS(); - if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS(); - if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS(); - if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS(); - if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS(); - if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS(); - if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS(); - if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS(); - if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS(); - if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS(); - if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS(); - if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS(); - if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS(); - if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS(); - if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS(); - if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); - } - - using type = decltype(select_type()); -}; - -} // namespace deep_gemm \ No newline at end of file diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh deleted file mode 100644 index 35cbcf1e3a1e..000000000000 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// The file has been adapted from DeepSeek DeepEP project -// Copyright (c) 2025 DeepSeek -// Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -#include "utils.cuh" - -namespace deep_gemm { - -enum class GemmType { - Normal, - GroupedContiguous, - GroupedMasked -}; - -#pragma clang diagnostic push -#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" -template -struct Scheduler { - int current_iter = -1; - uint32_t num_aligned_m_blocks; - - // For normal GEMM - // Maybe not used in the masked grouped GEMM - uint32_t num_blocks; - - // For grouped GEMM - int* grouped_layout; - // Only used for masked layout - uint32_t curr_group_idx, curr_cumsum; - - __device__ __forceinline__ explicit Scheduler(const uint32_t shape_m, - int* grouped_layout = nullptr) { - num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M); - if constexpr (kGemmType == GemmType::Normal) { - num_blocks = num_aligned_m_blocks * kNumNBlocks; - } else if (kGemmType == GemmType::GroupedContiguous) { - num_blocks = num_aligned_m_blocks * kNumNBlocks; - this->grouped_layout = grouped_layout; - } else if (kGemmType == GemmType::GroupedMasked) { - curr_group_idx = curr_cumsum = 0; - this->grouped_layout = grouped_layout; - } - } - - __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { - DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); - - // Swizzle for better L2 usages - auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; - auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); - auto in_group_idx = block_idx % num_blocks_per_group; - m_block_idx = in_group_idx / num_n_blocks_in_group; - n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; - } - - template - __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, - const uint32_t& block_idx, const uint32_t& m_block_idx=0) { - if constexpr (kGemmType == GemmType::Normal) { - return block_idx * block_size; - } else if (kGemmType == GemmType::GroupedContiguous) { - auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); - return offset * shape_dim + block_idx * block_size; - } else if (kGemmType == GemmType::GroupedMasked) { - return curr_group_idx * shape_dim + block_idx * block_size; - } - } - - __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { - const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; - - if constexpr (kGemmType == GemmType::GroupedMasked) { - uint32_t num_m_blocks; - while (true) { - // End of the task - if (curr_group_idx == kNumGroups) - return false; - - // Within current group - num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); - auto current_m_block_cumsum = curr_cumsum + num_m_blocks; - if (next_block_idx < current_m_block_cumsum * kNumNBlocks) - break; - - // Move to check the next group - curr_group_idx ++, curr_cumsum = current_m_block_cumsum; - } - - get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); - } else { - if (next_block_idx >= num_blocks) - return false; - - get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); - } - return true; - } -}; -#pragma clang diagnostic pop - -} // namespace deep_gemm \ No newline at end of file diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh deleted file mode 100644 index 47f23fa01ad1..000000000000 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// The file has been adapted from DeepSeek DeepEP project -// Copyright (c) 2025 DeepSeek -// Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "utils.cuh" - -namespace deep_gemm { - -template -constexpr CUtensorMapDataType get_CUtensorMapDataType() { - if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_UINT64; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_INT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_INT64; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - } else if constexpr (std::is_same::value) { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; - } -} - -PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { - // Get pointer to `cuTensorMapEncodeTiled` - cudaDriverEntryPointQueryResult driver_status; - void* cuTensorMapEncodeTiled_ptr = nullptr; - -/* -#if CUDA_VERSION >= 12050 - cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, - cudaEnableDefault, &driver_status); -#else -*/ - cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, - cudaEnableDefault, &driver_status); -//#endif - - if (driver_status != cudaDriverEntryPointSuccess) - throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); - return reinterpret_cast(cuTensorMapEncodeTiled_ptr); -} - -template -CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], - uint64_t stride_in_bytes, uint32_t smem_dim[2], - CUtensorMapSwizzle swizzle_type, - PFN_cuTensorMapEncodeTiled encode_func = nullptr) { - CUtensorMap tensor_map{}; - constexpr uint32_t rank = 2; - uint64_t global_stride[rank - 1] = {stride_in_bytes}; - uint32_t elem_strides[rank] = {1, 1}; - - if (encode_func == nullptr) - encode_func = get_cuTensorMapEncodeTiled(); - - auto result = encode_func( - &tensor_map, get_CUtensorMapDataType::type>(), rank, - global_address, gmem_dim, global_stride, smem_dim, elem_strides, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - DG_HOST_ASSERT(result == CUDA_SUCCESS); - return tensor_map; -} - -template -__device__ __forceinline__ void -tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - int32_t const& crd_0, int32_t const& crd_1) { - constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); - if constexpr (kNumTMAMulticast == 1) { - cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); - } else if (cute::block_rank_in_cluster() == 0) { - cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1); - } -} - -} // namespace deep_gemm \ No newline at end of file diff --git a/ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh b/ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh deleted file mode 100644 index c21d16e513c2..000000000000 --- a/ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// The file has been adapted from DeepSeek DeepEP project -// Copyright (c) 2025 DeepSeek -// Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -#pragma once - -#include - -#ifdef __CLION_IDE__ -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); } -#define printf host_device_printf -#endif - -class AssertionException : public std::exception { -private: - std::string message{}; - -public: - explicit AssertionException(const std::string& message) : message(message) {} - - const char *what() const noexcept override { return message.c_str(); } -}; - -#ifndef DG_HOST_ASSERT -#define DG_HOST_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", \ - __FILE__, __LINE__, #cond); \ - throw AssertionException("Assertion failed: " #cond); \ - } \ -} while (0) -#endif - -#ifndef DG_DEVICE_ASSERT -#define DG_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - asm("trap;"); \ - } \ -} while (0) -#endif - -#ifndef DG_STATIC_ASSERT -#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) -#endif - -template -__device__ __host__ constexpr T ceil_div(T a, T b) { - return (a + b - 1) / b; -} \ No newline at end of file diff --git a/ops/csrc/fp8/deep_gemm/jit/__init__.py b/ops/csrc/fp8/deep_gemm/jit/__init__.py deleted file mode 100644 index cb04fd0007f8..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -from .compiler import build, get_nvcc_compiler -from .runtime import Runtime -from .template import cpp_format, generate diff --git a/ops/csrc/fp8/deep_gemm/jit/compiler.py b/ops/csrc/fp8/deep_gemm/jit/compiler.py deleted file mode 100644 index ea6714be1f5a..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit/compiler.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import functools -import hashlib -import os -import re -import subprocess -import uuid -from typing import Tuple - -from ..utils import get_cuda_home -from . import interleave_ffma -from .runtime import Runtime, RuntimeCache -from .template import typename_map - -runtime_cache = RuntimeCache() - - -def hash_to_hex(s: str) -> str: - md5 = hashlib.md5() - md5.update(s.encode("utf-8")) - return md5.hexdigest()[0:12] - - -@functools.lru_cache(maxsize=None) -def get_jit_include_dir() -> str: - return f"{os.path.dirname(os.path.abspath(__file__))}/../include" - - -@functools.lru_cache(maxsize=None) -def get_deep_gemm_version() -> str: - # Update include directories - include_dir = f"{get_jit_include_dir()}/deep_gemm" - assert os.path.exists(include_dir), f"Cannot find GEMM include directory {include_dir}" - md5 = hashlib.md5() - for filename in filter(lambda x: x.endswith(".cuh"), sorted(os.listdir(include_dir))): - with open(f"{include_dir}/{filename}", "rb") as f: - md5.update(f.read()) - - # Update `interleave_ffma.py` - with open(f"{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py", "rb") as f: - md5.update(f.read()) - return md5.hexdigest()[0:12] - - -@functools.lru_cache(maxsize=None) -def get_nvcc_compiler() -> Tuple[str, str]: - paths = [] - if os.getenv("DG_NVCC_COMPILER"): - paths.append(os.getenv("DG_NVCC_COMPILER")) - CUDA_HOME = get_cuda_home() - paths.append(f"{CUDA_HOME}/bin/nvcc") - - # Try to find the first available NVCC compiler - least_version_required = "12.3" - version_pattern = re.compile(r"release (\d+\.\d+)") - for path in paths: - if os.path.exists(path): - match = version_pattern.search(os.popen(f"{path} --version").read()) - version = match.group(1) - assert match, f"Cannot get the version of NVCC compiler {path}" - assert ( - version >= least_version_required - ), f"NVCC {path} version {version} is lower than {least_version_required}" - return path, version - raise RuntimeError("Cannot find any available NVCC compiler") - - -@functools.lru_cache(maxsize=None) -def get_default_user_dir(): - if "DG_CACHE_DIR" in os.environ: - path = os.getenv("DG_CACHE_DIR") - os.makedirs(path, exist_ok=True) - return path - return os.path.expanduser("~") + "/.deep_gemm" - - -@functools.lru_cache(maxsize=None) -def get_tmp_dir(): - return f"{get_default_user_dir()}/tmp" - - -@functools.lru_cache(maxsize=None) -def get_cache_dir(): - return f"{get_default_user_dir()}/cache" - - -def make_tmp_dir(): - tmp_dir = get_tmp_dir() - os.makedirs(tmp_dir, exist_ok=True) - return tmp_dir - - -def put(path, data, is_binary=False): - # Write and do POSIX atomic replace - tmp_file_path = f"{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}" - with open(tmp_file_path, "wb" if is_binary else "w") as f: - f.write(data) - os.replace(tmp_file_path, path) - - -def build(name: str, arg_defs: tuple, code: str) -> Runtime: - # Compiler flags - nvcc_flags = [ - "-std=c++17", - "-shared", - "-O3", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "-gencode=arch=compute_90a,code=sm_90a", - "--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""), - # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases - "--diag-suppress=177,174,940", - ] - cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi"] - flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] - include_dirs = [get_jit_include_dir()] - - # Build signature - enable_sass_opt = get_nvcc_compiler()[1] <= "12.8" and int(os.getenv("DG_DISABLE_FFMA_INTERLEAVE", 0)) == 0 - signature = f"{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}" - name = f"kernel.{name}.{hash_to_hex(signature)}" - path = f"{get_cache_dir()}/{name}" - - # Check runtime cache or file system hit - global runtime_cache - if runtime_cache[path] is not None: - if os.getenv("DG_JIT_DEBUG", None): - print(f"Using cached JIT runtime {name} during build") - return runtime_cache[path] - - # Write the code - os.makedirs(path, exist_ok=True) - args_path = f"{path}/kernel.args" - src_path = f"{path}/kernel.cu" - put(args_path, ", ".join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs])) - put(src_path, code) - - # Compile into a temporary SO file - so_path = f"{path}/kernel.so" - tmp_so_path = f"{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so" - # Compile - command = [get_nvcc_compiler()[0], src_path, "-o", tmp_so_path, *flags, *[f"-I{d}" for d in include_dirs]] - if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_JIT_PRINT_NVCC_COMMAND", False): - print(f"Compiling JIT runtime {name} with command {command}") - return_code = subprocess.check_call(command) - assert return_code == 0, f"Failed to compile {src_path}" - - # Interleave FFMA reuse - if enable_sass_opt: - interleave_ffma.process(tmp_so_path) - - # Atomic replace SO file - os.replace(tmp_so_path, so_path) - - # Put cache and return - runtime_cache[path] = Runtime(path) - return runtime_cache[path] diff --git a/ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py b/ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py deleted file mode 100644 index 0a5919b6b87b..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import argparse -import mmap -import os -import re -import subprocess - -from ..utils import get_cuda_home - - -def run_cuobjdump(file_path): - CUDA_HOME = get_cuda_home() - command = [f"{CUDA_HOME}/bin/cuobjdump", "-sass", file_path] - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - assert result.returncode == 0 - return result.stdout - - -def extract_ffma(sass): - lines = sass.splitlines() - collected = [] - current = [] - - arch_name, func_name = "N/A", "N/A" - skip_next_line = False - for line in lines: - if "code for" in line: - arch_name = line.lstrip().lstrip("code for ").rstrip() - elif "Function :" in line: - func_name = line.lstrip().lstrip("Function :").rstrip() - elif "FFMA" in line: - current.append(line) - skip_next_line = True - elif skip_next_line: - current.append(line) - skip_next_line = False - else: - if len(current) >= 16: - assert len(current) % 2 == 0 - collected.append((f"{arch_name}::{func_name}", current)) - current = [] - - if os.getenv("DG_PRINT_REG_REUSE", None): - print(f"Found {len(collected)} FFMA segments") - return collected - - -def extract_hex_from_line(line): - match = re.search(r"/\*\s*(0x[0-9a-fA-F]+)\s*\*/", line) - assert match - return int(match.group(1), 16) - - -def validate(m, offset, le_bytes, num_lines): - assert len(le_bytes) == num_lines // 2 - assert m[offset : offset + 16] == le_bytes[0] - for i in range(1, num_lines // 2): - if m[offset + i * 16 : offset + i * 16 + 16] != le_bytes[i]: - return False - return True - - -def parse_registers(line): - line = re.sub(r"/\*.*?\*/", "", line) - line = line.replace(";", "") - tokens = line.strip().split(",") - registers = [] - for token in tokens: - token = token.strip() - words = token.split() - for word in words: - if word.startswith("R"): - reg = word.split(".")[0] - registers.append(reg) - return registers - - -def modify_segment(m, name, ffma_lines): - num_lines = len(ffma_lines) - assert num_lines % 2 == 0 - - le_bytes, new_le_bytes = [], [] - reused_list = [] - dst_reg_set = set() - last_reused, last_dst_reg = False, "" - num_changed = 0 - for i in range(num_lines // 2): - dst_reg = parse_registers(ffma_lines[i * 2])[-2] - low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1] - low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line) - le_bytes.append(low_hex.to_bytes(8, "little") + high_hex.to_bytes(8, "little")) - reused = (high_hex & 0x0800000000000000) != 0 - if reused: - is_first_occurred = dst_reg not in dst_reg_set - if is_first_occurred or (last_reused and dst_reg == last_dst_reg): - # Modify the `reuse` and `yield` bits - assert high_hex & 0x0800200000000000, f"{hex(high_hex)}" - high_hex ^= 0x0800200000000000 - reused = False - num_changed += 1 - else: - reused_list.append(i) - dst_reg_set.add(dst_reg) - new_le_bytes.append(low_hex.to_bytes(8, "little") + high_hex.to_bytes(8, "little")) - last_reused, last_dst_reg = reused, dst_reg - if os.getenv("DG_PRINT_REG_REUSE", None): - print(f" > segment `{name}` new reused list ({num_changed} changed): {reused_list}") - - # Find the offset - offsets = [] - offset = m.find(le_bytes[0]) - while offset != -1: - offsets.append(offset) - offset = m.find(le_bytes[0], offset + 1) - offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets)) - - # Replace with `new_le_bytes` - for offset in offsets: - for i in range(num_lines // 2): - m[offset + i * 16 : offset + i * 16 + 16] = new_le_bytes[i] - - -def process(path): - if os.getenv("DG_PRINT_REG_REUSE", None): - print(f"Processing {path}") - output = run_cuobjdump(path) - segments = extract_ffma(output) - with open(path, "r+b") as f: - mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE) - for segment in segments: - modify_segment(mm, *segment) - mm.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Interleave FFMA reg reuse") - parser.add_argument("--so", help="Path to the SO file") - args = parser.parse_args() - - process(args.so) diff --git a/ops/csrc/fp8/deep_gemm/jit/runtime.py b/ops/csrc/fp8/deep_gemm/jit/runtime.py deleted file mode 100644 index 258409f758c8..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit/runtime.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import ctypes -import os -from typing import Optional - -import paddle -from paddle import Tensor - -from .template import map_ctype - - -class Runtime: - def __init__(self, path: str) -> None: - self.path = path - self.lib = None - self.args = None - - assert self.is_path_valid(self.path) - - @staticmethod - def is_path_valid(path: str) -> bool: - # Exists and is a directory - if not os.path.exists(path) or not os.path.isdir(path): - return False - - # Contains all necessary files - files = ["kernel.cu", "kernel.args", "kernel.so"] - return all(os.path.exists(os.path.join(path, file)) for file in files) - - def __call__(self, *args) -> int: - # Load SO file - if self.lib is None or self.args is None: - self.lib = ctypes.CDLL(os.path.join(self.path, "kernel.so")) - with open(os.path.join(self.path, "kernel.args"), "r") as f: - self.args = eval(f.read()) - - # Check args and launch - assert len(args) == len(self.args), f"Expected {len(self.args)} arguments, got {len(args)}" - cargs = [] - for arg, (name, dtype) in zip(args, self.args): - if isinstance(arg, Tensor): - assert arg.dtype == dtype, f"Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`" - else: - assert isinstance(arg, dtype), f"Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`" - cargs.append(map_ctype(arg)) - - return_code = ctypes.c_int(0) - self.lib.launch(*cargs, ctypes.byref(return_code)) - return return_code.value - - -class RuntimeCache: - def __init__(self) -> None: - self.cache = {} - - def __getitem__(self, path: str) -> Optional[Runtime]: - # In Python runtime - if path in self.cache: - return self.cache[path] - - # Already compiled - if os.path.exists(path) and Runtime.is_path_valid(path): - runtime = Runtime(path) - self.cache[path] = runtime - return runtime - return None - - def __setitem__(self, path, runtime) -> None: - self.cache[path] = runtime diff --git a/ops/csrc/fp8/deep_gemm/jit/template.py b/ops/csrc/fp8/deep_gemm/jit/template.py deleted file mode 100644 index 1b7bfb9877b4..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit/template.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import copy -import ctypes -import os -from typing import Any, Dict, Iterable, Tuple - -import paddle -from paddle import Tensor - -# Name map for Python `eval` -typename_map: Dict[Any, str] = { - **{t: t.__name__ for t in (bool, int, float)}, - paddle.int32: "paddle.int32", - paddle.float32: "paddle.float32", - paddle.bfloat16: "paddle.bfloat16", - paddle.float8_e4m3fn: "paddle.float8_e4m3fn", - paddle.device.cuda.Stream: "paddle.device.cuda.Stream", -} -# `ctype` map for Python casting -ctype_map: Dict[Any, Any] = { - **{t: getattr(ctypes, f"c_{t.__name__}") for t in (bool, int, float)}, - **{ - t: ctypes.c_void_p - for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream) - }, -} - - -# Type map for both Python API and source code usages -genc_map = { - bool: ("bool", "bool"), - int: ("int", "int"), - float: ("float", "float"), - paddle.int32: ("void*", "int*"), - paddle.float32: ("void*", "float*"), - paddle.bfloat16: ("void*", "__nv_bfloat16*"), - paddle.float8_e4m3fn: ("void*", "__nv_fp8_e4m3*"), - paddle.device.cuda.Stream: ("void*", "cudaStream_t"), -} - - -def map_ctype(value: Any) -> Any: - ctype = ctype_map[value.dtype if isinstance(value, Tensor) else type(value)] - if isinstance(value, Tensor): - return ctype(value.data_ptr()) - if isinstance(value, paddle.device.cuda.Stream): - return ctype(value.cuda_stream) - return ctype(value) - - -def cpp_format(template: str, keys: Dict[str, Any]) -> str: - # We don't use `str.format` because it's not safe for C++ {} braces - new_template = copy.deepcopy(template) - for key, value in keys.items(): - new_template = new_template.replace(f"{{{key}}}", f"{value}") - return new_template - - -def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str: - # Common prefix - code = "// DeepGEMM auto-generated JIT CUDA source file\n\n" - - # Includes - preload_sys_includes = ["", "", "", ""] - preload_package_includes = ['"cutlass/cutlass.h"'] - - assert isinstance(includes, list) or isinstance(includes, tuple) - sys_includes = sorted( - list(set(preload_sys_includes + [include for include in includes if include.startswith("<")])) - ) - package_includes = sorted( - list(set(preload_package_includes + [include for include in includes if include.startswith('"')])) - ) - code += "\n".join(f"#include {include}" for include in sys_includes) + "\n\n" - code += "\n".join(f"#include {include}" for include in package_includes) + "\n\n" - - # Function signature - raw = "__raw_" - get_def = lambda n, t: f"{genc_map[t][0]} " + (raw if genc_map[t][0] != genc_map[t][1] else "") + n - code += f'extern "C" void launch(' - code += ", ".join( - [get_def(*arg_def) for arg_def in arg_defs] - + [ - "int& __return_code", - ] - ) - code += ") {\n" - - # Cast raw types - code += " // Cast raw types (if needed)\n" - for arg_name, arg_type in arg_defs: - if genc_map[arg_type][0] != genc_map[arg_type][1]: - code += f" auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n" - - # Function body - code += "\n".join([((" " if line else "") + line) for line in body.split("\n")]) - - # End the function - code += "}\n\n" - - # Debug print - if os.getenv("DG_JIT_DEBUG", None): - print(f"Generated code:\n{code}") - - return code diff --git a/ops/csrc/fp8/deep_gemm/jit_kernels/__init__.py b/ops/csrc/fp8/deep_gemm/jit_kernels/__init__.py deleted file mode 100644 index 61c49354bd35..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit_kernels/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -from .gemm import gemm_fp8_fp8_bf16_nt -from .m_grouped_gemm import ( - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked, -) -from .utils import ( - ceil_div, - get_col_major_tma_aligned_tensor, - get_m_alignment_for_contiguous_layout, - get_num_sms, - set_num_sms, -) diff --git a/ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py b/ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py deleted file mode 100644 index 166f10f28477..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import functools -from typing import Tuple - -import paddle -from paddle import Tensor - -from .tuner import jit_tuner -from .utils import ( - ceil_div, - get_col_major_tma_aligned_tensor, - get_m_alignment_for_contiguous_layout, - get_num_sms, -) - -# C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"',) -template = """ -using namespace deep_gemm; - -// Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; -constexpr auto BLOCK_M = {BLOCK_M}; -constexpr auto BLOCK_N = {BLOCK_N}; -constexpr auto kNumStages = {NUM_STAGES}; -constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; - -// Make a templated GEMM -using GemmType = Gemm; - -// Launch kernel -auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); -GemmType::run(out, rhs_scales, nullptr, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); -""" - - -def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: - if num_tma_multicast == 1: - return True - return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 - - -def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: - smem_d = block_m * block_n * 2 - smem_a_per_stage = block_m * block_k - smem_scales_a_per_stage = block_m * 4 - smem_b_per_stage = block_n * block_k - smem_scales_b = ceil_div(k, block_k) * 4 - smem_barrier = num_stages * 8 * 2 - - smem_size = 0 - smem_size += smem_d - smem_size += num_stages * smem_a_per_stage - smem_size += num_stages * smem_scales_a_per_stage - smem_size += num_stages * smem_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 - smem_size += smem_barrier - return smem_size - - -def get_best_configs( - m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False -) -> Tuple[int, int, int, int, int]: - if not is_grouped_contiguous: - # TODO: for some cases, smaller M block is better, add them into tuning space - block_ms = (64 if m <= 64 else 128,) - else: - block_ms = (get_m_alignment_for_contiguous_layout(),) - block_ns = tuple(range(16, 129, 8)) - - fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) - get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) - - # Decide block sizes by waves - best_block_m, best_block_n = None, None - for block_m in block_ms: - for block_n in block_ns: - success = False - num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) - if best_block_m is None or best_block_n is None: - success = True - elif num_waves < best_num_waves: - success = True - elif num_waves == best_num_waves: - # Check last wave utilization - util = get_last_wave_util(block_m, block_n) - best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util or ( - util == best_util - and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)) - ) - best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) - assert best_block_m is not None and best_block_n is not None - - # Always pick the longest one - # NOTES: for double B scales, the best number of stages may be reduced - best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): - best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) - if best_smem_size <= sm90_capacity: - best_num_stages = num_stages - break - assert best_num_stages is not None - - # Decide the number of TMA multicast - best_num_tma_multicast = 1 - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 - - return best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size - - -@functools.lru_cache() -def auto_tuning_with_compilation(m, n, k): - global includes, template - num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) - runtime = jit_tuner.compile_and_tune( - m, - n, - k, - name="gemm_fp8_fp8_bf16_nt", - keys={ - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "K": k, - "N": n, - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": num_tma_multicast, - }, - space=(), - includes=includes, - arg_defs=( - ("lhs", paddle.float8_e4m3fn), - ("lhs_scales", paddle.float32), - ("rhs", paddle.float8_e4m3fn), - ("rhs_scales", paddle.float32), - ("out", paddle.bfloat16), - ("m", int), - ("stream", paddle.device.cuda.Stream), - ("num_sms", int), - ("smem_size", int), - ), - template=template, - ) - return runtime, num_sms, smem_size - - -def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor) -> None: - """ - Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow Paddle operations. - - Arguments: - lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`. - the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[m, n]`, representing the result. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - m, k = lhs.shape - n, k_ = rhs.shape - m_, n_ = out.shape - assert n % 64 == 0 and k % 128 == 0 - - # Type and shape checks - assert m == m_ and n == n_ and k == k_ - assert n > 0 and k > 0 - assert lhs_scales.shape == [m, (k + 127) // 128] - assert rhs_scales.shape == [(n + 127) // 128, (k + 127) // 128] - assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32 - assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32 - assert out.dtype == paddle.bfloat16 - assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() - - # LHS scales must be transposed for TMA load, but not for RHS scales - # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Do nothing if `m` is zero - if m == 0: - return - runtime, num_sms, smem_size = auto_tuning_with_compilation(m, n, k) - args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.cuda.current_stream(), num_sms, smem_size) - # Run the kernel. - runtime(*args) diff --git a/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py b/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py deleted file mode 100644 index f5d1cc18c2c1..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import functools -from typing import Tuple - -import paddle -from paddle import Tensor - -from .gemm import get_best_configs -from .tuner import jit_tuner -from .utils import get_col_major_tma_aligned_tensor, get_num_sms - -# C++ code templates -includes = ('"deep_gemm/fp8_gemm.cuh"',) -template = """ -using namespace deep_gemm; - -// Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; -constexpr auto BLOCK_M = {BLOCK_M}; -constexpr auto BLOCK_N = {BLOCK_N}; -constexpr auto kNumStages = {NUM_STAGES}; -constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; - -// Make a templated grouped GEMM -using GemmType = Gemm; - -// Launch kernel -auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); -auto tma_b_desc = GemmType::make_2d_tma_b_desc(rhs); -auto tma_scales_a_desc = GemmType::make_2d_tma_scales_a_desc(lhs_scales, m); -auto tma_d_desc = GemmType::make_2d_tma_d_desc(out, m); -GemmType::run(out, rhs_scales, grouped_layout, - m, - tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc, - stream, num_sms, smem_size); -""" - - -@functools.lru_cache() -def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, num_sms): - global includes, template - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs( - m, n, k, 1, num_sms, is_grouped_contiguous=True - ) - runtime = jit_tuner.compile_and_tune( - m, - n, - k, - name="m_grouped_gemm_fp8_fp8_bf16_nt", - keys={ - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "GEMM_TYPE": "GroupedContiguous", - "K": k, - "N": n, - "NUM_GROUPS": num_groups, - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": num_tma_multicast, - }, - space=(), - includes=includes, - arg_defs=( - ("lhs", paddle.float8_e4m3fn), - ("lhs_scales", paddle.float32), - ("rhs", paddle.float8_e4m3fn), - ("rhs_scales", paddle.float32), - ("out", paddle.bfloat16), - ("grouped_layout", paddle.int32), - ("m", int), - ("num_groups", int), - ("stream", paddle.device.cuda.Stream), - ("num_sms", int), - ("smem_size", int), - ), - template=template, - ) - return runtime, num_sms, smem_size - - -def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, m_indices: Tensor -) -> None: - """ - Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow Paddle operations. - On the M axis, inputs are grouped into several batches, of which batch sizes aligned to - `get_m_alignment_for_contiguous_layout()` (128). - - Arguments: - lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`. - the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. - m_indices: a tensor of shape `[m_sum]` with type `paddle.int32`. - `m_indices[i]` records the group which the j-th row of the LHS belong to, - which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. - Values of `m_indices` in every-m-alignment-block must also be the same. - `-1` in this tensor indicates no RHS matrix selected, the kernel will skip the computation for that aligned block. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - m, k = lhs.shape - num_groups, n, k_ = rhs.shape - m_, n_ = out.shape - m__ = m_indices.numel() - # Type and shape checks - assert m == m_ == m__ and k == k_ and n == n_ - assert lhs_scales.shape == [m, (k + 127) // 128] - assert rhs_scales.shape == [num_groups, (n + 127) // 128, (k + 127) // 128] - assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32 - assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32 - assert out.dtype == paddle.bfloat16 - assert m_indices.dtype == paddle.int32 - assert lhs.is_contiguous() and rhs.is_contiguous() - assert out.is_contiguous() and m_indices.is_contiguous() - - # LHS scales must be transposed for TMA load, but not for RHS scales - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Do nothing if `m` is zero - if m == 0: - return - # Auto-tuning with compilation - global includes, template - num_sms = get_num_sms() - runtime, num_sms, smem_size = auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, num_sms) - - args = ( - lhs, - lhs_scales, - rhs, - rhs_scales, - out, - m_indices, - m, - num_groups, - paddle.device.cuda.current_stream(), - num_sms, - smem_size, - ) - runtime(*args) - - -def m_grouped_gemm_fp8_fp8_bf16_nt_masked( - lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, expected_m: int -) -> None: - """ - Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow Paddle operations. - Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch - should be separately transposed. - - Arguments: - lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`. - the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. - masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute - in the i-th group. - expected_m: a value hint (which is a value on CPU) for the M expectation of each batch, - correctly setting this value may lead to better performance. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - num_groups, m, k = lhs.shape - num_groups_, n, k_ = rhs.shape - num_groups__, m_, n_ = out.shape - num_groups___ = masked_m.numel() - - # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ - assert m == m_ and n == n_ and k == k_ - assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert lhs_scales.shape == [num_groups, m, (k + 127) // 128] - assert rhs_scales.shape == [num_groups, (n + 127) // 128, (k + 127) // 128] - assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32 - assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32 - assert out.dtype == paddle.bfloat16 - assert masked_m.dtype == paddle.int32 - assert lhs.is_contiguous() and rhs.is_contiguous() - assert out.is_contiguous() and masked_m.is_contiguous() - - # LHS scales must be transposed for TMA load, but not for RHS scales - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Auto-tuning with compilation - global includes, template - num_sms = get_num_sms() - block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs( - expected_m, n, k, num_groups, num_sms - ) - - # Extra checks for TMA store - if num_groups > 1 and m > block_m: - assert ( - m % block_m == 0 - ), f"For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})" - - args = ( - lhs, - lhs_scales, - rhs, - rhs_scales, - out, - masked_m, - m, - paddle.device.cuda.current_stream(), - num_sms, - smem_size, - ) - runtime = jit_tuner.compile_and_tune_group_gemm_masked( - name="m_grouped_gemm_fp8_fp8_bf16_nt", - keys={ - "N": n, - "K": k, - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "NUM_GROUPS": num_groups, - "NUM_STAGES": num_stages, - "NUM_TMA_MULTICAST": num_tma_multicast, - "GEMM_TYPE": "GroupedMasked", - }, - space=(), - includes=includes, - arg_defs=( - ("lhs", paddle.float8_e4m3fn), - ("lhs_scales", paddle.float32), - ("rhs", paddle.float8_e4m3fn), - ("rhs_scales", paddle.float32), - ("out", paddle.bfloat16), - ("grouped_layout", paddle.int32), - ("m", int), - ("stream", paddle.device.cuda.Stream), - ("num_sms", int), - ("smem_size", int), - ), - template=template, - args=args, - ) - - # Run the kernel - runtime(*args) diff --git a/ops/csrc/fp8/deep_gemm/jit_kernels/tuner.py b/ops/csrc/fp8/deep_gemm/jit_kernels/tuner.py deleted file mode 100644 index 4513df4e2d00..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit_kernels/tuner.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import copy -import os -from typing import Any, Dict - -import paddle - -from ..jit import Runtime, build, cpp_format, generate - - -class JITTuner: - def __init__(self) -> None: - self.tuned = {} - - def compile_and_tune_group_gemm_masked( - self, - name: str, - keys: Dict[str, Any], - space: tuple, - includes: tuple, - arg_defs: tuple, - template: str, - args: tuple, - ) -> Runtime: - # NOTES: we always assume the space and template will not change - # We also assume the GPU device will not be changed - # NOTES: the function must have no accumulated side effects - keys = {k: keys[k] for k in sorted(keys.keys())} - signature = (name, f"{keys}") - if signature in self.tuned: - if os.getenv("DG_JIT_DEBUG", None): - print(f"Using cached JIT kernel {name} with keys {keys}") - return self.tuned[signature] - - if os.getenv("DG_JIT_DEBUG", None): - print(f"Auto-tuning JIT kernel {name} with keys {keys}") - - assert signature not in self.tuned - assert args is not None - space = (dict(),) if len(space) == 0 else space - - kernels = [] - for tuned_keys in space: - assert isinstance(tuned_keys, dict) - full_keys = copy.deepcopy(keys) - full_keys.update(tuned_keys) - code = generate(includes, arg_defs, cpp_format(template, full_keys)) - - # Illegal build must raise errors - kernels.append((build(name, arg_defs, code), tuned_keys)) - - best_runtime, best_time, best_keys = None, None, None - for runtime, tuned_keys in kernels: - if len(space) > 1: - # Check kernel validity - return_code = runtime(*args) - if return_code != 0: - # Pass illegal kernels, e.g. insufficient shared memory capacity - if os.getenv("DG_JIT_DEBUG", None): - print( - f"Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}" - ) - continue - - # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels - start_event = paddle.device.cuda.Event(enable_timing=True) - end_event = paddle.device.cuda.Event(enable_timing=True) - paddle.empty(int(256e6 // 4), dtype=paddle.int32).zero_() - paddle.randn((8192, 8192), dtype=paddle.float32, device="cuda") @ paddle.randn( - (8192, 8192), dtype=paddle.float32 - ) - start_event.record() - for i in range(20): - assert runtime(*args) == 0 - end_event.record() - end_event.synchronize() - elapsed_time = start_event.elapsed_time(end_event) - else: - elapsed_time = 0 - - # Compare if better - if best_time is None or elapsed_time < best_time: - best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys - if os.getenv("DG_JIT_DEBUG", None): - print(f"Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}") - assert best_runtime is not None, f"Failed to tune JIT kernel {name} with keys {keys}" - - # Cache the best runtime and return - if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_PRINT_AUTOTUNE", None): - print(f"Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}") - self.tuned[signature] = best_runtime - return best_runtime - - def compile_and_tune( - self, - m, - n, - k, - name: str, - keys: Dict[str, Any], - space: tuple, - includes: tuple, - arg_defs: tuple, - template: str, - # args: tuple, - ) -> Runtime: - # NOTES: we always assume the space and template will not change - # We also assume the GPU device will not be changed - # NOTES: the function must have no accumulated side effects - signature = (name, m, k, n) - if signature in self.tuned: - return self.tuned[signature] - # keys = {k: keys[k] for k in sorted(keys.keys())} - # signature = (name, f"{keys}") - # if signature in self.tuned: - # return self.tuned[signature] - space = (dict(),) if len(space) == 0 else space - - kernels = [] - for tuned_keys in space: - assert isinstance(tuned_keys, dict) - full_keys = copy.deepcopy(keys) - full_keys.update(tuned_keys) - code = generate(includes, arg_defs, cpp_format(template, full_keys)) - - # Illegal build must raise errors - kernels.append((build(name, arg_defs, code), tuned_keys)) - - best_runtime, best_time, best_keys = None, None, None - for runtime, tuned_keys in kernels: - elapsed_time = 0 - - # Compare if better - if best_time is None or elapsed_time < best_time: - best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys - if os.getenv("DG_JIT_DEBUG", None): - print(f"Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}") - assert best_runtime is not None, f"Failed to tune JIT kernel {name} with keys {keys}" - - # Cache the best runtime and return - if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_PRINT_AUTOTUNE", None): - print(f"Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}") - self.tuned[signature] = best_runtime - return best_runtime - - -jit_tuner = JITTuner() diff --git a/ops/csrc/fp8/deep_gemm/jit_kernels/utils.py b/ops/csrc/fp8/deep_gemm/jit_kernels/utils.py deleted file mode 100644 index 0f8ee8bb5f5e..000000000000 --- a/ops/csrc/fp8/deep_gemm/jit_kernels/utils.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import paddle -from paddle import Tensor - -_num_sms = None - - -def set_num_sms(num_sms: int) -> None: - """ - Set the maximum SM count for all GEMM kernels to use. - - Arguments: - num_sms: the desired maximum SM count for all GEMM kernels to use. - """ - global _num_sms - assert 0 < num_sms <= paddle.device.cuda.get_device_properties(device="cuda").multi_processor_count - _num_sms = num_sms - - -def get_num_sms() -> int: - """ - Get the current maximum limit of SM count for all GEMM kernels to use. - If the count is never specified, the function will return the number of device SMs. - - Returns: - Current maximum limit of SM count for all GEMM kernels to use. - """ - global _num_sms - if _num_sms is None: - _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count - return _num_sms - - -def ceil_div(x: int, y: int) -> int: - """ - Perform ceiling division of two integers. - - Args: - x: the dividend. - y: the divisor. - - Returns: - The result of the ceiling division. - """ - return (x + y - 1) // y - - -def get_m_alignment_for_contiguous_layout(): - """ - When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. - Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well - with GEMM block shape. - - Returns: - Group-level alignment requirement for grouped contiguous layout, which is always 128. - """ - return 128 - - -def get_tma_aligned_size(x: int, element_size: int) -> int: - """ - Global memory address of TMA must be 16-byte aligned. - Since we use column-major layout for the LHS scaling tensor, - the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. - - Arguments: - x: original M-axis shape of the LHS scaling tensor. - element_size: element size of the LHS scaling tensor. - - Returns: - M-axis shape of the LHS scaling tensor after padding. - """ - tma_alignment_bytes = 16 - assert tma_alignment_bytes % element_size == 0 - alignment = tma_alignment_bytes // element_size - return ceil_div(x, alignment) * alignment - - -def get_col_major_tma_aligned_tensor(x: Tensor) -> Tensor: - """ - Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary. - If the input tensor is already column-major layout and 16-byte aligned along the M axis - (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. - - Arguments: - x: usually the LHS scaling tensor in GEMM. - - Returns: - The LHS scaling tensor of TMA-aligned transposed format. - """ - # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA - assert x.dim() in (2, 3) - remove_dim = False - if x.dim() == 2: - x, remove_dim = x.unsqueeze(0), True - - b, m, n = x.shape - aligned_m = get_tma_aligned_size(m, x.element_size()) - - # The last kernel gives a column-major TMA aligned layout - if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m: - return x.squeeze(0) if remove_dim else x - - # Normal layout requires transposing - aligned_x = paddle.transpose(paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1]) - aligned_x[:, :m, :] = x - aligned_x = aligned_x[:, :m, :] - return aligned_x.squeeze(0) if remove_dim else aligned_x diff --git a/ops/csrc/fp8/deep_gemm/utils.py b/ops/csrc/fp8/deep_gemm/utils.py deleted file mode 100644 index 9f67703c12a9..000000000000 --- a/ops/csrc/fp8/deep_gemm/utils.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import os -import sys - -import paddle - - -def bench(fn, num_warmups: int = 5, num_tests: int = 10, high_precision: bool = False): - # Flush L2 cache with 256 MB data - paddle.device.cuda.synchronize() - cache = paddle.empty(int(256e6 // 4), dtype=paddle.int32) - cache.zero_() - - # Warmup - for _ in range(num_warmups): - fn() - - # Add a large kernel to eliminate the CPU launch overhead - if high_precision: - x = paddle.randn((8192, 8192), dtype=paddle.float32) - y = paddle.randn((8192, 8192), dtype=paddle.float32) - x @ y - - # Testing - start_event = paddle.device.cuda.Event(enable_timing=True) - end_event = paddle.device.cuda.Event(enable_timing=True) - start_event.record() - for i in range(num_tests): - fn() - end_event.record() - paddle.cuda.synchronize() - - return start_event.elapsed_time(end_event) / num_tests - - -def get_cuda_home(): - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - return cuda_home - - try: - which_cmd = "which nvcc" - - nvcc_path = os.popen(which_cmd).read().strip() - if nvcc_path: - return os.path.dirname(os.path.dirname(nvcc_path)) - except Exception: - pass - - return None - - -class empty_suppress: - def __enter__(self): - return self - - def __exit__(self, *_): - pass - - -class suppress_stdout_stderr: - def __enter__(self): - self.outnull_file = open(os.devnull, "w") - self.errnull_file = open(os.devnull, "w") - - self.old_stdout_fileno_undup = sys.stdout.fileno() - self.old_stderr_fileno_undup = sys.stderr.fileno() - - self.old_stdout_fileno = os.dup(sys.stdout.fileno()) - self.old_stderr_fileno = os.dup(sys.stderr.fileno()) - - self.old_stdout = sys.stdout - self.old_stderr = sys.stderr - - os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) - os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) - - sys.stdout = self.outnull_file - sys.stderr = self.errnull_file - return self - - def __exit__(self, *_): - sys.stdout = self.old_stdout - sys.stderr = self.old_stderr - - os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) - os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) - - os.close(self.old_stdout_fileno) - os.close(self.old_stderr_fileno) - - self.outnull_file.close() - self.errnull_file.close() - - -def calc_diff(x, y): - x, y = x.astype(paddle.float64), y.astype(paddle.float64) - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -def count_bytes(tensors): - total = 0 - for t in tensors: - if isinstance(t, tuple): - total += count_bytes(t) - else: - total += t.numel() * t.element_size() - return total diff --git a/ops/csrc/fp8/setup.py b/ops/csrc/fp8/setup.py deleted file mode 100644 index 019b28921a8a..000000000000 --- a/ops/csrc/fp8/setup.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import os -import shutil -import subprocess - -import setuptools -from setuptools.command.build_py import build_py -from setuptools.command.develop import develop - -current_dir = os.path.dirname(os.path.realpath(__file__)) -jit_include_dirs = ("deep_gemm/include/deep_gemm",) -third_party_include_dirs = ( - "../../../csrc/third_party/cutlass/include/cute", - "../../../csrc/third_party/cutlass/include/cutlass", -) - - -class PostDevelopCommand(develop): - def run(self): - develop.run(self) - self.make_jit_include_symlinks() - - @staticmethod - def make_jit_include_symlinks(): - # Make symbolic links of third-party include directories - for d in third_party_include_dirs: - dirname = d.split("/")[-1] - src_dir = f"{current_dir}/{d}" - dst_dir = f"{current_dir}/deep_gemm/include/{dirname}" - assert os.path.exists(src_dir) - if os.path.exists(dst_dir): - assert os.path.islink(dst_dir) - os.unlink(dst_dir) - os.symlink(src_dir, dst_dir, target_is_directory=True) - - -class CustomBuildPy(build_py): - def run(self): - # First, prepare the include directories - self.prepare_includes() - - # Then run the regular build - build_py.run(self) - - def prepare_includes(self): - # Create temporary build directory instead of modifying package directory - build_include_dir = os.path.join(self.build_lib, "deep_gemm/include") - os.makedirs(build_include_dir, exist_ok=True) - - # Copy third-party includes to the build directory - for d in third_party_include_dirs: - dirname = d.split("/")[-1] - src_dir = os.path.join(current_dir, d) - dst_dir = os.path.join(build_include_dir, dirname) - - # Remove existing directory if it exists - if os.path.exists(dst_dir): - shutil.rmtree(dst_dir) - - # Copy the directory - shutil.copytree(src_dir, dst_dir) - - -if __name__ == "__main__": - # noinspection PyBroadException - try: - cmd = ["git", "rev-parse", "--short", "HEAD"] - revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip() - except: - revision = "" - - setuptools.setup( - name="deep_gemm", - version="1.0.0" + revision, - packages=["deep_gemm", "deep_gemm/jit", "deep_gemm/jit_kernels"], - package_data={ - "deep_gemm": [ - "include/deep_gemm/**/*", - "include/cute/**/*", - "include/cutlass/**/*", - ] - }, - cmdclass={ - "develop": PostDevelopCommand, - "build_py": CustomBuildPy, - }, - ) diff --git a/ops/csrc/fp8/tests/test_core.py b/ops/csrc/fp8/tests/test_core.py deleted file mode 100644 index b10616dac2b5..000000000000 --- a/ops/csrc/fp8/tests/test_core.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file has been adapted from DeepSeek DeepEP project -# Copyright (c) 2025 DeepSeek -# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE - -import random -from typing import Tuple - -import deep_gemm -import paddle -from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor -from paddle import Tensor - - -def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]: - assert x.dim() == 2 and x.shape[1] % 128 == 0 - m, n = x.shape - x_view = paddle.view(x, (m, -1, 128)) - x_abs = paddle.abs(x_view).astype(paddle.float32) - x_amax = paddle.amax(x_abs, axis=2) - x_amax = paddle.view(x_amax, (m, -1)) - x_amax = paddle.clip(x_amax, min=1e-4) - scaled_x = x_view * (448.0 / x_amax.unsqueeze(2)) - scaled_x_converted = paddle.view(scaled_x.astype(paddle.float8_e4m3fn), (m, n)) - - x_amax_scaled = paddle.view((x_amax / 448.0), (m, -1)) - - result = (scaled_x_converted, x_amax_scaled) - return result - - -def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = paddle.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype) - x_padded[:m, :n] = x - x_view = paddle.view(x_padded, (-1, 128, x_padded.shape[1] // 128, 128)) - - x_abs = paddle.abs(x_view).astype(paddle.float32) - x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True) - x_amax = paddle.clip(x_amax, min=1e-4) - x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn) - - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( - paddle.view(x_amax / 448.0, (x_view.shape[0], x_view.shape[2])) - ) - - -def construct(m: int, k: int, n: int) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor, Tensor]: - x = paddle.randn((m, k), dtype=paddle.bfloat16) - y = paddle.randn((n, k), dtype=paddle.bfloat16) - out = paddle.empty((m, n), dtype=paddle.bfloat16) - ref_out = x @ y.t() - - x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -def construct_grouped( - num_groups: int, m: int, k: int, n: int, is_masked: bool -) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor, Tensor]: - # x_np = np.full((num_groups, m, k), 3) - # y_np = np.full((num_groups, n, k), 2) - # x=paddle.to_tensor(x_np).astype(paddle.bfloat16) - # y=paddle.to_tensor(y_np).astype(paddle.bfloat16) - x = paddle.randn((num_groups, m, k), dtype=paddle.bfloat16) - y = paddle.randn((num_groups, n, k), dtype=paddle.bfloat16) - out = paddle.empty((num_groups, m, n), dtype=paddle.bfloat16) - ref_out = paddle.einsum("gmk,gnk->gmn", x, y) - - assert m % 4 == 0, f"TMA alignment error: {m}" - x_fp8 = ( - paddle.empty_like(x, dtype=paddle.float8_e4m3fn), - paddle.empty((num_groups, m, k // 128), dtype=paddle.float32), - ) - y_fp8 = ( - paddle.empty_like(y, dtype=paddle.float8_e4m3fn), - paddle.empty((num_groups, (n + 127) // 128, k // 128), dtype=paddle.float32), - ) - for i in range(num_groups): - # x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) - # y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - x_fp8_0_i, x_fp8_1_i = per_token_cast_to_fp8(x[i]) - paddle.assign(x_fp8_0_i, x_fp8[0][i]) - paddle.assign(x_fp8_1_i, x_fp8[1][i]) - y_fp8_0_i, y_fp8_1_i = per_block_cast_to_fp8(y[i]) - paddle.assign(y_fp8_0_i, y_fp8[0][i]) - paddle.assign(y_fp8_1_i, y_fp8[1][i]) - - # For non-masked input, we must merge the group and M dims - if not is_masked: - x_fp8 = (paddle.view(x_fp8[0], (-1, k)), per_token_cast_to_fp8(paddle.view(x, (-1, k)))[1]) - out, ref_out = paddle.view(out, (-1, n)), paddle.view(ref_out, (-1, n)) - - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -def test_gemm() -> None: - print("Testing GEMM:") - for m in (64,): - for k, n in [ - (7168, 2112), - ]: - x_fp8, y_fp8, out, ref_out = construct(m, k, n) - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f"{m=}, {k=}, {n=}, {diff:.5f}" - - print() - - -def test_m_grouped_gemm_contiguous() -> None: - print("Testing grouped contiguous GEMM:") - - for num_groups, m, k, n in ((4, 8192, 7168, 4096),): - # TODO: make a stronger test - x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=False) - m_indices = paddle.arange(0, num_groups, dtype=paddle.int32) - # m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) - m_indices = paddle.flatten(paddle.expand(paddle.unsqueeze(m_indices, -1), shape=[num_groups, m])) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) - diff = calc_diff(out, ref_out) - print("diff:", diff) - assert diff < 0.001, f"m={m * num_groups}, {k=}, {n=}, {diff:.5f}" - print() - - -def test_m_grouped_gemm_masked() -> None: - print("Testing grouped masked GEMM:") - - for num_groups, m in ((1, 1024),): - for k, n in ((7168, 4096),): - # Test correctness - masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384))) - for i in range(10): - x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True) - masked_m = paddle.empty((num_groups,), dtype=paddle.int32) - for j in range(num_groups): - masked_m[j] = random.choice(masked_m_candidates) - # expected_m = min(int(masked_m.float().mean()) + 1, m) - masked_m_float = paddle.cast(masked_m, "float32") - masked_m_mean = paddle.mean(masked_m_float) - masked_m_mean_int = paddle.cast(masked_m_mean, "int32") - expected_m = min(masked_m_mean_int + 1, m) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m) - for j in range(num_groups): - diff = calc_diff(out[j, : masked_m[j].item()], ref_out[j, : masked_m[j].item()]) - print("diff:", diff) - assert diff < 0.001, f"{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}" - - print() - - -if __name__ == "__main__": - paddle.seed(0) - random.seed(0) - print("Library path:") - print(f" > {deep_gemm.__path__}\n") - test_gemm() - test_m_grouped_gemm_contiguous() - test_m_grouped_gemm_masked() diff --git a/ops/csrc/fused_ln/layer_norm_cuda.cu b/ops/csrc/fused_ln/layer_norm_cuda.cu deleted file mode 100644 index f803952d9e64..000000000000 --- a/ops/csrc/fused_ln/layer_norm_cuda.cu +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#include -#include - -#include "layer_norm_cuda.h" // NOLINT -#include "paddle/extension.h" - -#ifdef CUSTOM_OP_WITH_SPMD -#include "paddle/phi/api/ext/spmd_infer.h" -#include "paddle/phi/infermeta/spmd_rules/rules.h" -#endif - -#define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor") - -static void GetRowsCols(const std::vector &shape, - int *p_rows, - int *p_cols) { - int rows = 1; - for (int i = 0; i + 1 < shape.size(); ++i) { - rows *= shape[i]; - } - int cols = shape[shape.size() - 1]; - *p_rows = rows; - *p_cols = cols; -} - -std::vector RMSLnFwd(const paddle::Tensor &x, - const paddle::Tensor &scale, - float epsilon) { - const auto &scale_shape = scale.shape(); - const auto &x_shape = x.shape(); - PD_CHECK(scale_shape.size() == 1); - PD_CHECK(scale_shape[0] == x_shape[x_shape.size() - 1]); - CHECK_CUDA(x); - CHECK_CUDA(scale); - - int rows, cols; - GetRowsCols(x_shape, &rows, &cols); - - auto place = x.place(); - auto y = paddle::empty(x_shape, scale.type(), place); - auto variance_shape = x_shape; - variance_shape.pop_back(); - auto invvar = paddle::empty(variance_shape, paddle::DataType::FLOAT32, place); - cuda_rms_norm(x, scale, rows, cols, epsilon, &y, &invvar); - return {y, invvar}; -} - -std::vector LnFwd(const paddle::Tensor &x, - const paddle::Tensor &scale, - const paddle::Tensor &bias, - float epsilon) { - const auto &scale_shape = scale.shape(); - const auto &bias_shape = bias.shape(); - const auto &x_shape = x.shape(); - PD_CHECK(scale_shape == bias_shape); - PD_CHECK(scale_shape.size() == 1); - PD_CHECK(scale_shape[0] == x_shape[x_shape.size() - 1]); - CHECK_CUDA(x); - CHECK_CUDA(scale); - CHECK_CUDA(bias); - - int rows, cols; - GetRowsCols(x_shape, &rows, &cols); - - auto place = x.place(); - auto y = paddle::empty(x_shape, scale.type(), place); - auto mean = paddle::empty({rows}, paddle::DataType::FLOAT32, place); - auto invvar = paddle::empty_like(mean); - - cuda_layer_norm(x, scale, bias, rows, cols, epsilon, &y, &mean, &invvar); - return {y, mean, invvar}; -} - -std::vector> LnFwdInferShape( - std::vector x_shape, - std::vector scale_shape, - std::vector bias_shape, - float epsilon) { - int rows, cols; - GetRowsCols(x_shape, &rows, &cols); - return {x_shape, {rows}, {rows}}; -} - -std::vector> RMSLnFwdInferShape( - std::vector x_shape, - std::vector scale_shape, - float epsilon) { - auto variance_shape = x_shape; - variance_shape.pop_back(); - return {x_shape, variance_shape}; -} - -std::vector LnFwdInferDtype(paddle::DataType x_dtype, - paddle::DataType scale_dtype, - paddle::DataType bias_dtype) { - return {x_dtype, paddle::DataType::FLOAT32, paddle::DataType::FLOAT32}; -} - -std::vector RMSLnFwdInferDtype(paddle::DataType x_dtype, - paddle::DataType scale_dtype) { - return {x_dtype, paddle::DataType::FLOAT32}; -} - -std::vector LnBwd(const paddle::Tensor &x, - const paddle::Tensor &scale, - const paddle::Tensor &bias, - const paddle::Tensor &mean, - const paddle::Tensor &invvar, - const paddle::Tensor &dy, - float epsilon) { - CHECK_CUDA(dy); - CHECK_CUDA(mean); - CHECK_CUDA(invvar); - CHECK_CUDA(x); - CHECK_CUDA(scale); - CHECK_CUDA(bias); - - int rows, cols; - GetRowsCols(x.shape(), &rows, &cols); - - auto grad_x = paddle::empty_like(x); - auto grad_scale = paddle::empty_like(scale); - auto grad_bias = paddle::empty_like(bias); - - cuda_layer_norm_gradient(x, - scale, - bias, - mean, - invvar, - dy, - rows, - cols, - epsilon, - &grad_x, - &grad_scale, - &grad_bias); - return {grad_x, grad_scale, grad_bias}; -} - -std::vector RMSLnBwd(const paddle::Tensor &x, - const paddle::Tensor &scale, - const paddle::Tensor &invvar, - const paddle::Tensor &dy, - float epsilon) { - CHECK_CUDA(dy); - CHECK_CUDA(invvar); - CHECK_CUDA(x); - CHECK_CUDA(scale); - - int rows, cols; - GetRowsCols(x.shape(), &rows, &cols); - - auto grad_x = paddle::empty_like(x); - auto grad_scale = paddle::empty_like(scale); - - cuda_rms_norm_gradient( - x, scale, invvar, dy, rows, cols, epsilon, &grad_x, &grad_scale); - return {grad_x, grad_scale}; -} - -std::vector> LnBwdInferShape( - std::vector input_shape, - std::vector gamma_shape, - std::vector beta_shape, - std::vector mean_shape, - std::vector invvar_shape, - std::vector dout_shape, - float epsilon) { - return {input_shape, gamma_shape, beta_shape}; -} - -std::vector> RMSLnBwdInferShape( - std::vector input_shape, - std::vector gamma_shape, - std::vector invvar_shape, - std::vector dout_shape, - float epsilon) { - return {input_shape, gamma_shape}; -} - - -PD_BUILD_OP(fused_ln) - .Inputs({"x", "scale", "bias"}) - .Outputs({"y", "mean", "invvar"}) - .Attrs({"epsilon: float"}) - .SetKernelFn(PD_KERNEL(LnFwd)) - .SetInferShapeFn(PD_INFER_SHAPE(LnFwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(LnFwdInferDtype)); - -PD_BUILD_GRAD_OP(fused_ln) - .Inputs({"x", "scale", "bias", "mean", "invvar", paddle::Grad("y")}) - .Outputs({paddle::Grad("x"), paddle::Grad("scale"), paddle::Grad("bias")}) - .Attrs({"epsilon: float"}) - .SetKernelFn(PD_KERNEL(LnBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(LnBwdInferShape)); - -PD_BUILD_OP(fused_rms_norm) - .Inputs({"x", "scale"}) - .Outputs({"y", "invvar"}) - .Attrs({"epsilon: float"}) - .SetKernelFn(PD_KERNEL(RMSLnFwd)) - .SetInferShapeFn(PD_INFER_SHAPE(RMSLnFwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype)) -#ifdef CUSTOM_OP_WITH_SPMD - .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormInferSpmd)) -#endif - ; - -PD_BUILD_GRAD_OP(fused_rms_norm) - .Inputs({"x", "scale", "invvar", paddle::Grad("y")}) - .Outputs({paddle::Grad("x"), paddle::Grad("scale")}) - .Attrs({"epsilon: float"}) - .SetKernelFn(PD_KERNEL(RMSLnBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape)) -#ifdef CUSTOM_OP_WITH_SPMD - .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormGradInferSpmd)) -#endif - ; - - -// https://github.com/NVIDIA/apex/blob/85e9eddece9d4ac72b48c2407f8162f2173e1bf4/csrc/layer_norm_cuda_kernel.cu#L679 diff --git a/ops/csrc/fused_ln/layer_norm_cuda.h b/ops/csrc/fused_ln/layer_norm_cuda.h deleted file mode 100644 index 121de8b920cf..000000000000 --- a/ops/csrc/fused_ln/layer_norm_cuda.h +++ /dev/null @@ -1,1325 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#pragma once // NOLINT - -#ifdef PADDLE_WITH_HIP -#include -#else -#include // NOLINT -#include // NOLINT -#endif -#include "paddle/extension.h" - -#define DEFAULT_THROW(NAME, TYPE) \ - default: \ - do { \ - PD_THROW(#NAME, " not implemented for '", TYPE, "'"); \ - } while (0); \ - break - -#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch (TYPEIN) { \ - case paddle::DataType::FLOAT32: { \ - using scalar_t_in = float; \ - switch (TYPEOUT) { \ - case paddle::DataType::FLOAT32: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case paddle::DataType::FLOAT16: { \ - using scalar_t_out = phi::dtype::float16; \ - __VA_ARGS__; \ - break; \ - } \ - case paddle::DataType::BFLOAT16: { \ - using scalar_t_out = phi::dtype::bfloat16; \ - __VA_ARGS__; \ - break; \ - } \ - DEFAULT_THROW(NAME, TYPEOUT); \ - } \ - break; \ - } \ - case paddle::DataType::FLOAT16: { \ - using scalar_t_in = phi::dtype::float16; \ - using scalar_t_out = phi::dtype::float16; \ - __VA_ARGS__; \ - break; \ - } \ - case paddle::DataType::BFLOAT16: { \ - using scalar_t_in = phi::dtype::bfloat16; \ - using scalar_t_out = phi::dtype::bfloat16; \ - __VA_ARGS__; \ - break; \ - } \ - DEFAULT_THROW(NAME, TYPEIN); \ - } - -#ifdef PADDLE_WITH_HIP -#define WARP_SIZE 64 -#else -#define WARP_SIZE 32 -#endif - -template -__device__ __forceinline__ T WARP_SHFL_XOR(T value, - int laneMask, - int width = WARP_SIZE, - unsigned int mask = 0xffffffff) { - #ifdef PADDLE_WITH_HIP - return __shfl_xor(value, laneMask, width); - #else - return __shfl_xor_sync(mask,value, laneMask, width); - #endif -} - -template -__device__ __forceinline__ T WARP_SHFL(T value, - int srcLane, - int width = WARP_SIZE, - unsigned int mask = 0xffffffff) { - #ifdef PADDLE_WITH_HIP - return __shfl(value, srcLane, width); - #else - return __shfl_sync(mask, value, srcLane, width); - #endif -} - -template -__device__ void cuWelfordOnlineSum(const U curr, - U& mu, // NOLINT - U& sigma2, // NOLINT - U& count) { // NOLINT - count = count + U(1); - U delta = curr - mu; - U lmean = mu + delta / count; - mu = lmean; - U delta2 = curr - lmean; - sigma2 = sigma2 + delta * delta2; -} - -template -__device__ void cuChanOnlineSum(const U muB, - const U sigma2B, - const U countB, - U& mu, // NOLINT - U& sigma2, // NOLINT - U& count) { // NOLINT - U delta = muB - mu; - U nA = count; - U nB = countB; - count = count + countB; - U nX = count; - if (nX > U(0)) { - nA = nA / nX; - nB = nB / nX; - mu = nA * mu + nB * muB; - sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; - } else { - mu = U(0); - sigma2 = U(0); - } -} - -template -__device__ void cuRMSOnlineSum(const U curr, U& sigma2) { // NOLINT - sigma2 = sigma2 + curr * curr; -} - -template -__device__ void cuChanRMSOnlineSum(const U sigma2B, U& sigma2) { // NOLINT - sigma2 = sigma2 + sigma2B; -} - - -template -__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, - const int n1, - const int n2, - const int i1, - U& mu, // NOLINT - U& sigma2, // NOLINT - U* buf, - bool rms_only) { - // Assumptions: - // 1) blockDim.x == WARP_SIZE - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - U count = U(0); - mu = U(0); - sigma2 = U(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const T* lvals = vals + i1 * n2; - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - U curr = static_cast(lvals[l + k]); - if (!rms_only) { - cuWelfordOnlineSum(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - } - for (; l < n2; ++l) { - U curr = static_cast(lvals[l]); - if (!rms_only) { - cuWelfordOnlineSum(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - // intra-warp reductions - #ifdef PADDLE_WITH_HIP - for (int l = 0; l <= 5; ++l) - #else - for (int l = 0; l <= 4; ++l) - #endif - { - #ifdef PADDLE_WITH_HIP - int srcLaneB = (threadIdx.x + (1 << l)) & 63; - #else - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - #endif - U sigma2B = WARP_SHFL(sigma2, srcLaneB); - if (!rms_only) { - U muB = WARP_SHFL(mu, srcLaneB); - U countB = WARP_SHFL(count, srcLaneB); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - U* ubuf = (U*)buf; // NOLINT - U* ibuf = (U*)(ubuf + blockDim.y); // NOLINT - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - if (!rms_only) { - ubuf[2 * wrt_y] = mu; - ibuf[wrt_y] = count; - } - ubuf[2 * wrt_y + 1] = sigma2; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - U sigma2B = ubuf[2 * threadIdx.y + 1]; - if (!rms_only) { - U muB = ubuf[2 * threadIdx.y]; - U countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - ubuf[0] = mu; - } - ubuf[1] = sigma2; - } - __syncthreads(); - if (!rms_only) { - mu = ubuf[0]; - } - sigma2 = ubuf[1] / U(n2); - // don't care about final value of count, we know count == n2 - } else { - if (!rms_only) { - mu = WARP_SHFL(mu, 0); - } - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2 / U(n2), 0); - } - } -} - -template <> -__device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals, - const int n1, - const int n2, - const int i1, - float& mu, // NOLINT - float& sigma2, // NOLINT - float* buf, - bool rms_only) { - // Assumptions: - // 1) blockDim.x == WARP_SIZE - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - float count = 0.0f; - mu = float(0); // NOLINT - sigma2 = float(0); // NOLINT - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const auto* lvals = vals + i1 * n2; - int l = 8 * thrx; - if ((((size_t)lvals) & 3) != 0) { // NOLINT - // 16 bit alignment - // first thread consumes first point - if (thrx == 0) { - float curr = static_cast(lvals[0]); - if (!rms_only) { - cuWelfordOnlineSum(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - ++l; - } - // at this point, lvals[l] are 32 bit aligned for all threads. - for (; l + 7 < n2; l += 8 * numx) { - for (int k = 0; k < 8; k += 2) { - float2 curr = __half22float2(*((__half2*)(lvals + l + k))); // NOLINT - if (!rms_only) { - cuWelfordOnlineSum(curr.x, mu, sigma2, count); - cuWelfordOnlineSum(curr.y, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr.x, sigma2); - cuRMSOnlineSum(curr.y, sigma2); - } - } - } - for (; l < n2; ++l) { - float curr = static_cast(lvals[l]); - if (!rms_only) { - cuWelfordOnlineSum(curr, mu, sigma2, count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - // intra-warp reductions - #ifdef PADDLE_WITH_HIP - for (int l = 0; l <= 5; ++l) - #else - for (int l = 0; l <= 4; ++l) - #endif - { - #ifdef PADDLE_WITH_HIP - int srcLaneB = (threadIdx.x + (1 << l)) & 63; - #else - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - #endif - float sigma2B = WARP_SHFL(sigma2, srcLaneB); - if (!rms_only) { - float muB = WARP_SHFL(mu, srcLaneB); - float countB = WARP_SHFL(count, srcLaneB); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - float* ubuf = (float*)buf; // NOLINT - float* ibuf = (float*)(ubuf + blockDim.y); // NOLINT - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2 * wrt_y + 1] = sigma2; - if (!rms_only) { - ubuf[2 * wrt_y] = mu; - ibuf[wrt_y] = count; - } - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - float sigma2B = ubuf[2 * threadIdx.y + 1]; - if (!rms_only) { - float muB = ubuf[2 * threadIdx.y]; - float countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - ubuf[0] = mu; - } - ubuf[1] = sigma2; - } - __syncthreads(); - if (!rms_only) { - mu = ubuf[0]; - } - sigma2 = ubuf[1] / float(n2); // NOLINT - // don't care about final value of count, we know count == n2 - } else { - if (!rms_only) { - mu = WARP_SHFL(mu, 0); - } - sigma2 = WARP_SHFL(sigma2 / float(n2), 0); // NOLINT - } - } -} - -template __device__ -U rsqrt(U v) { - return U(1) / sqrt(v); -} -template <> __device__ -float rsqrt(float v) { - return rsqrtf(v); -} -template <> __device__ -double rsqrt(double v) { - return rsqrt(v); -} - -namespace { // NOLINT -// This is the un-specialized struct. Note that we prevent instantiation of -// this struct by putting an undefined symbol in the function body so it won't -// compile. -// template -// struct SharedMemory -// { -// // Ensure that we won't compile any un-specialized types -// __device__ T *getPointer() -// { -// extern __device__ void error(void); -// error(); -// return NULL; -// } -// }; -// https://github.com/NVIDIA/apex/issues/246 -template -struct SharedMemory; - -template <> -struct SharedMemory { - __device__ float* getPointer() { - extern __shared__ float s_float[]; - return s_float; - } -}; - -} // namespace - -template -__device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, - U* __restrict__ mean, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const V* __restrict__ beta, - bool rms_only) { - // Assumptions: - // 1) blockDim.x == WARP_SIZE - // 2) Tensors are contiguous - // - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - SharedMemory shared; - U* buf = shared.getPointer(); - U mu, sigma2; - cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only); - const T* lvals = vals + i1 * n2; - V* ovals = output_vals + i1 * n2; - U c_invvar = rsqrt(sigma2 + epsilon); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL && (beta != NULL || rms_only)) { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - if (!rms_only) { - ovals[i] = - gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; - } else { - ovals[i] = gamma[i] * static_cast(c_invvar * curr); - } - } - } else { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - if (!rms_only) { - ovals[i] = static_cast(c_invvar * (curr - mu)); - } else { - ovals[i] = static_cast(c_invvar * curr); - } - } - } - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - mean[i1] = mu; - } - invvar[i1] = c_invvar; - } - __syncthreads(); - } -} - -template -__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, - U* __restrict__ mean, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const V* __restrict__ beta) { - cuApplyLayerNorm_( - output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false); -} - - -template -__global__ void cuApplyRMSNorm(V* __restrict__ output_vals, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma) { - cuApplyLayerNorm_( - output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); -} - -template -__device__ void cuLoadWriteStridedInputs(const int i1_block, - const int thr_load_row_off, - const int thr_load_col_off, - const int i2_off, - const int row_stride, - U* warp_buf1, - U* warp_buf2, - const T* input, - const V* dout, - const int i1_end, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - bool rms_only) { - int i1 = i1_block + thr_load_row_off; - if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1 * n2 + i2; - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - if (!rms_only) { - warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = - curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf2[write_idx] = curr_dout * (curr_input)*curr_invvar; - } - } else { - if (!rms_only) { - warp_buf1[write_idx] = U(0); - } - warp_buf2[write_idx] = U(0); - } - } - } else { - for (int k = 0; k < blockDim.y; ++k) { - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (!rms_only) { - warp_buf1[write_idx] = U(0); - } - warp_buf2[write_idx] = U(0); - } - } -} - -template -__device__ void cuLoadAddStridedInputs(const int i1_block, - const int thr_load_row_off, - const int thr_load_col_off, - const int i2_off, - const int row_stride, - U* warp_buf1, - U* warp_buf2, - const T* input, - const V* dout, - const int i1_end, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - bool rms_only) { - int i1 = i1_block + thr_load_row_off; - if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1 * n2 + i2; - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - if (!rms_only) { - warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += - curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf2[write_idx] += curr_dout * (curr_input)*curr_invvar; - } - } - } - } -} - -template -__global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, - const T* __restrict__ input, - const int n1, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - U epsilon, - U* part_grad_gamma, - U* part_grad_beta, - bool rms_only) { - const int numsegs_n1 = - (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); - const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; - const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; - const int i1_beg_plus_one = - (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; - const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; - const int row_stride = blockDim.x + 1; - const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); - const int thr_load_row_off = - (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; - const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; - SharedMemory shared; - U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * - // blockDim.y + (blockDim.y - - // 1)*(blockDim.x/blockDim.y) elements - U* warp_buf1 = (U*)buf; // NOLINT - U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; - // compute partial sums from strided inputs - // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg, - thr_load_row_off, - thr_load_col_off, - i2_off, - row_stride, - warp_buf1, - warp_buf2, - input, - dout, - i1_end, - n2, - mean, - invvar, - rms_only); - for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; - i1_block += blockDim.y * blockDim.y) { - cuLoadAddStridedInputs(i1_block, - thr_load_row_off, - thr_load_col_off, - i2_off, - row_stride, - warp_buf1, - warp_buf2, - input, - dout, - i1_end, - n2, - mean, - invvar, - rms_only); - } - __syncthreads(); - // inter-warp reductions - // sum within each warp - U acc1 = U(0); - U acc2 = U(0); - for (int k = 0; k < blockDim.y; ++k) { - int row1 = threadIdx.y + k * blockDim.y; - int idx1 = row1 * row_stride + threadIdx.x; - if (!rms_only) { - acc1 += warp_buf1[idx1]; - } - acc2 += warp_buf2[idx1]; - } - - if (!rms_only) { - warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; - } - warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; - __syncthreads(); - // sum all warps - for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { - if (threadIdx.y < offset) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + offset; - int idx1 = row1 * row_stride + threadIdx.x; - int idx2 = row2 * row_stride + threadIdx.x; - if (!rms_only) { - warp_buf1[idx1] += warp_buf1[idx2]; - } - warp_buf2[idx1] += warp_buf2[idx2]; - } - __syncthreads(); - } - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (threadIdx.y == 0 && i2 < n2) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + 1; - int idx1 = row1 * row_stride + threadIdx.x; - int idx2 = row2 * row_stride + threadIdx.x; - if (!rms_only) { - part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; - } - part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; - } -} - -template -__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, - const U* part_grad_beta, - const int part_size, - const int n1, - const int n2, - V* grad_gamma, - V* grad_beta, - bool rms_only) { - // sum partial gradients for gamma and beta - SharedMemory shared; - U* buf = shared.getPointer(); - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (i2 < n2) { - // each warp does sequential reductions until reduced part_size is num_warps - int num_warp_reductions = part_size / blockDim.y; - U sum_gamma = U(0); - U sum_beta = U(0); - const U* part_grad_gamma_ptr = - part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; - const U* part_grad_beta_ptr = - part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; - for (int warp_offset = 0; warp_offset < num_warp_reductions; - ++warp_offset) { - sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; - if (!rms_only) { - sum_beta += part_grad_beta_ptr[warp_offset * n2]; - } - } - // inter-warp reductions - const int nbsize3 = blockDim.x * blockDim.y / 2; - for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { - // top half write to shared memory - if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[write_idx] = sum_gamma; - if (!rms_only) { - buf[write_idx + nbsize3] = sum_beta; - } - } - __syncthreads(); - // bottom half sums - if (threadIdx.y < offset) { - const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; - sum_gamma += buf[read_idx]; - if (!rms_only) { - sum_beta += buf[read_idx + nbsize3]; - } - } - __syncthreads(); - } - // write out fully summed gradients - if (threadIdx.y == 0) { - grad_gamma[i2] = sum_gamma; - if (!rms_only) { - grad_beta[i2] = sum_beta; - } - } - } -} - -template -__global__ void cuComputeGradInput(const V* __restrict__ dout, - const T* __restrict__ input, - const int n1, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - U epsilon, - const V* gamma, - T* grad_input, - bool rms_only) { - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - U sum_loss1 = U(0); - U sum_loss2 = U(0); - U c_mean; - if (!rms_only) { - c_mean = mean[i1]; - } - const U c_invvar = invvar[i1]; - const T* k_input = input + i1 * n2; - const V* k_dout = dout + i1 * n2; - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL) { - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); - const U c_loss = static_cast(k_dout[l + k]); - const U gamma_tmp = static_cast(gamma[l + k]); - if (!rms_only) { - sum_loss1 += c_loss * gamma_tmp; - sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * gamma_tmp * (c_h)*c_invvar; - } - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - const U gamma_tmp = static_cast(gamma[l]); - if (!rms_only) { - sum_loss1 += c_loss * gamma_tmp; - sum_loss2 += c_loss * gamma_tmp * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * gamma_tmp * (c_h)*c_invvar; - } - } - } else { - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); - const U c_loss = static_cast(k_dout[l + k]); - if (!rms_only) { - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * (c_h)*c_invvar; - } - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - if (!rms_only) { - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * (c_h)*c_invvar; - } - } - } - // intra-warp reductions - for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { - if (!rms_only) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); - } - sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); - } - // inter-warp reductions - if (blockDim.y > 1) { - SharedMemory shared; - U* buf = shared.getPointer(); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - if (!rms_only) { - buf[2 * wrt_i] = sum_loss1; - } - buf[2 * wrt_i + 1] = sum_loss2; - } - __syncthreads(); - // lower half merges - if (threadIdx.y < offset) { - const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - if (!rms_only) { - sum_loss1 += buf[2 * read_i]; - } - sum_loss2 += buf[2 * read_i + 1]; - } - __syncthreads(); - } - if (threadIdx.y == 0) { - if (!rms_only) { - buf[2 * threadIdx.x] = sum_loss1; - } - buf[2 * threadIdx.x + 1] = sum_loss2; - } - __syncthreads(); - if (threadIdx.y != 0) { - if (!rms_only) { - sum_loss1 = buf[2 * threadIdx.x]; - } - sum_loss2 = buf[2 * threadIdx.x + 1]; - } - } - // all threads now have the two sums over l - U fH = (U)n2; - U term1 = (U(1) / fH) * c_invvar; - T* k_grad_input = grad_input + i1 * n2; - if (gamma != NULL) { - for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * static_cast(gamma[l]); - if (!rms_only) { - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - } else { - f_grad_input -= (c_h)*c_invvar * sum_loss2; - } - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); - } - } else { - for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss; - if (!rms_only) { - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - } else { - f_grad_input -= (c_h)*c_invvar * sum_loss2; - } - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); - } - } - // prevent race where buf is written again before reads are done - __syncthreads(); - } -} - -#ifdef PADDLE_WITH_HIP -static hipDeviceProp_t GetDevicePropImpl() { - int device = -1; - PD_CHECK(hipGetDevice(&device) == hipSuccess); - hipDeviceProp_t prop; - PD_CHECK(hipGetDeviceProperties(&prop, device) == hipSuccess); - return prop; -} - -static hipDeviceProp_t* GetDeviceProp() { - static auto prop = GetDevicePropImpl(); - return ∝ -} - -#else - -static cudaDeviceProp GetDevicePropImpl() { - int device = -1; - PD_CHECK(cudaGetDevice(&device) == cudaSuccess); - cudaDeviceProp prop; - PD_CHECK(cudaGetDeviceProperties(&prop, device) == cudaSuccess); - return prop; -} - -static cudaDeviceProp* GetDeviceProp() { - static auto prop = GetDevicePropImpl(); - return ∝ -} -#endif - -template -#ifdef PADDLE_WITH_HIP -void HostApplyLayerNorm(V* output, - U* mean, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma, - const V* beta, - hipStream_t stream) -#else -void HostApplyLayerNorm(V* output, - U* mean, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma, - const V* beta, - cudaStream_t stream) -#endif -{ - #ifdef PADDLE_WITH_HIP - const dim3 threads(64, 4, 1); - #else - const dim3 threads(32, 4, 1); - #endif - const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; - cuApplyLayerNorm<<>>( - output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); -} - -template -#ifdef PADDLE_WITH_HIP -void HostApplyRMSNorm(V* output, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma, - hipStream_t stream) -#else -void HostApplyRMSNorm(V* output, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma, - cudaStream_t stream) -#endif -{ - // auto stream = at::cuda::getCurrentCUDAStream().stream(); - #ifdef PADDLE_WITH_HIP - const dim3 threads(64, 4, 1); - #else - const dim3 threads(32, 4, 1); - #endif - // const uint64_t maxGridY = - // at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; - cuApplyRMSNorm<<>>( - output, invvar, input, n1, n2, U(epsilon), gamma); -} - -static void cuda_layer_norm(const paddle::Tensor& x, - const paddle::Tensor& scale, - const paddle::Tensor& bias, - int rows, - int cols, - float epsilon, - paddle::Tensor* y, - paddle::Tensor* mean, - paddle::Tensor* invvar) { - DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - x.type(), - y->type(), - "cuda_layer_norm_kernel", - HostApplyLayerNorm(y->data(), - mean->data(), - invvar->data(), - const_cast(x.data()), - rows, - cols, - epsilon, - const_cast(scale.data()), - const_cast(bias.data()), - x.stream())); -} - -static void cuda_rms_norm(const paddle::Tensor& x, - const paddle::Tensor& scale, - int rows, - int cols, - float epsilon, - paddle::Tensor* y, - paddle::Tensor* invvar) { - DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - x.type(), - y->type(), - "cuda_rms_norm_kernel", - HostApplyRMSNorm(y->data(), - invvar->data(), - const_cast(x.data()), - rows, - cols, - epsilon, - const_cast(scale.data()), - x.stream())); -} - -template -#ifdef PADDLE_WITH_HIP -void HostLayerNormGradient(const V* dout, - const U* mean, - const U* invvar, - const paddle::Tensor& input, - int n1, - int n2, - const V* gamma, - const V* beta, - double epsilon, - T* grad_input, - V* grad_gamma, - V* grad_beta, - hipStream_t stream) -#else -void HostLayerNormGradient(const V* dout, - const U* mean, - const U* invvar, - const paddle::Tensor& input, - int n1, - int n2, - const V* gamma, - const V* beta, - double epsilon, - T* grad_input, - V* grad_gamma, - V* grad_beta, - cudaStream_t stream) -#endif -{ - if (gamma != NULL && beta != NULL) { - // compute grad_gamma(j) and grad_beta(j) - const int part_size = 16; - const dim3 threads2(32, 4, 1); - const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); - const int nshared2_a = - 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(U); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - auto place = input.place(); - paddle::Tensor part_grad_gamma = - paddle::empty({part_size, n2}, paddle::DataType::FLOAT32, place); - paddle::Tensor part_grad_beta = paddle::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( - dout, - input.data(), - n1, - n2, - mean, - invvar, - U(epsilon), - part_grad_gamma.data(), - part_grad_beta.data(), - false); - - const dim3 threads3(32, 8, 1); - const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); - const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta<<>>( - part_grad_gamma.data(), - part_grad_beta.data(), - part_size, - n1, - n2, - grad_gamma, - grad_beta, - false); - } - - // compute grad_input - const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; - const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - const dim3 threads1(32, 4, 1); - int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput<<>>(dout, - input.data(), - n1, - n2, - mean, - invvar, - U(epsilon), - gamma, - grad_input, - false); -} - -template -#ifdef PADDLE_WITH_HIP -void HostRMSNormGradient(const V* dout, - const U* invvar, - const paddle::Tensor& input, - int n1, - int n2, - const V* gamma, - double epsilon, - T* grad_input, - V* grad_gamma, - hipStream_t stream) -#else -void HostRMSNormGradient(const V* dout, - const U* invvar, - const paddle::Tensor& input, - int n1, - int n2, - const V* gamma, - double epsilon, - T* grad_input, - V* grad_gamma, - cudaStream_t stream) -#endif -{ - if (gamma != NULL) { - const int part_size = 16; - const dim3 threads2(32, 4, 1); - const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); - const int nshared2_a = - 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(U); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - auto place = input.place(); - paddle::Tensor part_grad_gamma = - paddle::empty({part_size, n2}, paddle::DataType::FLOAT32, place); - cuComputePartGradGammaBeta<<>>( - dout, - input.data(), - n1, - n2, - invvar, // unused - invvar, - U(epsilon), - part_grad_gamma.data(), - part_grad_gamma.data(), /* unused */ - true); - - const dim3 threads3(32, 8, 1); - const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); - const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta<<>>( - part_grad_gamma.data(), - part_grad_gamma.data(), /* unused */ - part_size, - n1, - n2, - grad_gamma, - grad_gamma, /* unused */ - true); - } - - // compute grad_input - const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1]; - const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - const dim3 threads1(32, 4, 1); - int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, - input.data(), - n1, - n2, - invvar, /* unused */ - invvar, - U(epsilon), - gamma, - grad_input, - true); -} - -static void cuda_layer_norm_gradient(const paddle::Tensor& x, - const paddle::Tensor& scale, - const paddle::Tensor& bias, - const paddle::Tensor& mean, - const paddle::Tensor& invvar, - const paddle::Tensor& dy, - int rows, - int cols, - float epsilon, - paddle::Tensor* grad_x, - paddle::Tensor* grad_scale, - paddle::Tensor* grad_bias) { - DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - x.type(), - scale.type(), - "cuda_layer_norm_gradient_kernel", - HostLayerNormGradient( - dy.data(), - mean.data(), - invvar.data(), - x, - rows, - cols, - scale.data(), - bias.data(), - epsilon, - grad_x->data(), - grad_scale->data(), - grad_bias->data(), - x.stream())); -} - - -static void cuda_rms_norm_gradient(const paddle::Tensor& x, - const paddle::Tensor& scale, - const paddle::Tensor& invvar, - const paddle::Tensor& dy, - int rows, - int cols, - float epsilon, - paddle::Tensor* grad_x, - paddle::Tensor* grad_scale) { - DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - x.type(), - scale.type(), - "cuda_rms_norm_gradient_kernel", - HostRMSNormGradient( - dy.data(), - invvar.data(), - x, - rows, - cols, - scale.data(), - epsilon, - grad_x->data(), - grad_scale->data(), - x.stream())); -} diff --git a/ops/csrc/paddle_bwd_ops/add_bwd.cc b/ops/csrc/paddle_bwd_ops/add_bwd.cc deleted file mode 100644 index 9f32d98f418d..000000000000 --- a/ops/csrc/paddle_bwd_ops/add_bwd.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "paddle/extension.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/ddim.h" - -using paddle::Tensor; - -namespace paddle { -namespace experimental { - -PADDLE_API void add_grad(const Tensor& x, - const Tensor& y, - const Tensor& out_grad, - int axis, - Tensor* dx, - Tensor* dy); - -} -} // namespace paddle - -namespace phi { - - - -} // namespace phi - -std::vector SRAddBwd(const Tensor& x, - const Tensor& weight, - const Tensor& bias, - const Tensor& out_grad, - int axis); - -std::vector SRAddBwd(const Tensor& x, - const Tensor& weight, - const Tensor& bias, - const Tensor& out_grad, - int axis){ - std::vector res(2); - - std::vector dims_x = phi::vectorize(x.dims()); - std::vector dims_w = phi::vectorize(weight.dims()); - auto ndims_x = dims_x.size(); - auto ndims_w = dims_w.size(); - - PADDLE_ENFORCE_GT(ndims_x, - 1UL, - phi::errors::InvalidArgument( - "The Input(x) dims size must be greater than 1. Other cases are not supported")); - - PADDLE_ENFORCE_GT(ndims_w, - 1UL, - phi::errors::InvalidArgument( - "The Input(w) dims size must be greater than 1. Other cases are not supported")); - - size_t M, N; - M = dims_x[ndims_x - 2]; - N = dims_w[ndims_w - 1]; - - std::vector new_dims; - if (ndims_x > ndims_w) { - new_dims.assign(dims_x.begin(), dims_x.end() - 2); - } else if (ndims_x < ndims_w) { - new_dims.assign(dims_w.begin(), dims_w.end() - 2); - } else { - new_dims.reserve(ndims_x); - for (size_t i = 0; i < ndims_x - 2; ++i) { - new_dims.push_back(std::max(dims_x[i], dims_w[i])); - } - } - - new_dims.push_back(M); - new_dims.push_back(N); - auto ddim_out = phi::make_ddim(new_dims); - - phi::DenseTensor* new_x = new phi::DenseTensor(); - new_x->Resize(ddim_out); - Tensor tensor_x(std::make_shared(*new_x)); - - paddle::experimental::add_grad(tensor_x, bias, out_grad, axis, &res[0], &res[1]); - - delete new_x; - return res; -} - -std::vector SRAddBwdDtype(paddle::DataType x_dtype, - paddle::DataType y_dtype, - paddle::DataType z_dtype){ - return {x_dtype, y_dtype, z_dtype}; -} - -std::vector> SRAddBwdInferShape(std::vector x_shape, - std::vector y_shape, - std::vector z_shape){ - return {x_shape, y_shape, z_shape}; -} - -PD_BUILD_OP(add_bwd) - .Inputs({"x", "weight", "bias", "out_grad"}) - .Outputs({"xw_grad", "bias_grad"}) - .Attrs({"axis: int"}) - .SetKernelFn(PD_KERNEL(SRAddBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(SRAddBwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(SRAddBwdDtype)); \ No newline at end of file diff --git a/ops/csrc/paddle_bwd_ops/flash_attn_bwd.cc b/ops/csrc/paddle_bwd_ops/flash_attn_bwd.cc deleted file mode 100644 index 3acd55cbdbe9..000000000000 --- a/ops/csrc/paddle_bwd_ops/flash_attn_bwd.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include -#include - -using paddle::Tensor; - -namespace paddle { -namespace experimental { - -PADDLE_API void flash_attn_grad(const Tensor& q, - const Tensor& k, - const Tensor& v, - const Tensor& out, - const Tensor& softmax_lse, - const Tensor& seed_offset, - const paddle::optional &attn_mask, - const Tensor& out_grad, - float dropout, - bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad); - -} -} // namespace paddle - - - -std::vector SRFlashAttnBwd(const Tensor &q, - const Tensor &k, - const Tensor &v, - const Tensor &out, - const Tensor &softmax_lse, - const Tensor &seed_offset, - const paddle::optional &attn_mask, - const Tensor &out_grad, - float dropout, - bool causal); - - -std::vector SRFlashAttnBwd(const Tensor &q, - const Tensor &k, - const Tensor &v, - const Tensor &out, - const Tensor &softmax_lse, - const Tensor &seed_offset, - const paddle::optional &attn_mask, - const Tensor &out_grad, - float dropout, - bool causal){ - std::vector res(3); - paddle::experimental::flash_attn_grad(q, k, v, out, softmax_lse, seed_offset, attn_mask, - out_grad, dropout, causal, &res[0], &res[1], - &res[2]); - return res; -} - - - -std::vector SRFlashAttnBwdDtype(paddle::DataType q_dtype, - paddle::DataType k_dtype, - paddle::DataType v_dtype) { - return {q_dtype, k_dtype, v_dtype}; - -} - - -std::vector> SRFlashAttnBwdInferShape( - std::vector q_shape, std::vector k_shape, - std::vector v_shape) { - return {q_shape, k_shape, v_shape}; -} - - -PD_BUILD_OP(flash_attn_bwd) - .Inputs({"q", "k", "v", "out", "softmax_lse", "seed_offset", "attn_mask", "out_grad"}) - .Outputs({"q_grad", "k_grad", "v_grad"}) - .Attrs({"dropout: float", "causal: bool"}) - .SetKernelFn(PD_KERNEL(SRFlashAttnBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnBwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnBwdDtype)); diff --git a/ops/csrc/paddle_bwd_ops/flash_attn_with_sparse_mask_bwd.cc b/ops/csrc/paddle_bwd_ops/flash_attn_with_sparse_mask_bwd.cc deleted file mode 100644 index a3ef2a5b84c4..000000000000 --- a/ops/csrc/paddle_bwd_ops/flash_attn_with_sparse_mask_bwd.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include - -using paddle::Tensor; - -namespace paddle { -namespace experimental { - -PADDLE_API void flash_attn_with_sparse_mask_grad(const Tensor& q, - const Tensor& k, - const Tensor& v, - const Tensor& attn_mask_start_row_indices, - const Tensor& out, - const Tensor& softmax_lse, - const Tensor& seed_offset, - const Tensor& out_grad, - float dropout, - bool causal, int attn_mask_start_row, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad); -} -} // namespace paddle - - - -std::vector SRFlashAttnWithSparseMaskBwd(const Tensor &q, - const Tensor &k, - const Tensor &v, - const Tensor &attn_mask_start_row_indices, - const Tensor &out, - const Tensor &softmax_lse, - const Tensor &seed_offset, - const Tensor &out_grad, - float dropout, - bool causal, int attn_mask_start_row); - - -std::vector SRFlashAttnWithSparseMaskBwd(const Tensor &q, - const Tensor &k, - const Tensor &v, - const Tensor &attn_mask_start_row_indices, - const Tensor &out, - const Tensor &softmax_lse, - const Tensor &seed_offset, - const Tensor &out_grad, - float dropout, - bool causal, int attn_mask_start_row){ - std::vector res(3); - paddle::experimental::flash_attn_with_sparse_mask_grad(q, k, v, attn_mask_start_row_indices, out, softmax_lse, seed_offset, - out_grad, dropout, causal, attn_mask_start_row, &res[0], &res[1], &res[2]); - return res; -} - - - -std::vector SRFlashAttnWithSparseMaskBwdDtype(paddle::DataType q_dtype, - paddle::DataType k_dtype, - paddle::DataType v_dtype, - paddle::DataType attn_mask_start_row_indices_dtype) { - return {q_dtype, k_dtype, v_dtype, attn_mask_start_row_indices_dtype}; - -} - - -std::vector> SRFlashAttnWithSparseMaskBwdInferShape( - std::vector q_shape, std::vector k_shape, - std::vector v_shape, std::vector attn_mask_start_row_indices_shape) { - return {q_shape, k_shape, v_shape, attn_mask_start_row_indices_shape}; -} - - -PD_BUILD_OP(flash_attn_with_sparse_mask_bwd) - .Inputs({"q", "k", "v", "attn_mask_start_row_indices", "out", "softmax_lse", "seed_offset", "out_grad"}) - .Outputs({"q_grad", "k_grad", "v_grad"}) - .Attrs({"dropout: float", "causal: bool", "attn_mask_start_row: int"}) - .SetKernelFn(PD_KERNEL(SRFlashAttnWithSparseMaskBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnWithSparseMaskBwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnWithSparseMaskBwdDtype)); diff --git a/ops/csrc/paddle_bwd_ops/flashmask_attn_bwd.cc b/ops/csrc/paddle_bwd_ops/flashmask_attn_bwd.cc deleted file mode 100644 index 225f2451b98c..000000000000 --- a/ops/csrc/paddle_bwd_ops/flashmask_attn_bwd.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include - -using paddle::Tensor; - -namespace paddle { -namespace experimental { - -PADDLE_API void flashmask_attention_grad(const Tensor& q, - const Tensor& k, - const Tensor& v, - const Tensor& startend_row_indices, - const Tensor& out, - const Tensor& softmax_lse, - const Tensor& seed_offset, - const Tensor& out_grad, - float dropout, - bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad); -} -} // namespace paddle - - - -std::vector SRFlashMaskAttnBwd(const Tensor &q, - const Tensor &k, - const Tensor &v, - const Tensor &startend_row_indices, - const Tensor &out, - const Tensor &softmax_lse, - const Tensor &seed_offset, - const Tensor &out_grad, - float dropout, - bool causal); - - -std::vector SRFlashMaskAttnBwd(const Tensor &q, - const Tensor &k, - const Tensor &v, - const Tensor &startend_row_indices, - const Tensor &out, - const Tensor &softmax_lse, - const Tensor &seed_offset, - const Tensor &out_grad, - float dropout, - bool causal){ - std::vector res(3); - paddle::experimental::flashmask_attention_grad(q, k, v, startend_row_indices, out, softmax_lse, seed_offset, - out_grad, dropout, causal, &res[0], &res[1], &res[2]); - return res; -} - - - -std::vector SRFlashMaskAttnBwdDtype(paddle::DataType q_dtype, - paddle::DataType k_dtype, - paddle::DataType v_dtype, - paddle::DataType startend_row_indices_dtype) { - return {q_dtype, k_dtype, v_dtype, startend_row_indices_dtype}; - -} - - -std::vector> SRFlashMaskAttnBwdInferShape( - std::vector q_shape, std::vector k_shape, - std::vector v_shape, std::vector startend_row_indices_shape) { - return {q_shape, k_shape, v_shape, startend_row_indices_shape}; -} - - -PD_BUILD_OP(flashmask_attn_bwd) - .Inputs({"q", "k", "v", "startend_row_indices", "out", "softmax_lse", "seed_offset", "out_grad"}) - .Outputs({"q_grad", "k_grad", "v_grad"}) - .Attrs({"dropout: float", "causal: bool"}) - .SetKernelFn(PD_KERNEL(SRFlashMaskAttnBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(SRFlashMaskAttnBwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(SRFlashMaskAttnBwdDtype)); diff --git a/ops/csrc/paddle_bwd_ops/matmul_bwd.cc b/ops/csrc/paddle_bwd_ops/matmul_bwd.cc deleted file mode 100644 index 9fff8204a227..000000000000 --- a/ops/csrc/paddle_bwd_ops/matmul_bwd.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" -#include - -using paddle::Tensor; - -namespace paddle { -namespace experimental { - -PADDLE_API void matmul_grad(const Tensor& x, - const Tensor& y, - const Tensor& out_grad, - bool transpose_x, - bool transpose_y, - Tensor* dx, - Tensor* dy); - -} -} // namespace paddle - -std::vector SRMatmulBwd(const Tensor& x, - const Tensor& y, - const Tensor& out_grad, - bool transpose_x, - bool transpose_y); - -std::vector SRMatmulBwd(const Tensor& x, - const Tensor& y, - const Tensor& out_grad, - bool transpose_x, - bool transpose_y){ - std::vector res(2); - paddle::experimental::matmul_grad(x, y, out_grad, transpose_x, transpose_y, - &res[0], &res[1]); - return res; -} - -std::vector SRMatmulBwdDtype(paddle::DataType x_dtype, - paddle::DataType y_dtype){ - return {x_dtype, y_dtype}; -} - -std::vector> SRMatmulBwdInferShape(std::vector x_shape, - std::vector y_shape){ - return {x_shape, y_shape}; -} - -PD_BUILD_OP(matmul_bwd) - .Inputs({"x", "y", "out_grad"}) - .Outputs({"x_grad", "y_grad"}) - .Attrs({"transpose_x: bool", "transpose_y: bool"}) - .SetKernelFn(PD_KERNEL(SRMatmulBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(SRMatmulBwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(SRMatmulBwdDtype)); \ No newline at end of file diff --git a/ops/csrc/selective_scan/reverse_scan.cuh b/ops/csrc/selective_scan/reverse_scan.cuh deleted file mode 100644 index d19397879bda..000000000000 --- a/ops/csrc/selective_scan/reverse_scan.cuh +++ /dev/null @@ -1,415 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#ifndef USE_ROCM - #include - - #include - #include - #include - // #include -#else - #include - namespace cub = hipcub; -#endif -#include "uninitialized_copy.cuh" - -/** - * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ReductionOp> -__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { - static_assert(LENGTH > 0); - T retval = input[LENGTH - 1]; - #pragma unroll - for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } - return retval; -} - -/** - * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ScanOp> -__device__ __forceinline__ T ThreadReverseScanInclusive( - const T (&input)[LENGTH], - T (&output)[LENGTH], - ScanOp scan_op, - const T postfix) -{ - T inclusive = postfix; - #pragma unroll - for (int i = LENGTH - 1; i >= 0; --i) { - inclusive = scan_op(inclusive, input[i]); - output[i] = inclusive; - } - return inclusive; -} - -/** - * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ScanOp> -__device__ __forceinline__ T ThreadReverseScanExclusive( - const T (&input)[LENGTH], - T (&output)[LENGTH], - ScanOp scan_op, - const T postfix) -{ - // Careful, output maybe be aliased to input - T exclusive = postfix; - T inclusive; - #pragma unroll - for (int i = LENGTH - 1; i >= 0; --i) { - inclusive = scan_op(exclusive, input[i]); - output[i] = exclusive; - exclusive = inclusive; - } - return inclusive; -} - - -/** - * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. - * - * LOGICAL_WARP_THREADS must be a power-of-two - */ -template < - typename T, ///< Data type being scanned - int LOGICAL_WARP_THREADS ///< Number of threads per logical warp - > -struct WarpReverseScan { - //--------------------------------------------------------------------- - // Constants and type definitions - //--------------------------------------------------------------------- - - /// Whether the logical warp size and the PTX warp size coincide - - // In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size() - // While in cub, it's defined as a macro that takes a redundant unused argument. - #ifndef USE_ROCM - #define WARP_THREADS CUB_WARP_THREADS(0) - #else - #define WARP_THREADS HIPCUB_WARP_THREADS - #endif - static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS); - /// The number of warp scan steps - static constexpr int STEPS = cub::Log2::VALUE; - static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); - - - //--------------------------------------------------------------------- - // Thread fields - //--------------------------------------------------------------------- - - /// Lane index in logical warp - unsigned int lane_id; - - /// Logical warp index in 32-thread physical warp - unsigned int warp_id; - - /// 32-thread physical warp member mask of logical warp - unsigned int member_mask; - - //--------------------------------------------------------------------- - // Construction - //--------------------------------------------------------------------- - - /// Constructor - explicit __device__ __forceinline__ - WarpReverseScan() - : lane_id(cub::LaneId()) - , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) - , member_mask(cub::WarpMask(warp_id)) - { - if (!IS_ARCH_WARP) { - lane_id = lane_id % LOGICAL_WARP_THREADS; - } - } - - - /// Broadcast - __device__ __forceinline__ T Broadcast( - T input, ///< [in] The value to broadcast - int src_lane) ///< [in] Which warp lane is to do the broadcasting - { - return cub::ShuffleIndex(input, src_lane, member_mask); - } - - - /// Inclusive scan - template - __device__ __forceinline__ void InclusiveReverseScan( - T input, ///< [in] Calling thread's input item. - T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. - ScanOpT scan_op) ///< [in] Binary scan operator - { - inclusive_output = input; - #pragma unroll - for (int STEP = 0; STEP < STEPS; STEP++) { - int offset = 1 << STEP; - T temp = cub::ShuffleDown( - inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask - ); - // Perform scan op if from a valid peer - inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset - ? inclusive_output : scan_op(temp, inclusive_output); - } - } - - /// Exclusive scan - // Get exclusive from inclusive - template - __device__ __forceinline__ void ExclusiveReverseScan( - T input, ///< [in] Calling thread's input item. - T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. - ScanOpT scan_op, ///< [in] Binary scan operator - T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. - { - T inclusive_output; - InclusiveReverseScan(input, inclusive_output, scan_op); - warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); - // initial value unknown - exclusive_output = cub::ShuffleDown( - inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask - ); - } - - /** - * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. - */ - template - __device__ __forceinline__ void ReverseScan( - T input, ///< [in] Calling thread's input item. - T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. - T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. - ScanOpT scan_op) ///< [in] Binary scan operator - { - InclusiveReverseScan(input, inclusive_output, scan_op); - // initial value unknown - exclusive_output = cub::ShuffleDown( - inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask - ); - } - -}; - -/** - * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. - */ -template < - typename T, ///< Data type being scanned - int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension - bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure - > -struct BlockReverseScan { - //--------------------------------------------------------------------- - // Types and constants - //--------------------------------------------------------------------- - - /// Constants - /// The thread block size in threads - static constexpr int BLOCK_THREADS = BLOCK_DIM_X; - - /// Layout type for padded thread block raking grid - using BlockRakingLayout = cub::BlockRakingLayout; - // The number of reduction elements is not a multiple of the number of raking threads for now - static_assert(BlockRakingLayout::UNGUARDED); - - /// Number of raking threads - static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; - /// Number of raking elements per warp synchronous raking thread - static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; - /// Cooperative work can be entirely warp synchronous - static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); - - /// WarpReverseScan utility type - using WarpReverseScan = WarpReverseScan; - - /// Shared memory storage layout type - struct _TempStorage { - typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid - }; - - - /// Alias wrapper allowing storage to be unioned - struct TempStorage : cub::Uninitialized<_TempStorage> {}; - - - //--------------------------------------------------------------------- - // Per-thread fields - //--------------------------------------------------------------------- - - // Thread fields - _TempStorage &temp_storage; - unsigned int linear_tid; - T cached_segment[SEGMENT_LENGTH]; - - - //--------------------------------------------------------------------- - // Utility methods - //--------------------------------------------------------------------- - - /// Performs upsweep raking reduction, returning the aggregate - template - __device__ __forceinline__ T Upsweep(ScanOp scan_op) { - T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); - // Read data into registers - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } - T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; - #pragma unroll - for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { - raking_partial = scan_op(raking_partial, cached_segment[i]); - } - return raking_partial; - } - - - /// Performs exclusive downsweep raking scan - template - __device__ __forceinline__ void ExclusiveDownsweep( - ScanOp scan_op, - T raking_partial) - { - T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); - // Read data back into registers - if (!MEMOIZE) { - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } - } - ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); - // Write data back to smem - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } - } - - - //--------------------------------------------------------------------- - // Constructors - //--------------------------------------------------------------------- - - /// Constructor - __device__ __forceinline__ BlockReverseScan( - TempStorage &temp_storage) - : - temp_storage(temp_storage.Alias()), - linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) - {} - - - /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. - template < - typename ScanOp, - typename BlockPostfixCallbackOp> - __device__ __forceinline__ void ExclusiveReverseScan( - T input, ///< [in] Calling thread's input item - T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) - ScanOp scan_op, ///< [in] Binary scan operator - BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. - { - if (WARP_SYNCHRONOUS) { - // Short-circuit directly to warp-synchronous scan - T block_aggregate; - WarpReverseScan warp_scan; - warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); - // Obtain warp-wide postfix in lane0, then broadcast to other lanes - T block_postfix = block_postfix_callback_op(block_aggregate); - block_postfix = warp_scan.Broadcast(block_postfix, 0); - exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); - } else { - // Place thread partial into shared memory raking grid - T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); - detail::uninitialized_copy(placement_ptr, input); - cub::CTA_SYNC(); - // Reduce parallelism down to just raking threads - if (linear_tid < RAKING_THREADS) { - WarpReverseScan warp_scan; - // Raking upsweep reduction across shared partials - T upsweep_partial = Upsweep(scan_op); - // Warp-synchronous scan - T exclusive_partial, block_aggregate; - warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); - // Obtain block-wide postfix in lane0, then broadcast to other lanes - T block_postfix = block_postfix_callback_op(block_aggregate); - block_postfix = warp_scan.Broadcast(block_postfix, 0); - // Update postfix with warpscan exclusive partial - T downsweep_postfix = linear_tid == RAKING_THREADS - 1 - ? block_postfix : scan_op(block_postfix, exclusive_partial); - // Exclusive raking downsweep scan - ExclusiveDownsweep(scan_op, downsweep_postfix); - } - cub::CTA_SYNC(); - // Grab thread postfix from shared memory - exclusive_output = *placement_ptr; - - // // Compute warp scan in each warp. - // // The exclusive output from the last lane in each warp is invalid. - // T inclusive_output; - // WarpReverseScan warp_scan; - // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); - - // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. - // T block_aggregate; - // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); - - // // Apply warp postfix to our lane's partial - // if (warp_id != 0) { - // exclusive_output = scan_op(warp_postfix, exclusive_output); - // if (lane_id == 0) { exclusive_output = warp_postfix; } - // } - - // // Use the first warp to determine the thread block postfix, returning the result in lane0 - // if (warp_id == 0) { - // T block_postfix = block_postfix_callback_op(block_aggregate); - // if (lane_id == 0) { - // // Share the postfix with all threads - // detail::uninitialized_copy(&temp_storage.block_postfix, - // block_postfix); - - // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 - // } - // } - - // cub::CTA_SYNC(); - - // // Incorporate thread block postfix into outputs - // T block_postfix = temp_storage.block_postfix; - // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } - } - } - - - /** - * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. - */ - template < - int ITEMS_PER_THREAD, - typename ScanOp, - typename BlockPostfixCallbackOp> - __device__ __forceinline__ void InclusiveReverseScan( - T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items - T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) - ScanOp scan_op, ///< [in] Binary scan functor - BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. - { - // Reduce consecutive thread items in registers - T thread_postfix = ThreadReverseReduce(input, scan_op); - // Exclusive thread block-scan - ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); - // Inclusive scan in registers with postfix as seed - ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); - } - -}; \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan.cpp b/ops/csrc/selective_scan/selective_scan.cpp deleted file mode 100644 index 5590b8e096f5..000000000000 --- a/ops/csrc/selective_scan/selective_scan.cpp +++ /dev/null @@ -1,494 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include -#include - -#include "selective_scan.h" - -#define CHECK_SHAPE(x, ...) PD_CHECK(x.dims() == common::make_ddim({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == paddle::DataType::FLOAT16) { \ - using input_t = phi::dtype::float16; \ - __VA_ARGS__(); \ - } else if (ITYPE == paddle::DataType::BFLOAT16) { \ - using input_t = phi::dtype::bfloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == paddle::DataType::FLOAT32) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - PADDLE_THROW(#NAME, " not implemented for input type '", ITYPE, "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == paddle::DataType::FLOAT16) { \ - using weight_t = phi::dtype::float16; \ - __VA_ARGS__(); \ - } else if (WTYPE == paddle::DataType::BFLOAT16) { \ - using weight_t = phi::dtype::bfloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == paddle::DataType::FLOAT32) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - PADDLE_THROW(#NAME, " not implemented for weight type '", WTYPE, "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ - if (WTYPE == paddle::DataType::FLOAT32) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else if (WTYPE == paddle::DataType::COMPLEX64) { \ - using weight_t = phi::dtype::complex; \ - __VA_ARGS__(); \ - } else { \ - PADDLE_THROW(#NAME, " not implemented for weight type '", WTYPE, "'"); \ - } - -template -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); - -template -void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); - -void set_ssm_params_fwd(SSMParamsBase ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, - // device pointers - const paddle::Tensor u, - const paddle::Tensor delta, - const paddle::Tensor A, - const paddle::Tensor B, - const paddle::Tensor C, - const paddle::Tensor out, - const paddle::Tensor z, - const paddle::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - bool has_z, - bool delta_softplus) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.dstate = dstate; - params.n_groups = n_groups; - params.n_chunks = n_chunks; - params.dim_ngroups_ratio = dim / n_groups; - - params.delta_softplus = delta_softplus; - - params.is_variable_B = is_variable_B; - params.is_variable_C = is_variable_C; - - // Set the pointers and strides. - params.u_ptr = const_cast(u.data()); - params.delta_ptr = const_cast(delta.data()); - params.A_ptr = const_cast(A.data()); - params.B_ptr = const_cast(B.data()); - params.C_ptr = const_cast(C.data()); - params.D_ptr = const_cast(D_ptr); - params.delta_bias_ptr = const_cast(delta_bias_ptr); - params.out_ptr = const_cast(out.data()); - params.x_ptr = const_cast(x_ptr); - params.z_ptr = has_z ? const_cast(z.data()) : nullptr; - params.out_z_ptr = has_z ? const_cast(out_z.data()) : nullptr; - // All stride are in elements, not bytes. - params.A_d_stride = A.strides()[0]; - params.A_dstate_stride = A.strides()[1]; - if (!is_variable_B) { - params.B_d_stride = B.strides()[0]; - } else { - params.B_batch_stride = B.strides()[0]; - params.B_group_stride = B.strides()[1]; - } - params.B_dstate_stride = !is_variable_B ? B.strides()[1] : B.strides()[2]; - if (!is_variable_C) { - params.C_d_stride = C.strides()[0]; - } else { - params.C_batch_stride = C.strides()[0]; - params.C_group_stride = C.strides()[1]; - } - params.C_dstate_stride = !is_variable_C ? C.strides()[1] : C.strides()[2]; - params.u_batch_stride = u.strides()[0]; - params.u_d_stride = u.strides()[1]; - params.delta_batch_stride = delta.strides()[0]; - params.delta_d_stride = delta.strides()[1]; - if (has_z) { - params.z_batch_stride = z.strides()[0]; - params.z_d_stride = z.strides()[1]; - params.out_z_batch_stride = out_z.strides()[0]; - params.out_z_d_stride = out_z.strides()[1]; - } - params.out_batch_stride = out.strides()[0]; - params.out_d_stride = out.strides()[1]; -} - -void set_ssm_params_bwd(SSMParamsBwd ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, - // device pointers - const paddle::Tensor u, - const paddle::Tensor delta, - const paddle::Tensor A, - const paddle::Tensor B, - const paddle::Tensor C, - const paddle::Tensor z, - const paddle::Tensor out, - const paddle::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - const paddle::Tensor dout, - const paddle::Tensor du, - const paddle::Tensor ddelta, - const paddle::Tensor dA, - const paddle::Tensor dB, - const paddle::Tensor dC, - const paddle::Tensor dz, - void* dD_ptr, - void* ddelta_bias_ptr, - bool has_z, - bool delta_softplus, - bool recompute_out_z) { - // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z - set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, has_z ? out : dout, - has_z ? z : dout, - // If not recompute_out_z, pass dout instead of out_z. - // This won't be used by the bwd kernel - recompute_out_z ? out_z : dout, - D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); - if (!recompute_out_z) { params.out_z_ptr = nullptr; } - - // Set the pointers and strides. - params.dout_ptr = const_cast(dout.data()); - params.du_ptr = const_cast(du.data()); - params.dA_ptr = const_cast(dA.data()); - params.dB_ptr = const_cast(dB.data()); - params.dC_ptr = const_cast(dC.data()); - params.dD_ptr = const_cast(dD_ptr); - params.ddelta_ptr = const_cast(ddelta.data()); - params.ddelta_bias_ptr = const_cast(ddelta_bias_ptr); - params.dz_ptr = has_z ? const_cast(dz.data()) : nullptr; - // All stride are in elements, not bytes. - params.dout_batch_stride = dout.strides()[0]; - params.dout_d_stride = dout.strides()[1]; - params.dA_d_stride = dA.strides()[0]; - params.dA_dstate_stride = dA.strides()[1]; - if (!is_variable_B) { - params.dB_d_stride = dB.strides()[0]; - } else { - params.dB_batch_stride = dB.strides()[0]; - params.dB_group_stride = dB.strides()[1]; - } - params.dB_dstate_stride = !is_variable_B ? dB.strides()[1] : dB.strides()[2]; - if (!is_variable_C) { - params.dC_d_stride = dC.strides()[0]; - } else { - params.dC_batch_stride = dC.strides()[0]; - params.dC_group_stride = dC.strides()[1]; - } - params.dC_dstate_stride = !is_variable_C ? dC.strides()[1] : dC.strides()[2]; - params.du_batch_stride = du.strides()[0]; - params.du_d_stride = du.strides()[1]; - params.ddelta_batch_stride = ddelta.strides()[0]; - params.ddelta_d_stride = ddelta.strides()[1]; - if (has_z) { - params.dz_batch_stride = dz.strides()[0]; - params.dz_d_stride = dz.strides()[1]; - } -} - -std::vector -selective_scan_fwd(const paddle::Tensor &u, const paddle::Tensor &delta, - const paddle::Tensor &A, const paddle::Tensor &B, const paddle::Tensor &C, - const std::optional &D_, - const std::optional &z_, - const std::optional &delta_bias_, - bool delta_softplus) { - auto input_type = u.dtype(); - auto weight_type = A.dtype(); - PD_CHECK(input_type == paddle::DataType::FLOAT32 || input_type == paddle::DataType::FLOAT16 || input_type == paddle::DataType::BFLOAT16); - PD_CHECK(weight_type == paddle::DataType::FLOAT32 || weight_type == paddle::DataType::COMPLEX64); - - const bool is_variable_B = B.dims().size() >= 3; - const bool is_variable_C = C.dims().size() >= 3; - const bool is_complex = weight_type == paddle::DataType::COMPLEX64; - - PD_CHECK(delta.dtype() == input_type); - PD_CHECK(B.dtype() == (!is_variable_B ? weight_type : input_type)); - PD_CHECK(C.dtype() == (!is_variable_C ? weight_type : input_type)); - - PD_CHECK(u.is_gpu()); - PD_CHECK(delta.is_gpu()); - PD_CHECK(A.is_gpu()); - PD_CHECK(B.is_gpu()); - PD_CHECK(C.is_gpu()); - - PD_CHECK(u.strides()[u.strides().size() - 1] == 1 || u.dims()[u.dims().size() - 1] == 1); - PD_CHECK(delta.strides()[delta.strides().size() - 1] == 1 || delta.dims()[delta.dims().size() - 1] == 1); - - const auto sizes = u.dims(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.dims()[1]; - const int n_groups = is_variable_B ? B.dims()[1] : 1; - - PD_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - PD_CHECK(B.strides()[B.strides().size() - 1] == 1 || B.dims()[B.dims().size() - 1] == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - PD_CHECK(C.strides()[C.strides().size() - 1] == 1 || C.dims()[C.dims().size() - 1] == 1); - } - - if (D_.has_value()) { - auto D = D_.value(); - PD_CHECK(D.dtype() == paddle::DataType::FLOAT32); - PD_CHECK(D.is_gpu()); - PD_CHECK(D.strides()[D.strides().size() - 1] == 1 || D.dims()[D.dims().size() - 1] == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - PD_CHECK(delta_bias.dtype() == paddle::DataType::FLOAT32); - PD_CHECK(delta_bias.is_gpu()); - PD_CHECK(delta_bias.strides()[delta_bias.strides().size() - 1] == 1 || delta_bias.dims()[delta_bias.dims().size() - 1] == 1); - CHECK_SHAPE(delta_bias, dim); - } - - paddle::Tensor z, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - PD_CHECK(z.dtype() == input_type); - PD_CHECK(z.is_gpu()); - PD_CHECK(z.strides()[z.strides().size() - 1] == 1 || z.dims()[z.dims().size() - 1] == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - out_z = paddle::empty_like(z); - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // paddle::Tensor out = paddle::empty_like(u); - // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout - paddle::Tensor out = paddle::empty_like(delta); - paddle::Tensor x; - x = paddle::empty({batch_size, dim, n_chunks, dstate * 2}, weight_type, delta.place()); - - SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, out, z, out_z, - D_.has_value() ? const_cast(D_.value().data()) : nullptr, - delta_bias_.has_value() ? const_cast(delta_bias_.value().data()) : nullptr, - x.data(), - has_z, - delta_softplus); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - auto stream = x.stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.dtype(), "selective_scan_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.dtype(), "selective_scan_fwd", [&] { - selective_scan_fwd_cuda(params, stream); - }); - }); - std::vector result = {out, x}; - if (has_z) { result.push_back(out_z); } - return result; -} - -std::vector -selective_scan_bwd(const paddle::Tensor &u, const paddle::Tensor &delta, - const paddle::Tensor &A, const paddle::Tensor &B, const paddle::Tensor &C, - const std::optional &D_, - const std::optional &z_, - const std::optional &delta_bias_, - const paddle::Tensor &dout, - const std::optional &x_, - const std::optional &out_, - std::optional &dz_, - bool delta_softplus, - bool recompute_out_z) { - auto input_type = u.dtype(); - auto weight_type = A.dtype(); - PD_CHECK(input_type == paddle::DataType::FLOAT32 || input_type == paddle::DataType::FLOAT16 || input_type == paddle::DataType::BFLOAT16); - PD_CHECK(weight_type == paddle::DataType::FLOAT32 || weight_type == paddle::DataType::COMPLEX64); - - const bool is_variable_B = B.dims().size() >= 3; - const bool is_variable_C = C.dims().size() >= 3; - const bool is_complex = weight_type == paddle::DataType::COMPLEX64; - - PD_CHECK(delta.dtype() == input_type); - PD_CHECK(B.dtype() == (!is_variable_B ? weight_type : input_type)); - PD_CHECK(C.dtype() == (!is_variable_C ? weight_type : input_type)); - PD_CHECK(dout.dtype() == input_type); - - PD_CHECK(u.is_gpu()); - PD_CHECK(delta.is_gpu()); - PD_CHECK(A.is_gpu()); - PD_CHECK(B.is_gpu()); - PD_CHECK(C.is_gpu()); - PD_CHECK(dout.is_gpu()); - - PD_CHECK(u.strides()[u.strides().size() - 1] == 1 || u.dims()[u.dims().size() - 1] == 1); - PD_CHECK(delta.strides()[delta.strides().size() - 1] == 1 || delta.dims()[delta.dims().size() - 1] == 1); - PD_CHECK(dout.strides()[dout.strides().size() - 1] == 1 || dout.dims()[dout.dims().size() - 1] == 1); - - const auto sizes = u.dims(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.dims()[1]; - const int n_groups = is_variable_B ? B.dims()[1] : 1; - - PD_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - PD_CHECK(B.strides()[B.strides().size() - 1] == 1 || B.dims()[B.dims().size() - 1] == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - PD_CHECK(C.strides()[C.strides().size() - 1] == 1 || C.dims()[C.dims().size() - 1] == 1); - } - CHECK_SHAPE(dout, batch_size, dim, seqlen); - - if (D_.has_value()) { - auto D = D_.value(); - PD_CHECK(D.dtype() == paddle::DataType::FLOAT32); - PD_CHECK(D.is_gpu()); - PD_CHECK(D.strides()[D.strides().size() - 1] == 1 || D.dims()[D.dims().size() - 1] == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - PD_CHECK(delta_bias.dtype() == paddle::DataType::FLOAT32); - PD_CHECK(delta_bias.is_gpu()); - PD_CHECK(delta_bias.strides()[delta_bias.strides().size() - 1] == 1 || delta_bias.dims()[delta_bias.dims().size() - 1] == 1); - CHECK_SHAPE(delta_bias, dim); - } - - paddle::Tensor z, out, dz, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - PD_CHECK(z.dtype() == input_type); - PD_CHECK(z.is_gpu()); - PD_CHECK(z.strides()[z.strides().size() - 1] == 1 || z.dims()[z.dims().size() - 1] == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - - PD_CHECK(out_.has_value()); - out = out_.value(); - PD_CHECK(out.dtype() == input_type); - PD_CHECK(out.is_gpu()); - PD_CHECK(out.strides()[out.strides().size() - 1] == 1 || out.dims()[out.dims().size() - 1] == 1); - CHECK_SHAPE(out, batch_size, dim, seqlen); - - if (dz_.has_value()) { - dz = dz_.value(); - PD_CHECK(dz.dtype() == input_type); - PD_CHECK(dz.is_gpu()); - PD_CHECK(dz.strides()[dz.strides().size() - 1] == 1 || dz.dims()[dz.dims().size() - 1] == 1); - CHECK_SHAPE(dz, batch_size, dim, seqlen); - } else { - dz = paddle::empty_like(z); - } - if (recompute_out_z) { - out_z = paddle::empty_like(out); - } - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - if (n_chunks > 1) { PD_CHECK(x_.has_value()); } - if (x_.has_value()) { - auto x = x_.value(); - PD_CHECK(x.dtype() == weight_type); - PD_CHECK(x.is_gpu()); - // PD_CHECK(x.is_contiguous()); - CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); - } - - paddle::Tensor du = paddle::empty_like(u); - paddle::Tensor ddelta = paddle::empty_like(delta); - paddle::Tensor dA = paddle::experimental::zeros_like(A); - paddle::Tensor dB = !is_variable_B ? paddle::experimental::zeros_like(B) : paddle::experimental::zeros_like(B, paddle::DataType::FLOAT32); - paddle::Tensor dC = !is_variable_C ? paddle::experimental::zeros_like(C) : paddle::experimental::zeros_like(C, paddle::DataType::FLOAT32); - paddle::Tensor dD; - if (D_.has_value()) { dD = paddle::experimental::zeros_like(D_.value()); } - paddle::Tensor ddelta_bias; - if (delta_bias_.has_value()) { ddelta_bias = paddle::experimental::zeros_like(delta_bias_.value()); } - - SSMParamsBwd params; - set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, z, out, out_z, - D_.has_value() ? const_cast(D_.value().data()) : nullptr, - delta_bias_.has_value() ? const_cast(delta_bias_.value().data()) : nullptr, - x_.has_value() ? const_cast(x_.value().data()) : nullptr, - dout, du, ddelta, dA, dB, dC, dz, - D_.has_value() ? const_cast(dD.data()) : nullptr, - delta_bias_.has_value() ? const_cast(ddelta_bias.data()) : nullptr, - has_z, delta_softplus, recompute_out_z); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - auto stream = u.stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.dtype(), "selective_scan_bwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.dtype(), "selective_scan_bwd", [&] { - selective_scan_bwd_cuda(params, stream); - }); - }); - std::vector result = {du, ddelta, dA, dB.cast(B.dtype()), dC.cast(C.dtype()), dD, ddelta_bias}; - if (has_z) { result.push_back(dz); } - if (recompute_out_z) { result.push_back(out_z); } - return result; -} - -PYBIND11_MODULE(selective_scan_cuda_pd, m) { - m.def("fwd", &selective_scan_fwd, "Selective scan forward"); - m.def("bwd", &selective_scan_bwd, "Selective scan backward"); -} diff --git a/ops/csrc/selective_scan/selective_scan.h b/ops/csrc/selective_scan/selective_scan.h deleted file mode 100644 index e2c7bcdbd5dd..000000000000 --- a/ops/csrc/selective_scan/selective_scan.h +++ /dev/null @@ -1,101 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMScanParamsBase { - using index_t = uint32_t; - - int batch, seqlen, n_chunks; - index_t a_batch_stride; - index_t b_batch_stride; - index_t out_batch_stride; - - // Common data pointers. - void *__restrict__ a_ptr; - void *__restrict__ b_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, dstate, n_groups, n_chunks; - int dim_ngroups_ratio; - bool is_variable_B; - bool is_variable_C; - - bool delta_softplus; - - index_t A_d_stride; - index_t A_dstate_stride; - index_t B_batch_stride; - index_t B_d_stride; - index_t B_dstate_stride; - index_t B_group_stride; - index_t C_batch_stride; - index_t C_d_stride; - index_t C_dstate_stride; - index_t C_group_stride; - index_t u_batch_stride; - index_t u_d_stride; - index_t delta_batch_stride; - index_t delta_d_stride; - index_t z_batch_stride; - index_t z_d_stride; - index_t out_batch_stride; - index_t out_d_stride; - index_t out_z_batch_stride; - index_t out_z_d_stride; - - // Common data pointers. - void *__restrict__ A_ptr; - void *__restrict__ B_ptr; - void *__restrict__ C_ptr; - void *__restrict__ D_ptr; - void *__restrict__ u_ptr; - void *__restrict__ delta_ptr; - void *__restrict__ delta_bias_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; - void *__restrict__ z_ptr; - void *__restrict__ out_z_ptr; -}; - -struct SSMParamsBwd: public SSMParamsBase { - index_t dout_batch_stride; - index_t dout_d_stride; - index_t dA_d_stride; - index_t dA_dstate_stride; - index_t dB_batch_stride; - index_t dB_group_stride; - index_t dB_d_stride; - index_t dB_dstate_stride; - index_t dC_batch_stride; - index_t dC_group_stride; - index_t dC_d_stride; - index_t dC_dstate_stride; - index_t du_batch_stride; - index_t du_d_stride; - index_t dz_batch_stride; - index_t dz_d_stride; - index_t ddelta_batch_stride; - index_t ddelta_d_stride; - - // Common data pointers. - void *__restrict__ dout_ptr; - void *__restrict__ dA_ptr; - void *__restrict__ dB_ptr; - void *__restrict__ dC_ptr; - void *__restrict__ dD_ptr; - void *__restrict__ du_ptr; - void *__restrict__ dz_ptr; - void *__restrict__ ddelta_ptr; - void *__restrict__ ddelta_bias_ptr; -}; diff --git a/ops/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu b/ops/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu deleted file mode 100644 index f881143655c2..000000000000 --- a/ops/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_bwd_bf16_real.cu b/ops/csrc/selective_scan/selective_scan_bwd_bf16_real.cu deleted file mode 100644 index dad6506c4388..000000000000 --- a/ops/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu b/ops/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu deleted file mode 100644 index 0d0a3b97c07e..000000000000 --- a/ops/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_bwd_fp16_real.cu b/ops/csrc/selective_scan/selective_scan_bwd_fp16_real.cu deleted file mode 100644 index 3c8d362660d9..000000000000 --- a/ops/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu b/ops/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu deleted file mode 100644 index bd8694838c5f..000000000000 --- a/ops/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_bwd_fp32_real.cu b/ops/csrc/selective_scan/selective_scan_bwd_fp32_real.cu deleted file mode 100644 index b0c170ba3918..000000000000 --- a/ops/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/ops/csrc/selective_scan/selective_scan_bwd_kernel.cuh deleted file mode 100755 index a67bb86f6b8c..000000000000 --- a/ops/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ /dev/null @@ -1,559 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include - -#ifndef USE_ROCM - #include - #include - #include - #include -#else - #include - namespace cub = hipcub; -#endif - -#include "selective_scan.h" -#include "selective_scan_common.h" -#include "reverse_scan.cuh" -#include "static_switch.h" - -template __device__ __forceinline__ scalar_t conj(scalar_t x); -template<> __device__ __forceinline__ float conj(float x) { return x; } -// template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } -template<> __device__ __forceinline__ complex_t conj(complex_t x) { return complex_t(x.real, -x.imag); } - -template -struct Selective_Scan_bwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kNItems = kNItems_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsComplex = std::is_same_v; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; - static constexpr bool kHasZ = kHasZ_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. - // For complex this would lead to massive register spilling, so we keep it at 2. - static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; - using vec_t = typename BytesToType::Type; - using scan_t = std::conditional_t; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockLoadWeightT = cub::BlockLoad; - using BlockLoadWeightVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - // using BlockScanT = cub::BlockScan; - using BlockScanT = cub::BlockScan; - // using BlockScanT = cub::BlockScan; - using BlockReverseScanT = BlockReverseScan; - using BlockReduceT = cub::BlockReduce; - using BlockReduceFloatT = cub::BlockReduce; - using BlockReduceComplexT = cub::BlockReduce; - using BlockExchangeT = cub::BlockExchange; - - static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); - static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_bwd_kernel(SSMParamsBwd params) { - constexpr bool kIsComplex = Ktraits::kIsComplex; - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_weight = reinterpret_cast(smem_); - auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); - auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); - auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); - auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); - auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); - auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); - weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); - scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * MAX_DSTATE + kNThreads); - weight_t *smem_da = reinterpret_cast(smem_running_postfix + MAX_DSTATE); - weight_t *smem_dbc = reinterpret_cast(smem_da + MAX_DSTATE); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * params.u_d_stride; - input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * params.delta_d_stride; - input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride - + dim_id * params.dout_d_stride; - weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; - weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * params.B_d_stride; - input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * params.C_d_stride; - input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; - weight_t *dB = reinterpret_cast(params.dB_ptr) - + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); - weight_t *dC = reinterpret_cast(params.dC_ptr) - + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); - float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; - float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; - float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; - float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; - scan_t *x = params.x_ptr == nullptr - ? nullptr - : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; - float dD_val = 0; - float ddelta_bias_val = 0; - - constexpr int kChunkSize = kNThreads * kNItems; - u += (params.n_chunks - 1) * kChunkSize; - delta += (params.n_chunks - 1) * kChunkSize; - dout += (params.n_chunks - 1) * kChunkSize; - Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); - Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); - for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { - input_t u_vals[kNItems]; - input_t delta_vals_load[kNItems]; - input_t dout_vals_load[kNItems]; - __syncthreads(); - load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); - u -= kChunkSize; - __syncthreads(); - load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - // Will reload delta at the same location if kDeltaSoftplus - if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } - __syncthreads(); - load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - dout -= kChunkSize; - - float dout_vals[kNItems], delta_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dout_vals[i] = float(dout_vals_load[i]); - delta_vals[i] = float(delta_vals_load[i]) + delta_bias; - if constexpr (kDeltaSoftplus) { - delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; - } - } - - if constexpr (kHasZ) { - input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * params.z_d_stride + chunk * kChunkSize; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * params.out_d_stride + chunk * kChunkSize; - input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride - + dim_id * params.dz_d_stride + chunk * kChunkSize; - input_t z_vals[kNItems], out_vals[kNItems]; - __syncthreads(); - load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - __syncthreads(); - load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); - float dz_vals[kNItems], z_silu_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); - z_silu_vals[i] = z_val * z_sigmoid_val; - dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val - * (1.0f + z_val * (1.0f - z_sigmoid_val)); - dout_vals[i] *= z_silu_vals[i]; - } - __syncthreads(); - store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); - if (params.out_z_ptr != nullptr) { // Recompute and store out_z - float out_z_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); - // } - input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * params.out_z_d_stride + chunk * kChunkSize; - __syncthreads(); - store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); - } - } - - float du_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } - - float ddelta_vals[kNItems] = {0}; - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - const weight_t A_val = A[state_idx * params.A_dstate_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - weight_t A_scaled; - constexpr float kLog2e = M_LOG2E; - if constexpr (!kIsComplex) { - A_scaled = A_val * kLog2e; - } else { - A_scaled = complex_t(A_val.real * kLog2e, A_val.imag); - } - weight_t B_val, C_val; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (!kIsVariableB) { - B_val = B[state_idx * params.B_dstate_stride]; - } else { - load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - } - if constexpr (!kIsVariableC) { - C_val = C[state_idx * params.C_dstate_stride]; - } else { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - } - // const weight_t A_val = smem_a[state_idx]; - scan_t thread_data[kNItems], thread_reverse_data[kNItems]; - if constexpr (!kIsComplex) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); - thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); - if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; - } else { - thread_reverse_data[i - 1].x = delta_a_exp; - } - thread_reverse_data[i].y = dout_vals[i] * - (!kIsVariableC - ? (!kIsVariableB ? B_val * C_val : C_val) - : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); - } - __syncthreads(); - thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) - : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; - // Initialize running total - scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp prefix_op(running_prefix); - typename Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp postfix_op(running_postfix); - typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( - thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op - ); - if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } - weight_t dA_val = 0, dBC_val = 0; - weight_t dB_vals[kNItems], dC_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const float dx = thread_reverse_data[i].y; - const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; - du_vals[i] += ddelta_u * delta_vals[i]; - const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); - ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; - dA_val += dx * delta_vals[i] * a; - if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val - dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); - } else { // dBC_val is dC_val - dBC_val += dout_vals[i] * thread_data[i].y; - } - } - if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } - if constexpr (kIsVariableC) { - dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); - } - } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower - if constexpr (kIsVariableB || kIsVariableC) { - if constexpr (kIsVariableB) { - typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); - } - if constexpr (kIsVariableC) { - auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; - typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); - } - const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; - weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; - weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - if (i * kNThreads < seqlen_remaining) { - if constexpr (kIsVariableB) { phi::CudaAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } - if constexpr (kIsVariableC) { phi::CudaAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } - } - } - } - if constexpr (!kIsVariableB || !kIsVariableC) { - float2 dA_dBC_val = make_float2(dA_val, dBC_val); - dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); - dA_val = dA_dBC_val.x; - if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; - } - } else { - dA_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); - } - if (threadIdx.x == 0) { - smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; - } - } else { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - // Pytorch's implementation of complex exp (which calls thrust) is very slow - complex_t delta_a_exp = cexp2f(delta_vals[i] * float(A_scaled)); - weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : float(B_vals[i]) * delta_vals[i] * float(u_vals[i]); - thread_data[i] = make_float4(delta_a_exp.real, delta_a_exp.imag, B_delta_u_val.real, B_delta_u_val.imag); - if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; - } else { - thread_reverse_data[i - 1].x = delta_a_exp.real; - thread_reverse_data[i - 1].y = -delta_a_exp.imag; - } - complex_t dout_BC = 2 * dout_vals[i] - * float(conj(!kIsVariableC - ? (!kIsVariableB ? B_val * C_val : C_val) - : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]))); - thread_reverse_data[i].z = dout_BC.real; - thread_reverse_data[i].w = dout_BC.imag; - } - __syncthreads(); - complex_t delta_a_exp = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) - : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; - thread_reverse_data[kNItems - 1].x = delta_a_exp.real; - thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag; - // Initialize running total - scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - SSMScanPrefixCallbackOp prefix_op(running_prefix); - typename Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - SSMScanPrefixCallbackOp postfix_op(running_postfix); - typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( - thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op - ); - if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } - weight_t dA_val = 0, dBC_val = 0; - weight_t dB_vals[kNItems], dC_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - complex_t x = complex_t(thread_data[i].z, thread_data[i].w); - complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); - float ddelta_u = !kIsVariableB ? dx.real : (dx * conj(B_vals[i])).real; - if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val - dBC_val += weight_t((2 * dout_vals[i]) * float(conj(!kIsVariableC ? x : x * C_vals[i]))); - } else { // dBC_val is dC_val - dBC_val += weight_t((2 * dout_vals[i]) * float(conj(x))); - } - } - const complex_t a_conj = conj(float(x) - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * float(B_vals[i]))); - du_vals[i] += ddelta_u * delta_vals[i]; - ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real; - dA_val += complex_t(delta_vals[i]) * dx * a_conj; - if constexpr (kIsVariableB) { dB_vals[i] = float(dx) * delta_vals[i] * float(u_vals[i]); } - if constexpr (kIsVariableC) { - dC_vals[i] = (2 * dout_vals[i]) * float(conj(!kIsVariableB ? x * B_val : x)); - } - } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower - if constexpr (kIsVariableB || kIsVariableC) { - float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; - if constexpr (kIsVariableB) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dB_vals_f[i * 2] = dB_vals[i].real; - dB_vals_f[i * 2 + 1] = dB_vals[i].imag; - } - typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); - } - if constexpr (kIsVariableC) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dC_vals_f[i * 2] = dC_vals[i].real; - dC_vals_f[i * 2 + 1] = dC_vals[i].imag; - } - auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; - typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); - } - const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; - float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; - float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; - #pragma unroll - for (int i = 0; i < kNItems * 2; ++i) { - if (i * kNThreads < seqlen_remaining) { - if constexpr (kIsVariableB) { phi::CudaAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } - if constexpr (kIsVariableC) { phi::CudaAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } - } - } - } - if constexpr (!kIsVariableB || !kIsVariableC) { - float4 dA_dBC_val = make_float4(dA_val.real, dA_val.imag, dBC_val.real, dBC_val.imag); - dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); - dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); - dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); - if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; - } - } else { - dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); - } - if (threadIdx.x == 0) { - smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; - } - } - } - - if constexpr (kDeltaSoftplus) { - __syncthreads(); - input_t delta_vals_load[kNItems]; - load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - delta -= kChunkSize; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float delta_val = float(delta_vals_load[i]) + delta_bias; - float delta_val_neg_exp = expf(-delta_val); - ddelta_vals[i] = delta_val <= 20.f - ? ddelta_vals[i] / (1.f + delta_val_neg_exp) - : ddelta_vals[i]; - } - } - for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } - - input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride - + dim_id * params.du_d_stride + chunk * kChunkSize; - input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride - + dim_id * params.ddelta_d_stride + chunk * kChunkSize; - __syncthreads(); - store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); - __syncthreads(); - store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); - - Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); - Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); - } - if (params.dD_ptr != nullptr) { - dD_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); - if (threadIdx.x == 0) { phi::CudaAtomicAdd(dD, dD_val); } - } - if (params.ddelta_bias_ptr != nullptr) { - __syncthreads(); - ddelta_bias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); - if (threadIdx.x == 0) { phi::CudaAtomicAdd(ddelta_bias, ddelta_bias_val); } - } - for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - phi::CudaAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); - weight_t dBC_val; - if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } - if constexpr (!kIsVariableB) { - phi::CudaAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), - !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); - } - if constexpr (!kIsVariableC) { - phi::CudaAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), - !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); - } - } -} - -template -void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_bwd_kernel_traits; - // using Ktraits = Selective_Scan_bwd_kernel_traits; - // TODO: check this - constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); - - dim3 grid(params.batch, params.dim); - - auto kernel = &selective_scan_bwd_kernel; - - if (kSmemSize >= 48 * 1024) { - - #ifndef USE_ROCM - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - #else - cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - - } - - kernel<<>>(params); - }); - }); - }); - }); - }); -} - -template -void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { - - #ifndef USE_ROCM - if (params.seqlen <= 128) { - selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); - } - #else - if (params.seqlen <= 256) { - selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); - } - #endif -} \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_common.h b/ops/csrc/selective_scan/selective_scan_common.h deleted file mode 100644 index 6b80796321a9..000000000000 --- a/ops/csrc/selective_scan/selective_scan_common.h +++ /dev/null @@ -1,257 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#if defined(CUDA_BFLOAT16_AVAILABLE) - #ifndef USE_ROCM - #include - #else - #include - #endif -#endif -#include -#include - - -#ifndef USE_ROCM - - constexpr size_t custom_max(std::initializer_list ilist) - { - return std::max(ilist); - } - - template - constexpr T constexpr_min(T a, T b) { - return std::min(a, b); - } - -#else - constexpr size_t custom_max(std::initializer_list ilist) - { - return *std::max_element(ilist.begin(), ilist.end()); - } - - template - constexpr T constexpr_min(T a, T b) { - return a < b ? a : b; - } -#endif - - -#define MAX_DSTATE 256 - -using complex_t = phi::dtype::complex; - -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; -} - -inline __device__ float3 operator+(const float3 &a, const float3 &b) { - return {a.x + b.x, a.y + b.y, a.z + b.z}; -} - -inline __device__ float4 operator+(const float4 & a, const float4 & b){ - return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Converter{ - static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { dst[i] = src[i]; } - } -}; - -template -struct Converter{ - static inline __device__ void to_float(const phi::dtype::float16 (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast(src); - auto &dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } - } -}; - -#if __CUDA_ARCH__ >= 800 -template -struct Converter{ - static inline __device__ void to_float(const phi::dtype::bfloat16 (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast(src); - auto &dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp -// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 -__device__ __forceinline__ complex_t cexp2f(complex_t z) { - float t = exp2f(z.real); - float c, s; - sincosf(z.imag, &s, &c); - return complex_t(c * t, s * t); -} - -__device__ __forceinline__ complex_t cexpf(complex_t z) { - float t = expf(z.real); - float c, s; - sincosf(z.imag, &s, &c); - return complex_t(c * t, s * t); -} - -template struct SSMScanOp; - -template<> -struct SSMScanOp { - __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { - return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); - } -}; - -template<> -struct SSMScanOp { - __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { - complex_t a0 = complex_t(ab0.x, ab0.y); - complex_t b0 = complex_t(ab0.z, ab0.w); - complex_t a1 = complex_t(ab1.x, ab1.y); - complex_t b1 = complex_t(ab1.z, ab1.w); - complex_t out_a = a1 * a0; - complex_t out_b = a1 * b0 + b1; - return make_float4(out_a.real, out_a.imag, out_b.real, out_b.imag); - } -}; - -// A stateful callback functor that maintains a running prefix to be applied -// during consecutive scan operations. -template struct SSMScanPrefixCallbackOp { - using scan_t = std::conditional_t, float2, float4>; - scan_t running_prefix; - // Constructor - __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} - // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide scan. - __device__ scan_t operator()(scan_t block_aggregate) { - scan_t old_prefix = running_prefix; - running_prefix = SSMScanOp()(running_prefix, block_aggregate); - return old_prefix; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void load_input(typename Ktraits::input_t *u, - typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadT::TempStorage &smem_load, - int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_vec = reinterpret_cast(smem_load); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockLoadVecT(smem_load_vec).Load( - reinterpret_cast(u), - reinterpret_cast(u_vals) - #ifdef USE_ROCM - , Ktraits::kNThreads * Ktraits::kNLoads - #endif - - ); - } else { - typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); - } -} - -template -inline __device__ void load_weight(typename Ktraits::input_t *Bvar, - typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, - int seqlen) { - constexpr int kNItems = Ktraits::kNItems; - if constexpr (!Ktraits::kIsComplex) { - typename Ktraits::input_t B_vals_load[kNItems]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast(Bvar), - reinterpret_cast(B_vals_load) - ); - } else { - typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - // #pragma unroll - // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } - Converter::to_float(B_vals_load, B_vals); - } else { - typename Ktraits::input_t B_vals_load[kNItems * 2]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast(Bvar), - reinterpret_cast(B_vals_load) - ); - } else { - typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } - } -} - -template -inline __device__ void store_output(typename Ktraits::input_t *out, - const float (&out_vals)[Ktraits::kNItems], - typename Ktraits::BlockStoreT::TempStorage &smem_store, - int seqlen) { - typename Ktraits::input_t write_vals[Ktraits::kNItems]; - #pragma unroll - for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_store_vec = reinterpret_cast(smem_store); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockStoreVecT(smem_store_vec).Store( - reinterpret_cast(out), - reinterpret_cast(write_vals) - ); - } else { - typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); - } -} diff --git a/ops/csrc/selective_scan/selective_scan_fwd_bf16_complex.cu b/ops/csrc/selective_scan/selective_scan_fwd_bf16_complex.cu deleted file mode 100644 index b897bfa3a917..000000000000 --- a/ops/csrc/selective_scan/selective_scan_fwd_bf16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_fwd_bf16_real.cu b/ops/csrc/selective_scan/selective_scan_fwd_bf16_real.cu deleted file mode 100644 index 591ade54afd0..000000000000 --- a/ops/csrc/selective_scan/selective_scan_fwd_bf16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_fwd_fp16_complex.cu b/ops/csrc/selective_scan/selective_scan_fwd_fp16_complex.cu deleted file mode 100644 index e17f6a00b71d..000000000000 --- a/ops/csrc/selective_scan/selective_scan_fwd_fp16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_fwd_fp16_real.cu b/ops/csrc/selective_scan/selective_scan_fwd_fp16_real.cu deleted file mode 100644 index 304933acfb13..000000000000 --- a/ops/csrc/selective_scan/selective_scan_fwd_fp16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_fwd_fp32_complex.cu b/ops/csrc/selective_scan/selective_scan_fwd_fp32_complex.cu deleted file mode 100644 index 8076c860893a..000000000000 --- a/ops/csrc/selective_scan/selective_scan_fwd_fp32_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_fwd_fp32_real.cu b/ops/csrc/selective_scan/selective_scan_fwd_fp32_real.cu deleted file mode 100644 index be1cf1e234b3..000000000000 --- a/ops/csrc/selective_scan/selective_scan_fwd_fp32_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in parallel - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/ops/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/ops/csrc/selective_scan/selective_scan_fwd_kernel.cuh deleted file mode 100755 index fd9f120f7302..000000000000 --- a/ops/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ /dev/null @@ -1,373 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#ifndef USE_ROCM - #include - #include - #include -#else - #include - namespace cub = hipcub; -#endif - -#include "selective_scan.h" -#include "selective_scan_common.h" -#include "static_switch.h" - -template -struct Selective_Scan_fwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. - static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; - static constexpr int kNItems = kNItems_; - static constexpr int kNRows = kNRows_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsComplex = std::is_same_v; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kHasZ = kHasZ_; - - static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; - - using vec_t = typename BytesToType::Type; - using scan_t = std::conditional_t; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockLoadWeightT = cub::BlockLoad; - using BlockLoadWeightVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - // using BlockScanT = cub::BlockScan; - // using BlockScanT = cub::BlockScan; - using BlockScanT = cub::BlockScan; - static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_fwd_kernel(SSMParamsBase params) { - constexpr bool kIsComplex = Ktraits::kIsComplex; - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - constexpr int kNRows = Ktraits::kNRows; - constexpr bool kDirectIO = Ktraits::kDirectIO; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_weight = reinterpret_cast(smem_); - auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); - // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); - scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * kNRows * params.u_d_stride; - input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * kNRows * params.delta_d_stride; - weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; - weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; - input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; - input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - - float D_val[kNRows] = {0}; - if (params.D_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; - } - } - float delta_bias[kNRows] = {0}; - if (params.delta_bias_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; - } - } - - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; - // } - - constexpr int kChunkSize = kNThreads * kNItems; - for (int chunk = 0; chunk < params.n_chunks; ++chunk) { - input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); - if constexpr (!kDirectIO) { __syncthreads(); } - load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); - } - u += kChunkSize; - delta += kChunkSize; - - float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float u_val = float(u_vals[r][i]); - delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; - if (params.delta_softplus) { - delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; - } - delta_u_vals[r][i] = delta_vals[r][i] * u_val; - out_vals[r][i] = D_val[r] * u_val; - } - } - - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - weight_t A_val[kNRows]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - constexpr float kLog2e = M_LOG2E; - if constexpr (!kIsComplex) { - A_val[r] *= kLog2e; - } else { - A_val[r].real *= kLog2e; - } - } - // This variable holds B * C if both B and C are constant across seqlen. If only B varies - // across seqlen, this holds C. If only C varies across seqlen, this holds B. - // If both B and C vary, this is unused. - weight_t BC_val[kNRows]; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (kIsVariableB) { - load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - if constexpr (!kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - } - if constexpr (kIsVariableC) { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - if constexpr (!kIsVariableB) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; - } - } - } - if constexpr (!kIsVariableB && !kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if (r > 0) { __syncthreads(); } // Scan could be using the same smem - scan_t thread_data[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - if constexpr (!kIsComplex) { - thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), - !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float2(1.f, 0.f); - } - } - } else { - // Pytorch's implementation of complex exp (which calls thrust) is very slow - complex_t delta_a_exp = cexp2f(delta_vals[r][i] * float(A_val[r])); - weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : float(B_vals[i]) * delta_u_vals[r][i]; - thread_data[i] = make_float4(delta_a_exp.real, delta_a_exp.imag, B_delta_u_val.real, B_delta_u_val.imag); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); - } - } - } - } - // Initialize running total - scan_t running_prefix; - if constexpr (!kIsComplex) { - // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); - } else { - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - } - SSMScanPrefixCallbackOp prefix_op(running_prefix); - typename Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - // There's a syncthreads in the scan op, so we don't need to sync here. - // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. - if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; - x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const weight_t C_val = !kIsVariableC - ? BC_val[r] - : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); - if constexpr (!kIsComplex) { - out_vals[r][i] += thread_data[i].y * C_val; - } else { - out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real * 2; - } - } - } - } - - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } - - if constexpr (kHasZ) { - input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; - input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - input_t z_vals[kNItems]; - __syncthreads(); - load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - out_vals[r][i] *= z_val / (1 + expf(-z_val)); - } - __syncthreads(); - store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } - } - - Bvar += kChunkSize * (!kIsComplex ? 1 : 2); - Cvar += kChunkSize * (!kIsComplex ? 1 : 2); - } -} - -template -void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { - // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block - // processing 1 row. - constexpr int kNRows = 1; - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - dim3 grid(params.batch, params.dim / kNRows); - - // Had to change this substantially since potentially the hip - // interface for setting kernel launch attributes is slightly different from - // cuda's. In particualar, it seems to expect a plain const void * pointer. - - auto kernel = &selective_scan_fwd_kernel; - - - if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - #else - cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); - std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - } - - kernel<<>>(params); - }); - }); - }); - }); -} - -template -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { - - #ifndef USE_ROCM - if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); - } - #else - if (params.seqlen <= 256) { - selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); - } - #endif -} diff --git a/ops/csrc/selective_scan/static_switch.h b/ops/csrc/selective_scan/static_switch.h deleted file mode 100644 index 099e0756c4de..000000000000 --- a/ops/csrc/selective_scan/static_switch.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/ops/csrc/selective_scan/uninitialized_copy.cuh b/ops/csrc/selective_scan/uninitialized_copy.cuh deleted file mode 100644 index cdaf115e34a3..000000000000 --- a/ops/csrc/selective_scan/uninitialized_copy.cuh +++ /dev/null @@ -1,77 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#ifndef USE_ROCM - #include - - #include -#else - #include - // Map ::cuda::std to the standard std namespace - namespace cuda { - namespace std = ::std; - } -#endif - - -namespace detail -{ - -#if defined(_NVHPC_CUDA) -template -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - // NVBug 3384810 - new (ptr) T(::cuda::std::forward(val)); -} -#else -template ::value, - int - >::type = 0> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - *ptr = ::cuda::std::forward(val); -} - -template ::value, - int - >::type = 0> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - new (ptr) T(::cuda::std::forward(val)); -} -#endif - -} // namespace detail diff --git a/ops/csrc/setup.py b/ops/csrc/setup.py deleted file mode 100644 index 548a467f8fcb..000000000000 --- a/ops/csrc/setup.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -import os -from site import getsitepackages - -import paddle - -paddle_includes = [] -for site_packages_path in getsitepackages(): - paddle_includes.append(os.path.join(site_packages_path, "paddle", "include")) - paddle_includes.append(os.path.join(site_packages_path, "paddle", "include", "third_party")) - paddle_includes.append(os.path.join(site_packages_path, "nvidia", "cudnn", "include")) - - -def get_gencode_flags(compiled_all=False): - if not compiled_all: - prop = paddle.device.cuda.get_device_properties() - cc = prop.major * 10 + prop.minor - return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)] - else: - return [ - "-gencode", - "arch=compute_80,code=sm_80", - "-gencode", - "arch=compute_75,code=sm_75", - "-gencode", - "arch=compute_70,code=sm_70", - ] - - -def get_sm_version(): - prop = paddle.device.cuda.get_device_properties() - cc = prop.major * 10 + prop.minor - return cc - - -def run_single(func): - p = multiprocessing.Process(target=func) - p.start() - p.join() - - -def run_multi(func_list): - processes = [] - for func in func_list: - processes.append(multiprocessing.Process(target=func)) - processes.append(multiprocessing.Process(target=func)) - processes.append(multiprocessing.Process(target=func)) - - for p in processes: - p.start() - - for p in processes: - p.join() - - -cc_flag = get_gencode_flags(compiled_all=False) -cc = get_sm_version() - - -def setup_fast_ln(): - from paddle.utils.cpp_extension import CUDAExtension, setup - - setup( - name="fast_ln", - ext_modules=CUDAExtension( - include_dirs=paddle_includes, - sources=[ - "fast_ln/ln_api.cpp", - "fast_ln/ln_bwd_semi_cuda_kernel.cu", - "fast_ln/ln_fwd_cuda_kernel.cu", - ], - extra_compile_args={ - "cxx": ["-O3"], - "nvcc": [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "-I./apex/contrib/csrc/layer_norm/", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + cc_flag, - }, - ), - ) - - -def setup_fused_ln(): - from paddle.utils.cpp_extension import CUDAExtension, setup - - setup( - name="fused_ln", - ext_modules=CUDAExtension( - include_dirs=paddle_includes, - sources=[ - "fused_ln/layer_norm_cuda.cu", - ], - extra_compile_args={ - "cxx": ["-O3"], - "nvcc": [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "-I./apex/contrib/csrc/layer_norm/", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "-maxrregcount=50", - ] - + cc_flag, - }, - ), - ) - - -def setup_causal_conv1d(): - from paddle.utils.cpp_extension import CUDAExtension, setup - - sources = [ - "causal_conv1d/causal_conv1d.cpp", - "causal_conv1d/causal_conv1d_fwd.cu", - "causal_conv1d/causal_conv1d_bwd.cu", - "causal_conv1d/causal_conv1d_update.cu", - ] - - if cc >= 75: - cc_flag.append("-DCUDA_BFLOAT16_AVAILABLE") - - extra_compile_args = { - "cxx": ["-O3"], - "nvcc": [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo", - "--threads", - "4", - ] - + cc_flag, - } - - setup( - name="causal_conv1d_cuda_pd", - ext_modules=CUDAExtension( - sources=sources, - extra_compile_args=extra_compile_args, - ), - ) - - -def setup_selective_scan(): - from paddle.utils.cpp_extension import CUDAExtension, setup - - real_complex_list = ["real"] - dtype_list = ["fp16", "fp32"] - - if cc > 75: - dtype_list.insert(1, "bf16") - cc_flag.append("-DCUDA_BFLOAT16_AVAILABLE") - - sources = [ - "selective_scan/selective_scan.cpp", - ] - for real_or_complex in real_complex_list: - for dtype in dtype_list: - sources.append(f"selective_scan/selective_scan_fwd_{dtype}_{real_or_complex}.cu") - sources.append(f"selective_scan/selective_scan_bwd_{dtype}_{real_or_complex}.cu") - - extra_compile_args = { - "cxx": ["-O3", "-std=c++17"], - "nvcc": [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo", - "--threads", - "4", - ] - + cc_flag, - } - - setup( - name="selective_scan_cuda_pd", - ext_modules=CUDAExtension( - include_dirs=paddle_includes, - sources=sources, - extra_compile_args=extra_compile_args, - ), - ) - - -def setup_paddle_bwd_ops(): - from paddle.utils.cpp_extension import CUDAExtension, setup - - sources = ["paddle_bwd_ops/flash_attn_bwd.cc", "paddle_bwd_ops/add_bwd.cc", "paddle_bwd_ops/matmul_bwd.cc"] - try: - from paddle.nn.functional.flash_attention import ( # noqa: F401 - flash_attention_with_sparse_mask, - ) - - sources.append("paddle_bwd_ops/flash_attn_with_sparse_mask_bwd.cc") - except ImportError: - from paddle.nn.functional.flash_attention import ( # noqa: F401 - flashmask_attention, - ) - - sources.append("paddle_bwd_ops/flashmask_attn_bwd.cc") - - setup( - name="paddle_bwd_ops", - ext_modules=CUDAExtension( - include_dirs=paddle_includes, - sources=sources, - ), - ) - - -if __name__ == "__main__": - run_multi( - [ - setup_fast_ln, - setup_fused_ln, - setup_causal_conv1d, - setup_selective_scan, - setup_paddle_bwd_ops, - ], - ) diff --git a/ops/requirements.txt b/ops/requirements.txt deleted file mode 100644 index 287c43990e17..000000000000 --- a/ops/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -triton>=3.0.0 -paddlenlp>=3.0.0.b2 -einops>=0.6.1 -numpy \ No newline at end of file diff --git a/ops/setup.py b/ops/setup.py deleted file mode 100644 index 5e6c1f1d3ced..000000000000 --- a/ops/setup.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import shutil -import sys -import textwrap -import warnings -from pathlib import Path - -from setuptools import find_packages, setup - -version_range_max = max(sys.version_info[1], 10) + 1 - - -def read_requirements_file(filepath): - with open(filepath) as fin: - requirements = fin.read() - return requirements - - -def write_custom_op_api_py(libname, filename): - libname = str(libname) - filename = str(filename) - import paddle - - op_names = paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(libname) - api_content = [paddle.utils.cpp_extension.extension_utils._custom_api_content(op_name) for op_name in op_names] - dirname = os.path.dirname(os.path.abspath(filename)) - if not os.path.exists(dirname): - os.makedirs(dirname) - - _stub_template = textwrap.dedent( - """ - # THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY - - {custom_api} - - import os - import sys - import types - import paddle - import importlib.abc - import importlib.util - - cur_dir = os.path.dirname(os.path.abspath(__file__)) - so_path = os.path.join(cur_dir, "lib/{resource}") - - def __bootstrap__(): - assert os.path.exists(so_path) - # load custom op shared library with abs path - custom_ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) - - if os.name == 'nt' or sys.platform.startswith('darwin'): - # Cpp Extension only support Linux now - mod = types.ModuleType(__name__) - else: - try: - spec = importlib.util.spec_from_file_location(__name__, so_path) - assert spec is not None - mod = importlib.util.module_from_spec(spec) - assert isinstance(spec.loader, importlib.abc.Loader) - spec.loader.exec_module(mod) - except ImportError: - mod = types.ModuleType(__name__) - - for custom_op in custom_ops: - setattr(mod, custom_op, eval(custom_op)) - - __bootstrap__() - - """ - ).lstrip() - - with open(filename, "w", encoding="utf-8") as f: - f.write(_stub_template.format(resource=os.path.basename(libname), custom_api="\n\n".join(api_content))) - - -if len(sys.argv) > 0: - # generate lib files - lib_path = Path("src/paddlenlp_kernel/cuda/lib") - if lib_path.exists(): - shutil.rmtree(lib_path) - lib_path.mkdir(exist_ok=True) - (lib_path / "__init__.py").touch(exist_ok=True) - has_built = False - for so_file in Path("csrc").glob("**/*.so"): - so_filename = so_file.name - # so file - new_so_filename = so_filename.replace(".so", "_pd.so") - new_so_file = lib_path / new_so_filename - # py file - py_filename = so_filename.replace(".so", ".py") - new_py_file = lib_path.parent / py_filename - shutil.copyfile(so_file, new_so_file) - write_custom_op_api_py(new_so_file, new_py_file) - has_built = True - - if not has_built: - warnings.warn("No cuda lib found. Please build cuda lib first. See details in csrc/README.md.") - -# NEW ADDED END -setup( - name="paddlenlp_kernel", - version="0.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) - description="PaddleNLP GPU OPS cuda & triton.", - long_description=open("README.md", "r", encoding="utf-8").read(), - long_description_content_type="text/markdown", - keywords="paddlenlp kernel contain cuda & triton", - license="Apache 2.0 License", - author="PaddlePaddle", - author_email="paddlenlp@baidu.com", - url="https://github.com/PaddlePaddle/paddlenlp/ops", - package_dir={"": "src"}, - packages=find_packages("src"), - package_data={"paddlenlp_kernel.cuda.lib": ["*.so", "*.dll", "*.dylib"]}, - include_package_data=True, - python_requires=">=3.8.0", - install_requires=read_requirements_file("requirements.txt"), - classifiers=[ - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Programming Language :: Python :: 3", - ] - + [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)], -) diff --git a/ops/src/paddlenlp_kernel/__init__.py b/ops/src/paddlenlp_kernel/__init__.py deleted file mode 100644 index 45430611c3b7..000000000000 --- a/ops/src/paddlenlp_kernel/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__version__ = "0.1.0" diff --git a/ops/src/paddlenlp_kernel/cuda/__init__.py b/ops/src/paddlenlp_kernel/cuda/__init__.py deleted file mode 100644 index fd05a9208165..000000000000 --- a/ops/src/paddlenlp_kernel/cuda/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/ops/src/paddlenlp_kernel/cuda/causal_conv1d.py b/ops/src/paddlenlp_kernel/cuda/causal_conv1d.py deleted file mode 100644 index ed72215f122c..000000000000 --- a/ops/src/paddlenlp_kernel/cuda/causal_conv1d.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) 2024, Tri Dao. -""" -this code is modified from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py -""" -import paddle -import paddle.nn.functional as F - -from ..utils import custom_bwd, custom_fwd - -try: - from . import causal_conv1d_cuda_pd as causal_conv1d_cuda -except ImportError: - causal_conv1d_cuda = None - - -class CausalConv1dFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, - ): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.strides[2] != 1 and x.strides[1] != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert initial_states is None, "initial_states must be None if seq_idx is not None" - assert not return_final_states, "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and (initial_states.strides[2] != 1 and initial_states.strides[1] != 1): - initial_states = initial_states.contiguous() - if return_final_states: - assert x.strides[1] == 1, "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert final_states_out.strides[2] == 1 or final_states_out.strides[1] == 1 - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = paddle.empty([batch, width - 1, dim], dtype=x.dtype).transpose([0, 2, 1]) - else: - final_states_out = None - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_cuda.causal_conv1d_fwd( - x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation - ) - - if seq_idx is not None and initial_states is not None: - ctx.save_mode = 0 - ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) - elif initial_states is None and seq_idx is not None: - ctx.save_mode = 1 - ctx.save_for_backward(x, weight, bias, seq_idx) - elif seq_idx is None and initial_states is not None: - ctx.save_mode = 2 - ctx.save_for_backward(x, weight, bias, initial_states) - else: - ctx.save_mode = 3 - ctx.save_for_backward(x, weight, bias) - - ctx.return_final_states = return_final_states - ctx.return_dinitial_states = initial_states is not None and not initial_states.stop_gradient - return out if not return_final_states else (out, final_states_out) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - initial_states = seq_idx = None - if ctx.save_mode == 0: - x, weight, bias, seq_idx, initial_states = ctx.saved_tensor() - elif ctx.save_mode == 1: - x, weight, bias, seq_idx = ctx.saved_tensor() - elif ctx.save_mode == 2: - x, weight, bias, initial_states = ctx.saved_tensor() - else: - x, weight, bias = ctx.saved_tensor() - - dfinal_states = args[0] if ctx.return_final_states else None - - # if dout.strides[2] != 1 and dout.strides[1] != 1: - # dout = dout.contiguous() - # NEW ADD, not in c++ code - is_channel_last = x.strides[1] == 1 and x.strides[2] > 1 - if not is_channel_last and dout.strides[2] != 1: - dout = dout.contiguous() - if ctx.return_final_states: - dfinal_states = dfinal_states.contiguous() - - if is_channel_last and dout.strides[1] != 1: - dout = dout.transpose([0, 2, 1]).contiguous().transpose([0, 2, 1]) - if ctx.return_final_states: - dfinal_states = dfinal_states.transpose([0, 2, 1]).contiguous().transpose([0, 2, 1]) - - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd( - x, - weight, - bias, - dout, - seq_idx, - initial_states, - dfinal_states, - None, - ctx.return_dinitial_states, - ctx.activation, - ) - return ( - dx, - dweight, - dbias if bias is not None else None, - None, - dinitial_states if initial_states is not None else None, - None, - None, - None, - ) - - -def causal_conv1d_fn( - x, - weight, - bias=None, - seq_idx=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) seq_idx: (batch, seqlen) initial_states: (batch, dim, - width - 1) final_states_out: (batch, dim, width - 1), to be written to activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - - return CausalConv1dFn.apply( - x, - weight, - bias, - seq_idx, - initial_states, - return_final_states, - final_states_out, - activation, - ) - - -def causal_conv1d_ref( - x, - weight, - bias=None, - initial_states=None, - return_final_states=False, - final_states_out=None, - activation=None, -): - """ - x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) initial_states: (batch, dim, width - 1) final_states_out: - (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.cast(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - else: - x = paddle.concat([initial_states.cast(x.dtype), x], axis=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - tmp = width - 1 - x.shape[-1] - if tmp < 0: - final_states = x[..., -tmp:].cast(dtype_in) # (batch, dim, width - 1) - else: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0), data_format="NCL").cast( - dtype_in - ) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states.cast(final_states_out.dtype), False) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).cast(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 weight: - (dim, width) bias: (dim,) cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. The conv_state will be updated by copying x to the - conv_state starting at the index @cache_seqlens % state_len. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - out = causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation, cache_seqlens) - if unsqueeze: - out = out.squeeze(-1) - return out - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): - """ - x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 weight: - (dim, width) bias: (dim,) cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. The conv_state will be updated by copying x to the - conv_state starting at the index @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert tuple(conv_state.shape) == (batch, dim, state_len) - assert tuple(weight.shape) == (dim, width) - if cache_seqlens is None: - x_new = paddle.concat([conv_state, x], axis=-1).cast(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state.copy_(x_new[:, :, -state_len:].cast(conv_state.dtype), False) - else: - width_idx = paddle.arange(-(width - 1), 0, dtype=cache_seqlens.dtype).unsqueeze(0) + cache_seqlens.unsqueeze(1) - state_len = paddle.to_tensor(state_len, dtype=width_idx.dtype) - width_idx = paddle.remainder(width_idx, state_len).unsqueeze(1).expand([-1, dim, -1]) - x_new = paddle.concat([paddle.take_along_axis(conv_state, width_idx, axis=2), x], axis=-1).cast(weight.dtype) - copy_idx = paddle.arange(seqlen, dtype=cache_seqlens.dtype).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = paddle.remainder(copy_idx, state_len).unsqueeze(1).expand([-1, dim, -1]) - conv_state.copy_(conv_state.put_along_axis(copy_idx, x, axis=2), False) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).cast(dtype=dtype_in) diff --git a/ops/src/paddlenlp_kernel/cuda/selective_scan.py b/ops/src/paddlenlp_kernel/cuda/selective_scan.py deleted file mode 100644 index 4ed2b637ada7..000000000000 --- a/ops/src/paddlenlp_kernel/cuda/selective_scan.py +++ /dev/null @@ -1,526 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. -""" -this code is modified from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py -""" -import paddle -import paddle.nn.functional as F -from einops import rearrange, repeat - -from ..utils import custom_bwd, custom_fwd - -try: - from .causal_conv1d import causal_conv1d_fn -except ImportError: - causal_conv1d_fn = None - -try: - from . import causal_conv1d_cuda_pd as causal_conv1d_cuda -except ImportError: - causal_conv1d_cuda = None -try: - from . import selective_scan_cuda_pd as selective_scan_cuda -except ImportError: - selective_scan_cuda = None -from paddle.distributed import fleet - -from ..utils import get_autocast_gpu_dtype, is_autocast_enabled - - -class SelectiveScanFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward( - ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False - ): - if u.strides[-1] != 1: - u = u.contiguous() - if delta.strides[-1] != 1: - delta = delta.contiguous() - if D is not None: - D = D.contiguous() - if B.strides[-1] != 1: - B = B.contiguous() - if C.strides[-1] != 1: - C = C.contiguous() - if z is not None and z.strides[-1] != 1: - z = z.contiguous() - if B.dim() == 3: - B = rearrange(B, "b dstate l -> b 1 dstate l") - ctx.squeeze_B = True - if C.dim() == 3: - C = rearrange(C, "b dstate l -> b 1 dstate l") - ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) - ctx.delta_softplus = delta_softplus - ctx.has_z = z is not None - last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) - if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) - return out if not return_last_state else (out, last_state) - else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) - out_z = rest[0] - return out_z if not return_last_state else (out_z, last_state) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensor() - z = None - out = None - else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensor() - if dout.strides[-1] != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - # Here we just pass in None and dz will be allocated in the C++ code. - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, - delta, - A, - B, - C, - D, - z, - delta_bias, - dout, - x, - out, - None, - ctx.delta_softplus, - False, # option to recompute out_z, not used here - ) - dz = rest[0] if ctx.has_z else None - dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB - dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - - # new added for tp parallel - try: - hcg = fleet.get_hybrid_communicate_group() - mp_group = hcg.get_model_parallel_group() - mp_src_rank = hcg.get_model_parallel_group_src_rank() - mp_world_size = hcg.get_model_parallel_world_size() - except Exception: - mp_group = None - mp_src_rank = 0 - mp_world_size = 1 - - if mp_world_size > 1: - paddle.distributed.broadcast( - dB, - src=mp_src_rank, - group=mp_group, - ) - paddle.distributed.broadcast( - dC, - src=mp_src_rank, - group=mp_group, - ) - - return ( - du, - ddelta, - dA, - dB, - dC, - dD if D is not None else None, - dz, - ddelta_bias if delta_bias is not None else None, - None, - None, - ) - - -def selective_scan_fn( - u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False -): - """if return_last_state is True, returns (out, last_state) - last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the - backward pass. - """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) - - -def selective_scan_ref( - u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False -): - """ - u: r(B D L) delta: r(B D L) A: c(D N) or r(D N) B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) C: - c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 - - out: r(B D L) last_state (optional): r(B D dstate) or c(B D dstate) - """ - dtype_in = u.dtype - u = u.cast("float32") - delta = delta.cast("float32") - if delta_bias is not None: - delta = delta + delta_bias[..., None].cast("float32") - if delta_softplus: - delta = F.softplus(delta) - batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] - is_variable_B = B.dim() >= 3 - is_variable_C = C.dim() >= 3 - if A.is_complex(): - if is_variable_B: - B = paddle.as_complex(rearrange(B.cast("float32"), "... (L two) -> ... L two", two=2)) - if is_variable_C: - C = paddle.as_complex(rearrange(C.cast("float32"), "... (L two) -> ... L two", two=2)) - else: - B = B.cast("float32") - C = C.cast("float32") - x = paddle.zeros((batch, dim, dstate), dtype=A.dtype) - ys = [] - deltaA = paddle.exp(paddle.einsum("bdl,dn->bdln", delta, A)) - if not is_variable_B: - deltaB_u = paddle.einsum("bdl,dn,bdl->bdln", delta, B, u) - else: - if B.dim() == 3: - deltaB_u = paddle.einsum("bdl,bnl,bdl->bdln", delta, B, u) - else: - B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = paddle.einsum("bdl,bdnl,bdl->bdln", delta, B, u) - if is_variable_C and C.dim() == 4: - C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) - last_state = None - for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - if not is_variable_C: - y = paddle.einsum("bdn,dn->bd", x, C) - else: - if C.dim() == 3: - y = paddle.einsum("bdn,bn->bd", x, C[:, :, i]) - else: - y = paddle.einsum("bdn,bdn->bd", x, C[:, :, :, i]) - if i == u.shape[2] - 1: - last_state = x - if y.is_complex(): - y = y.real() * 2 - ys.append(y) - y = paddle.stack(ys, axis=2) # (batch dim L) - out = y if D is None else y + u * rearrange(D, "d -> d 1") - if z is not None: - out = out * F.silu(z) - out = out.cast(dtype=dtype_in) - return out if not return_last_state else (out, last_state) - - -class MambaInnerFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward( - ctx, - xz, - conv1d_weight, - conv1d_bias, - x_proj_weight, - delta_proj_weight, - out_proj_weight, - out_proj_bias, - A, - B=None, - C=None, - D=None, - delta_bias=None, - B_proj_bias=None, - C_proj_bias=None, - delta_softplus=True, - checkpoint_lvl=1, - ): - """ - xz: (batch, dim, seqlen) - """ - assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." - assert checkpoint_lvl in [0, 1] - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - if is_autocast_enabled(): - amp_dtype = get_autocast_gpu_dtype() - x_proj_weight = x_proj_weight.cast(dtype=amp_dtype) - delta_proj_weight = delta_proj_weight.cast(dtype=amp_dtype) - out_proj_weight = out_proj_weight.cast(dtype=amp_dtype) - out_proj_bias = out_proj_bias.cast(dtype=amp_dtype) if out_proj_bias is not None else None - if xz.strides[-1] != 1: - xz = xz.contiguous() - conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") - # chunk do not share memory, so we can't use it here. - # x, z = xz.chunk(2, axis=1) - half_xz = xz.shape[1] // 2 - x = xz[:, :half_xz] - z = xz[:, half_xz:] - conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) - # We're being very careful here about the layout, to avoid extra transposes. - # We want delta to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight.t()) # (bl d) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) - ctx.is_variable_B = B is None - ctx.is_variable_C = C is None - ctx.B_proj_bias_is_None = B_proj_bias is None - ctx.C_proj_bias_is_None = C_proj_bias is None - if B is None: # variable B - B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate) - if B_proj_bias is not None: - B = B + B_proj_bias.cast(dtype=B.dtype) - if not A.is_complex(): - # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() - B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - else: - B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() - else: - if B.strides[-1] != 1: - B = B.contiguous() - if C is None: # variable C - C = x_dbl[:, -d_state:] # (bl dstate) - if C_proj_bias is not None: - C = C + C_proj_bias.cast(dtype=C.dtype) - if not A.is_complex(): - # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() - C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - else: - C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() - else: - if C.strides[-1] != 1: - C = C.contiguous() - if D is not None: - D = D.contiguous() - out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus - ) - ctx.delta_softplus = delta_softplus - ctx.out_proj_bias_is_None = out_proj_bias is None - ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass - conv1d_out, delta = None, None - ctx.save_for_backward( - xz, - conv1d_weight, - conv1d_bias, - x_dbl, - x_proj_weight, - delta_proj_weight, - out_proj_weight, - conv1d_out, - delta, - A, - B, - C, - D, - delta_bias, - scan_intermediates, - out, - ) - return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight.t(), out_proj_bias) - - @staticmethod - @custom_bwd - def backward(ctx, dout): - # dout: (batch, seqlen, dim) - assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." - ( - xz, - conv1d_weight, - conv1d_bias, - x_dbl, - x_proj_weight, - delta_proj_weight, - out_proj_weight, - conv1d_out, - delta, - A, - B, - C, - D, - delta_bias, - scan_intermediates, - out, - ) = ctx.saved_tensor() - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - x, z = xz.chunk(2, axis=1) - if dout.strides[-1] != 1: - dout = dout.contiguous() - if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - dxz = paddle.empty_like(xz) # (batch, dim, seqlen) - # dx, dz = dxz.chunk(2, axis=1) - half_dxz = dxz.shape[1] // 2 - dx = dxz[:, :half_dxz] - dz = dxz[:, half_dxz:] - - dout = rearrange(dout, "b l e -> e (b l)") - dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) - dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( - conv1d_out, - delta, - A, - B, - C, - D, - z, - delta_bias, - dout_y, - scan_intermediates, - out, - dz, - ctx.delta_softplus, - True, # option to recompute out_z - ) - dout_proj_weight = paddle.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) - dout_proj_bias = dout.sum(axis=(0, 1)) if not ctx.out_proj_bias_is_None else None - dD = dD if D is not None else None - dx_dbl = paddle.empty_like(x_dbl) - dB_proj_bias = None - if ctx.is_variable_B: - if not A.is_complex(): - dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() - else: - dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() - dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None - dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d) - dB = None - dC_proj_bias = None - if ctx.is_variable_C: - if not A.is_complex(): - dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() - else: - dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() - dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None - dx_dbl[:, -d_state:] = dC # (bl d) - dC = None - ddelta = rearrange(ddelta, "b d l -> d (b l)") - ddelta_proj_weight = paddle.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) - dx_dbl[:, :delta_rank] = paddle.einsum("dB,dr->Br", ddelta, delta_proj_weight) - dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") - dx_proj_weight = paddle.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) - dconv1d_out = paddle.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t()) - dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( - x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True - ) - dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None - dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return ( - dxz, - dconv1d_weight, - dconv1d_bias, - dx_proj_weight, - ddelta_proj_weight, - dout_proj_weight, - dout_proj_bias, - dA, - dB, - dC, - dD, - ddelta_bias if delta_bias is not None else None, - dB_proj_bias, - dC_proj_bias, - None, - ) - - -def mamba_inner_fn( - xz, - conv1d_weight, - conv1d_bias, - x_proj_weight, - delta_proj_weight, - out_proj_weight, - out_proj_bias, - A, - B=None, - C=None, - D=None, - delta_bias=None, - B_proj_bias=None, - C_proj_bias=None, - delta_softplus=True, - is_paddle_linear=False, -): - if is_paddle_linear: - # NOTE: paddle linear weight is transposed - x_proj_weight = x_proj_weight.t().contiguous() - out_proj_weight = out_proj_weight.t().contiguous() - delta_proj_weight = delta_proj_weight.t().contiguous() - - return MambaInnerFn.apply( - xz, - conv1d_weight, - conv1d_bias, - x_proj_weight, - delta_proj_weight, - out_proj_weight, - out_proj_bias, - A, - B, - C, - D, - delta_bias, - B_proj_bias, - C_proj_bias, - delta_softplus, - ) - - -def mamba_inner_ref( - xz, - conv1d_weight, - conv1d_bias, - x_proj_weight, - delta_proj_weight, - out_proj_weight, - out_proj_bias, - A, - B=None, - C=None, - D=None, - delta_bias=None, - B_proj_bias=None, - C_proj_bias=None, - delta_softplus=True, - is_paddle_linear=False, -): - if is_paddle_linear: - # NOTE: paddle linear weight is transposed - x_proj_weight = x_proj_weight.t().contiguous() - out_proj_weight = out_proj_weight.t().contiguous() - delta_proj_weight = delta_proj_weight.t().contiguous() - assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - x, z = xz.chunk(2, axis=1) - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") - # We're being very careful here about the layout, to avoid extra transposes. - # We want delta to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight.t()) # (bl d) - delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() - delta = rearrange(delta, "d (b l) -> b d l", l=L) - if B is None: # variable B - B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d) - if B_proj_bias is not None: - B = B + B_proj_bias.cast(dtype=B.dtype) - if not A.is_complex(): - B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() - else: - B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - if C is None: # variable B - C = x_dbl[:, -d_state:] # (bl d) - if C_proj_bias is not None: - C = C + C_proj_bias.cast(dtype=C.dtype) - if not A.is_complex(): - C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() - else: - C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) - return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight.t(), out_proj_bias) diff --git a/ops/src/paddlenlp_kernel/triton/__init__.py b/ops/src/paddlenlp_kernel/triton/__init__.py deleted file mode 100644 index 89e3dd9ffaf7..000000000000 --- a/ops/src/paddlenlp_kernel/triton/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .triton_patch import * diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/__init__.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/__init__.py deleted file mode 100644 index e61b14c57f84..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -the code in this folder is adapted from https://github.com/apple/ml-cross-entropy -""" -from .linear_cross_entropy import ( - LinearCrossEntropy, - LinearCrossEntropyImpl, - linear_cross_entropy, -) diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce.py deleted file mode 100644 index d12e0bd49a3d..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -from dataclasses import dataclass -from typing import Union, cast - -import paddle - -from .cce_backward import cce_backward_kernel -from .cce_lse_forward import cce_lse_forward_kernel -from .constants import IGNORE_INDEX -from .doc import LINEAR_CROSS_ENTROPY_DOC, add_doc_start -from .indexed_dot import indexed_neg_dot_forward_kernel -from .utils import _build_flat_valids, _handle_eps, handle_reduction_none - - -@dataclass -class CCEParams: - targets: paddle.Tensor - valids: Union[paddle.Tensor, None] - softcap: Union[float, None] - reduction: str - filter_eps: Union[float, None] - shift: bool - batch_shape: list - - -def sort_logit_avg(logit_avg: paddle.Tensor) -> paddle.Tensor: - return paddle.argsort(logit_avg) - - -class LinearCrossEntropyFunction(paddle.autograd.PyLayer): - @staticmethod - def forward( - ctx, - e: paddle.Tensor, - c: paddle.Tensor, - params: CCEParams, - ) -> paddle.Tensor: - needs_grad = not e.stop_gradient or not c.stop_gradient - return_logit_avg = needs_grad and params.filter_eps is not None - - ret = cce_lse_forward_kernel( - e, - c, - params.valids, - softcap=params.softcap, - return_logit_avg=return_logit_avg, - ) - if return_logit_avg: - assert isinstance(ret, tuple) - lse, logit_avg = ret - else: - assert isinstance(ret, paddle.Tensor) - lse = ret - logit_avg = None - - neg_dot = indexed_neg_dot_forward_kernel( - e, c, params.targets, params.shift, params.valids, params.softcap, lse.dtype - ) - - nll = neg_dot.add_(lse) - - reduction = params.reduction - if reduction == "mean": - loss = nll.mean() - elif reduction == "sum": - loss = nll.sum() - elif reduction == "none": - loss = handle_reduction_none(params.batch_shape, params.valids, params.shift, nll) - else: - raise ValueError(f"Unknown reduction {reduction}") - - ctx.save_for_backward(e, c, lse, params.targets, params.valids, logit_avg) - ctx.params = params - - return loss - - @staticmethod - def backward(ctx, grad_out: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor, None]: - h, w, lse, targets, valids, logit_avg = ctx.saved_tensor() - - if logit_avg is not None: - vocab_ordering = sort_logit_avg(logit_avg) - else: - vocab_ordering = None - - params = cast(CCEParams, ctx.params) - reduction = params.reduction - if reduction == "mean": - grad_scale = 1 / lse.numel().item() # need cast paddle.Tensor to float - elif reduction == "sum": - grad_scale = 1.0 - elif reduction == "none": - grad_scale = 1.0 - grad_out = grad_out.flatten() - else: - raise ValueError(f"Unknown reduction {reduction}") - - de, dc = cce_backward_kernel( - grad_out, - h, - w, - lse, - valids, - params.softcap, - params.filter_eps, - targets=targets, - shift=params.shift, - vocab_ordering=vocab_ordering, - grad_scale=grad_scale, - ) - - return de, dc - - -def linear_cross_entropy_apply( - e: paddle.Tensor, - c: paddle.Tensor, - params: CCEParams, -) -> paddle.Tensor: - loss = LinearCrossEntropyFunction.apply(e, c, params) - assert isinstance(loss, paddle.Tensor) - - if params.shift and params.reduction == "none": - loss = loss[..., 1:] - - return loss - - -@add_doc_start(LINEAR_CROSS_ENTROPY_DOC) -def cce_linear_cross_entropy( - e: paddle.Tensor, - c: paddle.Tensor, - targets: paddle.Tensor, - ignore_index: int = IGNORE_INDEX, - softcap: Union[float, None] = None, - reduction: str = "mean", - shift: bool = False, - filter_eps: Union[float, str, None] = "auto", -) -> paddle.Tensor: - """ - :param filter_eps: The threshold value used to determine which locations can be safely ignored - in gradient computation. The default value of "auto" will automatically choose a value - based on the input dtype. - """ - assert e.shape[0:-1] == targets.shape - assert e.shape[-1] == c.shape[1] - # if not paddle.device.cuda.is_bf16_supported(): - # raise RuntimeError( - # "Cut Cross Entropy requires an ampere GPU or newer. " - # "Consider using torch_compile_linear_cross_entropy for scenarios where one is not available." - # ) - - batch_shape = targets.shape - - e = e.contiguous() - targets = targets.contiguous() - - valids = _build_flat_valids(targets, ignore_index, shift) - - e = e.flatten(0, -2) - targets = targets.flatten() - - return linear_cross_entropy_apply( - e, - c, - CCEParams( - targets, - valids, - softcap, - reduction, - _handle_eps(filter_eps, e.dtype), - shift, - batch_shape, - ), - ) diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce_backward.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce_backward.py deleted file mode 100644 index cd9a06385b17..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce_backward.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -from typing import Union - -import paddle -import triton -import triton.language as tl - -from .tl_autotune import cce_backward_autotune -from .tl_utils import ( - b_bin_fn, - tl_and_reduce_fn, - tl_lock_add, - tl_softcapping, - tl_softcapping_grad, -) - - -@triton.jit -def _mm_backward( - do, - da_ptrs, - partial_mask_a, - da_lock_ptr, - n_locks, - b_ptrs, - partial_mask_b, - stride_ad, - stride_bd, - D, - BLOCK_D: tl.constexpr, - EVEN_D: tl.constexpr, -): - d_inds = tl.arange(0, BLOCK_D)[None, :] - - da_ptrs = da_ptrs + d_inds * stride_ad - b_ptrs = b_ptrs + d_inds * stride_bd - - for d in range(0, tl.cdiv(D, BLOCK_D)): - if EVEN_D: - mask = partial_mask_b - else: - mask = partial_mask_b & (d_inds < (D - d * BLOCK_D)) - - b = tl.load(b_ptrs, mask=mask, other=0.0) - - da_i = tl.dot(do, b).to(da_ptrs.dtype.element_ty) - - if EVEN_D: - mask = partial_mask_a - else: - mask = partial_mask_a & (d_inds < (D - d * BLOCK_D)) - - lock_offset = d // tl.cdiv(D, BLOCK_D * n_locks) - this_da_lock_ptr = da_lock_ptr + lock_offset - - tl_lock_add(da_ptrs, da_i, mask, this_da_lock_ptr) - - b_ptrs += BLOCK_D * stride_bd - da_ptrs += BLOCK_D * stride_ad - - -@triton.jit -def _block_is_filtered(check_val: tl.tensor, filter_eps: tl.tensor) -> tl.tensor: - return tl.reduce(check_val < filter_eps, None, tl_and_reduce_fn) - - -# @cce_backward_autotune() -# @triton.heuristics( -# { -# "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0, -# "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2, -# "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0, -# "HAS_VALIDS": lambda args: args["Valids"] is not None, -# "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None, -# "FILTER_GRAD": lambda args: args["filter_eps"] is not None, -# "HAS_TARGETS": lambda args: args["Targets"] is not None, -# "HAS_SOFTCAP": lambda args: args["softcap"] is not None, -# "ITEM_DO": lambda args: (args["dOut"].numel() == 1).item(), -# "GROUP_B": lambda args: 8, -# } -# ) -# @triton.jit -def _cce_backward_kernel( - E, - C, - LSE, - dOut, - grad_scale, - Valids, - VocabOrdering, - softcap, - Targets, - dE, - dELocks, - dC, - dCLocks, - B, - D, - V, - n_de_locks_0, - n_de_locks_1, - n_dc_locks_0, - n_dc_locks_1, - stride_eb, - stride_ed, - stride_cv, - stride_cd, - stride_vb, - filter_eps, - B_BIN, - BLOCK_B: tl.constexpr, - BLOCK_V: tl.constexpr, - BLOCK_D: tl.constexpr, - MM_BACK_BLOCK_D: tl.constexpr, - GROUP_B: tl.constexpr, - EVEN_D: tl.constexpr, - MM_BACK_EVEN_D: tl.constexpr, - ITEM_DO: tl.constexpr, - HAS_VALIDS: tl.constexpr, - HAS_VOCAB_ORDERING: tl.constexpr, - FILTER_GRAD: tl.constexpr, - HAS_TARGETS: tl.constexpr, - HAS_SOFTCAP: tl.constexpr, - SHIFT: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_b_chunks = tl.cdiv(B, BLOCK_B) - num_v_chunks = tl.cdiv(V, BLOCK_V) - num_v_in_group = GROUP_B * num_v_chunks - group_id = pid // num_v_in_group - first_pid_b = group_id * GROUP_B - group_size_b = min(num_b_chunks - first_pid_b, GROUP_B) - pid_b = first_pid_b + ((pid % num_v_in_group) % group_size_b) - pid_v = (pid % num_v_in_group) // group_size_b - - offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B - if HAS_VALIDS: - offs_b = tl.load(Valids + stride_vb * offs_b) - - offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) % V - if HAS_VOCAB_ORDERING: - offs_v = tl.load(VocabOrdering + offs_v) - - offs_d = tl.arange(0, BLOCK_D) - e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) - c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd) - - accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32) - for d in range(0, tl.cdiv(D, BLOCK_D)): - if EVEN_D: - e = tl.load(e_ptrs) - c = tl.load(c_ptrs) - else: - e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0) - c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0) - - accum = tl.dot(e, c, accum) - - e_ptrs += BLOCK_D * stride_ed - c_ptrs += BLOCK_D * stride_cd - - if HAS_SOFTCAP: - accum = tl_softcapping(accum, softcap) - - if HAS_VALIDS: - lse = tl.load(LSE + (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B) - else: - lse = tl.load(LSE + offs_b) - - d_accum = tl.exp(accum - lse[:, None]) - - if HAS_TARGETS: - targets = tl.load(Targets + ((offs_b + 1) if SHIFT else offs_b)) - is_target = targets[:, None] == offs_v[None, :] - d_accum += tl.where(is_target, -1.0, 0.0) - else: - is_target = None - - accum_valid_mask = ((pid_b * BLOCK_B + tl.arange(0, BLOCK_B))[:, None] < B) & ( - (pid_v * BLOCK_V + tl.arange(0, BLOCK_V))[None, :] < V - ) - d_accum = tl.where(accum_valid_mask, d_accum, 0.0) - - if FILTER_GRAD: - if _block_is_filtered(tl.abs(d_accum), filter_eps): - return - - if HAS_SOFTCAP: - d_accum = tl_softcapping_grad(d_accum, accum, softcap) - - if ITEM_DO: - d_out = tl.load(dOut) - else: - d_out = tl.load(dOut + ((offs_b + 1) if SHIFT else offs_b))[:, None] - - d_out = grad_scale * d_out - - d_accum = (d_accum * d_out).to(e_ptrs.dtype.element_ty) - - b_mask = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)[:, None]) < B - v_mask = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)[:, None]) < V - - lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1 - dELocks += lock_offset - - _mm_backward( - d_accum, - dE + (offs_b[:, None] * stride_eb), - b_mask, - dELocks, - n_de_locks_1, - C + offs_v[:, None] * stride_cv, - v_mask, - stride_ed, - stride_cd, - D, - MM_BACK_BLOCK_D, - MM_BACK_EVEN_D, - ) - - lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1 - dCLocks += lock_offset - - _mm_backward( - tl.trans(d_accum), - dC + (offs_v[:, None] * stride_cv), - v_mask, - dCLocks, - n_dc_locks_1, - E + (offs_b[:, None] * stride_eb), - b_mask, - stride_cd, - stride_ed, - D, - MM_BACK_BLOCK_D, - MM_BACK_EVEN_D, - ) - - -# Fix issues https://github.com/apple/ml-cross-entropy/issues/6 -_cce_backward_kernel = triton.jit(_cce_backward_kernel) -_cce_backward_kernel = triton.heuristics( - { - "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0, - "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2, - "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0, - "HAS_VALIDS": lambda args: args["Valids"] is not None, - "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None, - "FILTER_GRAD": lambda args: args["filter_eps"] is not None, - "HAS_TARGETS": lambda args: args["Targets"] is not None, - "HAS_SOFTCAP": lambda args: args["softcap"] is not None, - "ITEM_DO": lambda args: (args["dOut"].numel() == 1).item(), - "GROUP_B": lambda args: 8, - } -)(_cce_backward_kernel) -_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel) - - -def cce_backward_kernel( - do: paddle.Tensor, - e: paddle.Tensor, - c: paddle.Tensor, - lse: paddle.Tensor, - valids: Union[paddle.Tensor, None], - softcap: Union[float, None], - filter_eps: Union[float, None], - targets: Union[paddle.Tensor, None] = None, - shift: bool = False, - vocab_ordering: Union[paddle.Tensor, None] = None, - grad_scale: float = 1.0, -) -> tuple[paddle.Tensor, paddle.Tensor]: - assert do.numel().item() in (e.shape[0], 1) - assert c.shape[1] == e.shape[1] - assert lse.shape[0] == e.shape[0] or (valids is not None and lse.shape[0] == valids.shape[0]) - assert e.dtype in ( - paddle.float16, - paddle.bfloat16, - ), "Backwards requires embeddings to be bf16 or fp16" - assert c.dtype in ( - paddle.float16, - paddle.bfloat16, - ), "Backwards requires classifier to be bf16 or fp16" - - do = do.contiguous() - lse = lse.contiguous() - - de = paddle.zeros_like(e) - dc = paddle.zeros_like(c) - - assert de.strides == e.strides - assert dc.strides == c.strides - - if valids is not None: - assert valids.ndim == 1 - B = valids.shape[0] - else: - B = e.shape[0] - - if do.numel().item() > 1: - do = do.contiguous() - lse = lse.contiguous() - assert do.strides[0] == lse.strides[0], f"{do.strides=}, {lse.strides=}" - - def grid(META): - return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(c.shape[0], META["BLOCK_V"]),) - - if vocab_ordering is not None: - assert vocab_ordering.ndim == 1 - assert vocab_ordering.numel().item() == dc.shape[0] - assert vocab_ordering.strides[0] == 1 - - nd_locks = triton.cdiv(c.shape[1], 64) - de_locks = paddle.zeros((triton.cdiv(B, nd_locks), nd_locks), dtype=paddle.int32) - dc_locks = paddle.zeros((triton.cdiv(c.shape[0], nd_locks), nd_locks), dtype=paddle.int32) - - _cce_backward_kernel[grid]( - e, - c, - lse, - do, - grad_scale, - valids, - vocab_ordering, - softcap, - targets, - de, - de_locks, - dc, - dc_locks, - B, - e.shape[1], - c.shape[0], - de_locks.shape[0], - de_locks.shape[1], - dc_locks.shape[0], - dc_locks.shape[1], - e.strides[0], - e.strides[1], - c.strides[0], - c.strides[1], - 1 if valids is None else valids.strides[0], - filter_eps, - B_BIN=b_bin_fn(B), - SHIFT=shift, - ) - - return de, dc diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce_lse_forward.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce_lse_forward.py deleted file mode 100644 index f8c03427e376..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/cce_lse_forward.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -from typing import Literal, Union, overload - -import paddle -import triton -import triton.language as tl - -from .tl_autotune import cce_forward_autotune -from .tl_utils import b_bin_fn, tl_logaddexp, tl_softcapping - - -def get_float32_matmul_precision(): - return "high" - - -@cce_forward_autotune() -@triton.heuristics( - { - "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, - "HAS_VALIDS": lambda args: args["Valids"] is not None, - "HAS_SOFTCAP": lambda args: args["softcap"] is not None, - "HAS_LA": lambda args: args["LA"] is not None, - "GROUP_B": lambda args: 8, - "DOT_PRECISION": lambda args: "tf32" if get_float32_matmul_precision() == "high" else "ieee", - } -) -@triton.jit -def _cce_lse_forward_kernel( - E, - C, - LSE, - LA, - Locks, - Valids, - softcap, - B, - V, - D, - stride_eb, - stride_ed, - stride_cv, - stride_cd, - stride_lse_b, - stride_vb, - num_locks, - # Meta-parameters - B_BIN, - HAS_VALIDS: tl.constexpr, - BLOCK_B: tl.constexpr, - BLOCK_V: tl.constexpr, - BLOCK_D: tl.constexpr, # - GROUP_B: tl.constexpr, # - EVEN_D: tl.constexpr, - HAS_SOFTCAP: tl.constexpr, - HAS_LA: tl.constexpr, - DOT_PRECISION: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_pid_b = tl.cdiv(B, BLOCK_B) - num_pid_v = tl.cdiv(V, BLOCK_V) - num_pid_in_group = GROUP_B * num_pid_v - group_id = pid // num_pid_in_group - first_pid_b = group_id * GROUP_B - group_size_b = min(num_pid_b - first_pid_b, GROUP_B) - pid_b = first_pid_b + ((pid % num_pid_in_group) % group_size_b) - pid_v = (pid % num_pid_in_group) // group_size_b - - offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B - if HAS_VALIDS: - offs_b = tl.load(Valids + stride_vb * offs_b) - - offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) % V - offs_d = tl.arange(0, BLOCK_D) - e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) - c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd) - - accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32) - for d in range(0, tl.cdiv(D, BLOCK_D)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - if EVEN_D: - e = tl.load(e_ptrs) - c = tl.load(c_ptrs) - else: - e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0) - c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0) - accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION) - e_ptrs += BLOCK_D * stride_ed - c_ptrs += BLOCK_D * stride_cd - - v_mask = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) < V - logits = tl.where(v_mask[None, :], accum, -float("inf")) - if HAS_SOFTCAP: - logits = tl_softcapping(logits, softcap) - - off_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) - o_mask = off_b < B - if HAS_LA: - logits = tl.where(o_mask[:, None], logits, 0.0) - this_avg_logit = tl.sum(logits, 0) / B - tl.atomic_add(LA + offs_v, this_avg_logit, mask=v_mask) - - this_mx = tl.max(logits, axis=1) - e = tl.exp(logits - this_mx[:, None]) - this_lse = this_mx + tl.log(tl.sum(e, axis=1)) - - lse_ptrs = LSE + (stride_lse_b * off_b) - - this_locks = Locks + (pid_b // tl.cdiv(B, BLOCK_B * num_locks)) - while tl.atomic_cas(this_locks, 0, 1) == 1: - pass - - lse = tl.load(lse_ptrs, mask=o_mask, other=0.0, eviction_policy="evict_last") - lse = tl_logaddexp(lse, this_lse) - tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last") - - tl.atomic_xchg(this_locks, 0) - - -@overload -def cce_lse_forward_kernel( - e, - c, - valids: Union[paddle.Tensor, None] = None, - softcap: Union[float, None] = None, - return_logit_avg: Literal[False] = False, -) -> paddle.Tensor: - ... - - -@overload -def cce_lse_forward_kernel( - e, - c, - valids: Union[paddle.Tensor, None] = None, - softcap: Union[float, None] = None, - return_logit_avg: Literal[True] = True, -) -> tuple[paddle.Tensor, paddle.Tensor]: - ... - - -@overload -def cce_lse_forward_kernel( - e, - c, - valids: Union[paddle.Tensor, None] = None, - softcap: Union[float, None] = None, - return_logit_avg: bool = False, -) -> Union[tuple[paddle.Tensor, paddle.Tensor], paddle.Tensor]: - ... - - -def cce_lse_forward_kernel( - e: paddle.Tensor, - c: paddle.Tensor, - valids: Union[paddle.Tensor, None] = None, - softcap: Union[float, None] = None, - return_logit_avg: bool = False, -) -> Union[tuple[paddle.Tensor, paddle.Tensor], paddle.Tensor]: - # Check constraints. - assert e.shape[1] == c.shape[1], "Incompatible dimensions" - assert e.is_contiguous(), "Matrix A must be contiguous" - if valids is not None: - assert valids.ndim == 1 - B = valids.numel().item() - else: - B, _ = e.shape - - V, D = c.shape - # Allocates output. - lse = paddle.full((B,), -float("inf"), dtype=paddle.float32) - - locks = paddle.full( - (triton.cdiv(B, 128),), - 0, - dtype="int32", # paddle do not support uint32, so we use int32 - ) - if return_logit_avg: - logit_avg = paddle.full((V,), 0.0, dtype=paddle.float32) - else: - logit_avg = None - - # 1D launch kernel where each block gets its own program. - def grid(META) -> tuple[int]: - return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(V, META["BLOCK_V"]),) - - _cce_lse_forward_kernel[grid]( - e, - c, - lse, # - logit_avg, - locks, - valids, - softcap, - B, - V, - D, # - e.strides[0], - e.strides[1], # - c.strides[0], - c.strides[1], # - lse.strides[0], - 1 if valids is None else valids.strides[0], - num_locks=locks.shape[0], - B_BIN=b_bin_fn(B), - ) - - if return_logit_avg: - assert logit_avg is not None - return lse, logit_avg - else: - return lse diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/constants.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/constants.py deleted file mode 100644 index 5747d4e6ec79..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/constants.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -IGNORE_INDEX: int = -100 diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/doc.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/doc.py deleted file mode 100644 index 1ee6baaa0f95..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/doc.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -LINEAR_CROSS_ENTROPY_DOC = """Computes cross-entropy loss using the logits generated by performing - the matrix multiplication between the embeddings (e) and classifier (c). - - This method saves GPU memory by not materializing the logits into GPU - main memory. - - - Specifically, this computes - - ```python - - loss = F.cross_entropy((e @ c.T).float(), targets) - ``` - - without allocating the intermediary (e @ c.T).float() matrix. - - :param e: Embedding of the inputs used to compute the logits. Shape (..., D) - :param c: Classifier matrix. Shape (NumClasses, D) - :param targets: The target class for each input. Values must be in [0, NumClasses). Shape (...) - :param ignore_index: If an input as a target of this value, it is ignored in the loss computation. - :param softcap: The value for logit softcapping. - :param reduction: The reduction to perform over the loss. Supports "mean", "sum", and "none". - :param shift: If true, the embedding and targets are assumed to require a shift along the - temporal axis to perform next token prediction. Specifically, setting this to true - will efficiently compute - - ```python - shift_e = e[..., :-1, :].flatten(0, -2) - shift_targets = targets[..., 1:].flatten() - - loss = F.cross_entropy((shift_e @ c.T), targets) - ``` -""" - - -def add_doc_start(*docstr: str): - def add_doc(fn): - fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") - - return fn - - return add_doc diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/indexed_dot.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/indexed_dot.py deleted file mode 100644 index 599d40a9b535..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/indexed_dot.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -from typing import Union - -import paddle -import triton -import triton.language as tl - -from .tl_autotune import indexed_dot_autotune -from .tl_utils import b_bin_fn -from .utils import softcapping - - -@indexed_dot_autotune() -@triton.heuristics( - { - "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, - "HAS_VALIDS": lambda args: args["Valids"] is not None, - "GROUP_B": lambda args: 8, - } -) -@triton.jit -def _indexed_neg_dot_forward_kernel( - E, - C, - Inds, - Valids, - Out, - B, - D, - stride_eb, - stride_ed, - stride_cv, - stride_cd, - stride_ib, - stride_vb, - B_BIN, - BLOCK_B: tl.constexpr, - BLOCK_D: tl.constexpr, - GROUP_B: tl.constexpr, - HAS_VALIDS: tl.constexpr, - EVEN_D: tl.constexpr, - SHIFT: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_b_chunks = tl.cdiv(B, BLOCK_B) - num_d_chunks = tl.cdiv(D, BLOCK_D) - num_d_in_group = GROUP_B * num_d_chunks - group_id = pid // num_d_in_group - first_pid_b = group_id * GROUP_B - group_size_b = min(num_b_chunks - first_pid_b, GROUP_B) - pid_b = first_pid_b + ((pid % num_d_in_group) % group_size_b) - pid_d = (pid % num_d_in_group) // group_size_b - - offs_b = (tl.arange(0, BLOCK_B) + pid_b * BLOCK_B) % B - if HAS_VALIDS: - offs_b = tl.load(Valids + stride_vb * offs_b) - - offs_d = tl.arange(0, BLOCK_D) + pid_d * BLOCK_D - e_ptrs = E + (stride_eb * offs_b[:, None] + stride_ed * offs_d[None, :]) - if EVEN_D: - e = tl.load(e_ptrs) - else: - e = tl.load(e_ptrs, mask=offs_d[None, :] < D, other=0.0) - - inds = tl.load(Inds + stride_ib * ((offs_b + 1) if SHIFT else offs_b)) - - c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) - if EVEN_D: - c = tl.load(c_ptrs) - else: - c = tl.load(c_ptrs, mask=offs_d[None, :] < D, other=0.0) - - offs_b = tl.arange(0, BLOCK_B) + pid_b * BLOCK_B - out_ptrs = Out + offs_b - dot = (e * c).to(tl.float32) - neg_dot = -tl.sum(dot, 1).to(out_ptrs.dtype.element_ty) - tl.atomic_add(out_ptrs, neg_dot, mask=offs_b < B) - - -def indexed_neg_dot_forward_kernel( - e: paddle.Tensor, - c: paddle.Tensor, - inds: paddle.Tensor, - shift: bool = False, - valids: Union[paddle.Tensor, None] = None, - softcap: Union[float, None] = None, - out_dtype: Union[paddle.dtype, None] = None, -) -> paddle.Tensor: - assert inds.ndim == 1 - assert e.ndim == 2 - assert c.ndim == 2 - assert inds.shape[0] == e.shape[0] - assert c.shape[1] == e.shape[1] - - if valids is not None: - assert valids.ndim == 1 - B = valids.shape[0] - else: - B = e.shape[0] - - out = paddle.zeros((B,), dtype=paddle.float32) - - def grid(META) -> tuple[int]: - return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(e.shape[1], META["BLOCK_D"]),) - - _indexed_neg_dot_forward_kernel[grid]( - e, - c, - inds, - valids, - out, - B, - e.shape[1], - e.strides[0], - e.strides[1], - c.strides[0], - c.strides[1], - inds.strides[0], - 1 if valids is None else valids.strides[0], - B_BIN=b_bin_fn(B), - SHIFT=shift, - ) - - if softcap is not None: - out = softcapping(out, softcap) - - if out_dtype is None: - out_dtype = e.dtype - - out = out.cast(out_dtype) - - return out diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/linear_cross_entropy.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/linear_cross_entropy.py deleted file mode 100644 index 289f4557b427..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/linear_cross_entropy.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -import enum -from enum import auto -from typing import Union - -import paddle -import paddle.nn as nn - -from .cce import cce_linear_cross_entropy -from .constants import IGNORE_INDEX -from .doc import LINEAR_CROSS_ENTROPY_DOC, add_doc_start - - -class LinearCrossEntropyImpl(enum.IntEnum): - CCE = auto() - - -@add_doc_start(LINEAR_CROSS_ENTROPY_DOC) -def linear_cross_entropy( - e: paddle.Tensor, - c: paddle.Tensor, - targets: paddle.Tensor, - ignore_index: int = IGNORE_INDEX, - softcap: Union[float, None] = None, - reduction: str = "mean", - shift: bool = False, - filter_eps: Union[float, str, None] = "auto", - impl: Union[str, LinearCrossEntropyImpl] = LinearCrossEntropyImpl.CCE, -) -> paddle.Tensor: - """ - :param filter_eps: The threshold value used to determine which locations can be safely ignored - in gradient computation. The default value of "auto" will automatically choose a value - based on the input dtype. Only valid for the CCE implementation. - :param impl: The linear cross entropy implementation to use. Currently supports cce and torch_compile. - """ - - if isinstance(impl, LinearCrossEntropyImpl): - impl = impl.name.lower() - - if impl == "cce": - return cce_linear_cross_entropy(e, c, targets, ignore_index, softcap, reduction, shift, filter_eps) - else: - raise NotImplementedError(f"{impl} is not implemented.") - - -class LinearCrossEntropy(nn.Layer): - def __init__( - self, - ignore_index: int = IGNORE_INDEX, - softcap: Union[float, None] = None, - reduction: str = "mean", - filter_eps: Union[float, str, None] = "auto", - shift: bool = False, - impl: Union[str, LinearCrossEntropyImpl] = LinearCrossEntropyImpl.CCE, - ): - super().__init__() - self.ignore_index = ignore_index - self.softcap = softcap - self.reduction = reduction - self.filter_eps = filter_eps - self.shift = shift - - self.impl = impl - - def forward(self, e: paddle.Tensor, c: paddle.Tensor, targets: paddle.Tensor) -> paddle.Tensor: - return linear_cross_entropy( - e, - c, - targets, - self.ignore_index, - self.softcap, - reduction=self.reduction, - filter_eps=self.filter_eps, - shift=self.shift, - impl=self.impl, - ) diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/tl_autotune.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/tl_autotune.py deleted file mode 100644 index 77f492d6a89b..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/tl_autotune.py +++ /dev/null @@ -1,472 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -import functools -import heapq -import os -from typing import Callable - -import paddle -import triton -from triton import Config, cdiv -from triton.runtime import autotuner, driver -from triton.testing import ( - get_dram_gbps, - get_max_simd_tflops, - get_max_tensorcore_tflops, - nvsmi, -) - -_AUTOTUNE: bool = os.getenv("CCE_AUTOTUNE", "0") != "0" - - -@functools.lru_cache() -def get_clock_rate_in_khz(): - try: - return nvsmi(["clocks.max.sm"])[0] * 1e3 - except FileNotFoundError: - import pynvml - - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 - - -def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): - """return compute throughput in TOPS""" - total_warps = num_ctas * min(num_warps, 4) - num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs - tflops = ( - min(num_subcores, total_warps) - / num_subcores - * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) - ) - return tflops - - -def get_simd_tflops(device, num_ctas, num_warps, dtype): - """return compute throughput in TOPS""" - total_warps = num_ctas * min(num_warps, 4) - num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs - tflops = ( - min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) - ) - return tflops - - -def get_tflops(device, num_ctas, num_warps, dtype): - capability = paddle.device.cuda.get_device_capability(device) - if capability[0] < 8 and dtype == paddle.float32: - return get_simd_tflops(device, num_ctas, num_warps, dtype) - return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) - - -def early_config_prune( - configs, - named_args, - *, - shared_memory_factor: float = 1.0, - max_num_warps: int = None, - **kwargs, -): - device = paddle.get_device() - device = int(device.split(":")[-1]) - capability = paddle.device.cuda.get_device_capability() - # BLOCK_B, BLOCK_V, BLOCK_D, SPLIT_K, num_warps, num_stages - dtsize = named_args["E"].element_size() - - if max_num_warps is not None: - configs = [config for config in configs if config.num_warps <= max_num_warps] - - # 1. make sure we have enough smem - pruned_configs = [] - for config in configs: - kw = config.kwargs - BLOCK_B, BLOCK_V, BLOCK_D, num_stages = ( - kw["BLOCK_B"], - kw["BLOCK_V"], - kw["BLOCK_D"], - config.num_stages, - ) - - max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] - required_shared_memory = shared_memory_factor * (BLOCK_B + BLOCK_V) * BLOCK_D * num_stages * dtsize - if required_shared_memory > max_shared_memory: - continue - - pruned_configs.append(config) - - configs = pruned_configs - - # group configs by (BLOCK_B,_N,_K, num_warps) - configs_map = {} - for config in configs: - kw = config.kwargs - BLOCK_B, BLOCK_V, BLOCK_D, num_warps, num_stages = ( - kw["BLOCK_B"], - kw["BLOCK_V"], - kw["BLOCK_D"], - config.num_warps, - config.num_stages, - ) - - key = (BLOCK_B, BLOCK_V, BLOCK_D, num_warps) - if key in configs_map: - configs_map[key].append((config, num_stages)) - else: - configs_map[key] = [(config, num_stages)] - - pruned_configs = [] - for k, v in configs_map.items(): - BLOCK_B, BLOCK_V, BLOCK_D, num_warps = k - if capability[0] >= 8: - # compute cycles (only works for ampere GPUs) - mmas = BLOCK_B * BLOCK_V * BLOCK_D / (16 * 8 * 16) - mma_cycles = mmas / min(4, num_warps) * 8 - - ldgsts_latency = 300 # Does this matter? - optimal_num_stages = ldgsts_latency / mma_cycles - - # nearest stages, prefer large #stages - nearest = heapq.nsmallest( - 2, - v, - key=lambda x: 10 + abs(x[1] - optimal_num_stages) - if (x[1] - optimal_num_stages) < 0 - else x[1] - optimal_num_stages, - ) - - for n in nearest: - pruned_configs.append(n[0]) - else: # Volta & Turing only supports num_stages <= 2 - random_config = v[0][0] - random_config.num_stages = 2 - pruned_configs.append(random_config) - return pruned_configs - - -def _total_ops_fn(B, V, D) -> float: - return 2 * B * V * D + 10 * B * V - - -def _total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v): - return B * dtsize - - -def estimate_matmul_time( - # backend, device, - num_warps, - num_stages, # - E, - B, - V, - D, # - BLOCK_B, - BLOCK_V, - BLOCK_D, - debug=False, - total_ops_fn=_total_ops_fn, - total_store_fn=_total_store_fn, - **kwargs, # -): - """return estimated running time in ms - = max(compute, loading) + store""" - device = paddle.get_device() - device = int(device.split(":")[-1]) - dtype = E.dtype - dtsize = E.element_size() - - num_cta_b = cdiv(B, BLOCK_B) - num_cta_v = cdiv(V, BLOCK_V) - num_ctas = num_cta_b * num_cta_v - - # If the input is smaller than the block size - B, V = max(B, BLOCK_B), max(V, BLOCK_V) - - # time to compute - total_ops = total_ops_fn(B, V, D) - total_ops = total_ops / (1024 * 1024 * 1024) # GOPS - tput = get_tflops(device, num_ctas, num_warps, dtype) - compute_ms = total_ops / tput - - # time to load data - num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] - active_cta_ratio = min(1, num_ctas / num_sm) - active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate - active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% - dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s - l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) - # assume 80% of (following) loads are in L2 cache - load_a_dram = B * D * dtsize * (1 + 0.2 * (num_cta_v - 1)) - load_a_l2 = B * D * dtsize * 0.8 * (num_cta_v - 1) - load_b_dram = V * D * dtsize * (1 + 0.2 * (num_cta_b - 1)) - load_b_l2 = V * D * dtsize * 0.8 * (num_cta_b - 1) - # total - total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB - total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) - # loading time in ms - load_ms = total_dram / dram_bw + total_l2 / l2_bw - - # estimate storing time - store_bw = dram_bw * 0.4 # :o - store_dram = total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v) / (1024 * 1024) - store_ms = store_dram / store_bw - - total_time_ms = max(compute_ms, load_ms) + store_ms - if debug: - print( - f"{BLOCK_B=}, {BLOCK_V=}, {BLOCK_D=}, {num_warps=}, {num_stages=}, " - f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " - f"loading time: {load_ms}ms, store time: {store_ms}ms, " - f"Activate CTAs: {active_cta_ratio*100}%" - ) - return total_time_ms - - -def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - Config( - { - "BLOCK_B": block_m, - "BLOCK_V": block_n, - "BLOCK_D": block_k, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - return configs - - -def get_autotune_config(): - return [ - # basic configs for compute-bound matmuls - Config( - {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 128}, - num_stages=2, - num_warps=4, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 32}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 32}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_B": 256, "BLOCK_V": 64, "BLOCK_D": 32}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 64, "BLOCK_V": 256, "BLOCK_D": 32}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, - num_stages=4, - num_warps=8, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 64, "BLOCK_D": 32}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 64, "BLOCK_V": 128, "BLOCK_D": 32}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 32, "BLOCK_D": 32}, - num_stages=4, - num_warps=4, - ), - Config({"BLOCK_B": 64, "BLOCK_V": 32, "BLOCK_D": 32}, num_stages=5, num_warps=2), - # good for int8 - Config( - {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 128}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 128}, - num_stages=3, - num_warps=16, - ), - Config( - {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 128}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 128}, - num_stages=3, - num_warps=16, - ), - Config( - {"BLOCK_B": 256, "BLOCK_V": 64, "BLOCK_D": 128}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 64, "BLOCK_V": 256, "BLOCK_D": 128}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 128}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 64, "BLOCK_D": 64}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 64, "BLOCK_V": 128, "BLOCK_D": 64}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_B": 128, "BLOCK_V": 32, "BLOCK_D": 64}, - num_stages=4, - num_warps=4, - ), - Config({"BLOCK_B": 64, "BLOCK_V": 32, "BLOCK_D": 64}, num_stages=5, num_warps=2), - ] + get_configs_io_bound() - - -def _heuristics_from_config(config: Config) -> Callable[..., autotuner.Heuristics]: - return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()}) - - -def _cce_forward_best_config() -> Config: - return Config(dict(BLOCK_B=256, BLOCK_V=128, BLOCK_D=32), num_warps=8, num_stages=3) - - -def cce_forward_autotune() -> Callable[..., autotuner.PaddleAutotuner]: - if _AUTOTUNE: - return triton.paddle_autotune( - configs=get_autotune_config(), - key=["V", "D", "B_BIN"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, - restore_value=["LSE"], - ) - else: - return _heuristics_from_config(_cce_forward_best_config()) - - -def _bw_total_ops_fn(B, V, D) -> float: - return 2 * B * V * D + 6 * B * V + 0.2 * (2 * B * V * D + 2 * B * V * D) - - -def _bw_total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v): - return 0.2 * (num_cta_v * B * D * dtsize + num_cta_b * D * V * dtsize) - - -def _cce_backward_best_config() -> Config: - return Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4) - - -def cce_backward_autotune() -> Callable[..., autotuner.PaddleAutotuner]: - if _AUTOTUNE: - return triton.paddle_autotune( - configs=get_autotune_config(), - key=["V", "D", "B_BIN"], - prune_configs_by={ - "early_config_prune": functools.partial(early_config_prune, shared_memory_factor=2.0), - "perf_model": functools.partial( - estimate_matmul_time, - total_ops_fn=_bw_total_ops_fn, - total_store_fn=_bw_total_store_fn, - ), - "top_k": 5, - }, - reset_to_zero=["dE", "dC"], - ) - else: - return _heuristics_from_config(_cce_backward_best_config()) - - -def _indexed_dot_best_config() -> Config: - return Config(dict(BLOCK_B=128, BLOCK_D=256), num_warps=16, num_stages=4) - - -def _indexed_dot_all_configs() -> list[Config]: - return [ - Config( - dict( - BLOCK_B=128, - BLOCK_D=128, - ), - num_warps=4, - num_stages=4, - ), - Config( - dict( - BLOCK_B=128, - BLOCK_D=128, - ), - num_warps=8, - num_stages=4, - ), - Config( - dict( - BLOCK_B=256, - BLOCK_D=256, - ), - num_warps=16, - num_stages=4, - ), - Config( - dict( - BLOCK_B=256, - BLOCK_D=128, - ), - num_warps=16, - num_stages=4, - ), - Config( - dict( - BLOCK_B=128, - BLOCK_D=256, - ), - num_warps=16, - num_stages=4, - ), - ] - - -def indexed_dot_autotune() -> Callable[..., autotuner.PaddleAutotuner]: - if _AUTOTUNE: - return triton.paddle_autotune( - configs=_indexed_dot_all_configs(), - key=["D", "B_BIN"], - reset_to_zero=["Out"], - ) - else: - return _heuristics_from_config(_indexed_dot_best_config()) diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/tl_utils.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/tl_utils.py deleted file mode 100644 index e46ce70b4324..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/tl_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -import triton -import triton.language as tl -from triton.language.extra import libdevice as tl_libdevice - - -@triton.jit -def tl_and_reduce_fn(a, b): - return a & b - - -@triton.jit -def tl_tanh(a: tl.tensor) -> tl.tensor: - return tl_libdevice.tanh(a) - - -@triton.jit -def tl_log1p(a: tl.tensor) -> tl.tensor: - return tl_libdevice.log1p(a) - - -@triton.jit -def tl_softcapping(v: tl.tensor, softcap: float) -> tl.tensor: - return tl_tanh(v / softcap) * softcap - - -@triton.jit -def tl_softcapping_grad(dv: tl.tensor, v: tl.tensor, softcap: float) -> tl.tensor: - v = v / softcap - return dv * (1 - v * v) - - -@triton.jit -def tl_logaddexp(a, b) -> tl.tensor: - minx = tl.minimum(a, b) - mx = tl.maximum(a, b) - return tl_log1p(tl.exp(minx - mx)) + mx - - -@triton.jit -def tl_lock_add(ptrs, v, mask, lock_ptr): - while tl.atomic_cas(lock_ptr, 0, 1) == 1: - pass - - cur_v = tl.load(ptrs, mask=mask, other=0.0, eviction_policy="evict_last") - new_v = v + cur_v - tl.store(ptrs, new_v, mask=mask, eviction_policy="evict_last") - - tl.atomic_xchg(lock_ptr, 0) - - -def b_bin_fn(b: int) -> int: - if b >= 1024: - return 1024 - elif b <= 128: - return 128 - else: - return 512 diff --git a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/utils.py b/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/utils.py deleted file mode 100644 index 8f31f8dd8fa4..000000000000 --- a/ops/src/paddlenlp_kernel/triton/cut_cross_entropy/utils.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -from typing import Union - -import numpy as np -import paddle - - -def softcapping(logits: paddle.Tensor, softcap: float) -> paddle.Tensor: - return paddle.tanh(logits / softcap) * softcap - - -def _handle_eps(filter_eps: Union[float, str, None], dtype: paddle.dtype) -> Union[float, None]: - if filter_eps is None: - return None - elif isinstance(filter_eps, float): - return filter_eps - elif isinstance(filter_eps, str) and filter_eps == "auto": - return paddle.finfo(dtype).eps / 32 - else: - raise RuntimeError(f"Unknown eps {filter_eps=}") - - -def _build_flat_valids( - targets: paddle.Tensor, - ignore_index: int, - shift: bool, -) -> Union[paddle.Tensor, None]: - if shift: - targets = targets[..., 1:] - else: - targets = targets.flatten() - - valids = (targets != ignore_index).nonzero().cast(paddle.int32) - - if not shift: - assert valids.shape[1] == 1 - return valids.squeeze(1) if valids.numel() != targets.numel() else None - - for i in range(targets.ndim - 1): - valids[:, i] *= targets.strides[i] - - assert targets.strides[-1] == 1 - - return valids.sum(1) - - -def handle_reduction_none( - batch_shape: list, valids: Union[paddle.Tensor, None], shift: bool, loss: paddle.Tensor -) -> paddle.Tensor: - if valids is None: - return loss.reshape(batch_shape) - - full_loss = paddle.zeros(np.prod(batch_shape), dtype=loss.dtype) - full_loss[(valids + 1) if shift else valids] = loss - - return full_loss.reshape(batch_shape) diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py b/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py deleted file mode 100644 index 371bdba8a6de..000000000000 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .flash import cal_flash_loss -from .ring import cal_inf_loss, cal_ring_loss diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/flash.py b/ops/src/paddlenlp_kernel/triton/inf_cl/flash.py deleted file mode 100644 index 0a1c709f62c6..000000000000 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/flash.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -this code is modified from https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_cl/flash.py -""" -import math - -import numpy as np -import paddle -import paddle.autograd -import paddle.nn.functional as F -import triton -import triton.language as tl - - -@triton.jit -def _prob_fwd_kernel( - Q, - K, - LSE, - nheads, - seqlen_q, - seqlen_k, - BLOCK_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - # start index of sequence length - start_m = tl.program_id(0) - - # initialize offsets - ndims = nheads * BLOCK_HEADDIM - offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - - # Initialize pointers to Q, K, V - q_ptrs = Q + ndims * offs_m[:, None] - k_ptrs = K + ndims * offs_n[:, None] - # initialize pointer to m and l - lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - - # loop over k, v and update accumulator - end_n = seqlen_k - for start_n in range(0, end_n, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - for off_h in range(nheads): - offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] - # -- fetch q and k of a single head ---- - q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) - k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) - # -- compute qk ---- - qk += tl.dot(q, tl.trans(k)) - - # Trying to combine the two masks seem to make the result wrong - m_ij = tl.maximum(tl.max(qk, 1), m_i) - p = tl.exp(qk - m_ij[:, None]) - # Fix out of bound access - p = tl.where((start_n + offs_n)[None, :] < seqlen_k, p, 0.0) - # -- update statistics - lse_i = tl.exp(m_i - m_ij) * lse_i + tl.sum(p, 1) - m_i = m_ij - - lse_i = m_i + tl.log(lse_i) - # mask out the padded values - lse_i = tl.where(offs_m < seqlen_q, lse_i, 0.0) - - tl.store(LSE + offs_m, lse_i) - - -@triton.jit -def _dq_prob_bwd_kernel( - Q, - K, - dQ, - LSE, - dLSE, - nheads, - seqlen_q, - seqlen_k, - BLOCK_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" - # start index of sequence length - start_m = tl.program_id(0) - - # initialize offsets - ndims = nheads * BLOCK_HEADDIM - offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - - # Initialize pointers to Q, K, V - q_ptrs = Q + ndims * offs_m[:, None] - dq_ptrs = dQ + ndims * offs_m[:, None] - k_ptrs = K + ndims * offs_n[:, None] - # setting lse - lse = tl.load(LSE + offs_m, mask=offs_m < seqlen_q, other=0.0) - dlse = tl.load(dLSE + offs_m, mask=offs_m < seqlen_q, other=0.0) - - # loop over k, v and update accumulator - end_n = seqlen_k - for start_n in range(0, end_n, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - for off_h in range(nheads): - offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] - # -- fetch q and k of a single head ---- - q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) - k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) - # -- compute qk ---- - qk += tl.dot(q, tl.trans(k)) - - qk_grad = tl.exp(qk - lse[:, None]) - qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, qk_grad, 0.0) - qk_grad = qk_grad * dlse[:, None] - qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1) - for off_h in range(nheads): - offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] - # -- fetch q and k of a single head ---- - q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) - k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) - # -- compute q grad ---- - # NOTE: tl.float32 adopt tf32, which causes precision inconsistency with torch - # A solution for this problem - # Refer to issue: https://github.com/triton-lang/triton/issues/4574 - # if allow_tf32: - k = tl.inline_asm_elementwise(ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1) - q_grad = tl.dot(qk_grad, k) - # Another solution for this problem - # Refer to https://github.com/triton-lang/triton/issues/376 - # q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False) - # -- store dq ---- - dq_h = tl.load(dq_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) - tl.store(dq_ptrs + offs_hd, dq_h + q_grad, mask=offs_m[:, None] < seqlen_q) - - -@triton.jit -def _dk_prob_bwd_kernel( - Q, - K, - dK, - LSE, - dLSE, - nheads, - seqlen_q, - seqlen_k, - BLOCK_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" - # start index of sequence length - start_n = tl.program_id(0) - - # initialize offsets - ndims = nheads * BLOCK_HEADDIM - offs_m = tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N - offs_d = tl.arange(0, BLOCK_HEADDIM) - - # Initialize pointers to Q, K, V - q_ptrs = Q + ndims * offs_m[:, None] - k_ptrs = K + ndims * offs_n[:, None] - dk_ptrs = dK + ndims * offs_n[:, None] - - # loop over q and update accumulator - end_m = seqlen_q - for start_m in range(0, end_m, BLOCK_M): - start_m = tl.multiple_of(start_m, BLOCK_M) - - # setting lse - lse = tl.load(LSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0) - dlse = tl.load(dLSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - for off_h in range(nheads): - offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] - # -- fetch q and k of a single head ---- - q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(offs_m + start_m)[:, None] < seqlen_q, other=0.0) - k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0) - # -- compute qk ---- - qk += tl.dot(q, tl.trans(k)) - - qk_grad = tl.exp(qk - lse[:, None]) - qk_grad = tl.where((start_m + offs_m)[:, None] < seqlen_q, qk_grad, 0.0) - qk_grad = qk_grad * dlse[:, None] - qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1) - for off_h in range(nheads): - offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] - # -- fetch q and k of a single head ---- - q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0) - k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0) - # -- compute k grad ---- - q = tl.inline_asm_elementwise(ASM, "=r, r", [q], dtype=tl.float32, is_pure=True, pack=1) - k_grad = tl.dot(tl.trans(qk_grad), q) - # k_grad = tl.dot(tl.trans(qk_grad), q.to(tl.float32)) - # -- store dk ---- - dk_h = tl.load(dk_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0) - tl.store(dk_ptrs + offs_hd, dk_h + k_grad, mask=(offs_n)[:, None] < seqlen_k) - - -def _flash_prob_forward(q, k): - # shape constraints - seqlen_q, nheads, d = q.shape - seqlen_k, _, _ = k.shape - assert k.shape == [seqlen_k, nheads, d] - # assert d <= 128, "FlashAttention only support head dimensions up to 128" - assert q.dtype == k.dtype, "All tensors must have the same type" - # assert q.dtype in [paddle.float16, paddle.bfloat16], "Only support fp16 and bf16" - - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = paddle.empty((seqlen_q_rounded,), dtype=paddle.float32) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK_M = 64 - BLOCK_N = 64 - num_warps = 8 - num_stages = 1 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1) - _prob_fwd_kernel[grid]( - q, - k, - lse, - nheads, - seqlen_q, - seqlen_k, - BLOCK_HEADDIM, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - - lse = lse[:seqlen_q] - - return lse - - -def _flash_prob_backward(q, k, lse, dlse): - # shape constraints - seqlen_q, nheads, d = q.shape - seqlen_k, _, _ = k.shape - assert k.shape == [seqlen_k, nheads, d] - # assert d <= 128, "FlashAttention only support head dimensions up to 128" - assert q.dtype == k.dtype, "All tensors must have the same type" - # assert q.dtype in [paddle.float16, paddle.bfloat16], "Only support fp16 and bf16" - - dq = paddle.zeros_like(q, dtype=paddle.float32) - dk = paddle.zeros_like(k, dtype=paddle.float32) - - q = q.contiguous() - k = k.contiguous() - dlse = dlse.contiguous() - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK_M = 64 - BLOCK_N = 64 - num_warps = 8 - num_stages = 1 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1) - _dq_prob_bwd_kernel[grid]( - q, - k, - dq, - lse, - dlse, - nheads, - seqlen_q, - seqlen_k, - BLOCK_HEADDIM, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - - BLOCK_N = BLOCK_M - BLOCK_M = BLOCK_N - grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]), 1) - _dk_prob_bwd_kernel[grid]( - q, - k, - dk, - lse, - dlse, - nheads, - seqlen_q, - seqlen_k, - BLOCK_HEADDIM, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - - dq = dq[:seqlen_q] - dk = dk[:seqlen_k] - - return dq, dk - - -class FlashProb(paddle.autograd.PyLayer): - @staticmethod - def forward(ctx, q, k): - lse = _flash_prob_forward(q, k) - ctx.save_for_backward(q, k, lse) - - return lse - - @staticmethod - def backward(ctx, dlse): - q, k, lse = ctx.saved_tensor() - dq, dk = _flash_prob_backward(q, k, lse, dlse) - - return dq, dk - - -def _cal_flash_loss(q, k, labels, head_dim=256): - bq = q.shape[0] - bk = k.shape[0] - # NOTE: logits forward or backward should keep fp32 for better precision - q = q.reshape([bq, -1, head_dim]).cast("float32") - k = k.reshape([bk, -1, head_dim]).cast("float32") - - lse = FlashProb.apply(q, k) - numerator = paddle.einsum("mhd,mhd->m", q, k[labels, ...]) - loss = -numerator + lse - - return loss - - -def cal_flash_loss(q, k, labels=None, scale=None, head_dim=256): - if labels is None: - labels = paddle.arange(q.shape[0]) - if scale is not None and scale != 1.0: - q = q * scale - return _cal_flash_loss(q, k, labels, head_dim) - - -if __name__ == "__main__": - import time - - # Parameters - num_heads = 3 # Number of attention heads - seq_length_q = 32768 # Sequence length - seq_length_k = 32768 - d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128) - - # Randomly initialize inputs - q = paddle.rand((seq_length_q, num_heads * d_model), dtype=paddle.float32) # Query - k = paddle.rand((seq_length_k, num_heads * d_model), dtype=paddle.float32) # Key - l = paddle.ones([]) * np.log(1 / 0.02) - l.stop_gradient = False - - q = F.normalize(q, p=2, axis=-1) - q.stop_gradient = False - k = F.normalize(k, p=2, axis=-1) - k.stop_gradient = False - - q1 = q.clone().detach() - q1.stop_gradient = False - k1 = k.clone().detach() - k1.stop_gradient = False - l1 = l.clone().detach() - l1.stop_gradient = False - - labels = paddle.arange(seq_length_q) - - for i in range(1000): - - # A. paddle gradient - start = time.time() - qk = paddle.einsum("md,nd->mn", l.exp() * q, k) - loss = F.cross_entropy(qk, labels, reduction="mean") - loss.backward() - end = time.time() - - # B. triton gradient - start1 = time.time() - loss1 = cal_flash_loss(q1, k1, labels, l1.exp()) - loss1 = loss1.mean() - loss1.backward() - end1 = time.time() - - print("========= Difference =========") - print(end - start, end1 - start1, l.grad, l1.grad) - print(paddle.max(paddle.abs(q.grad - q1.grad)), paddle.max(paddle.abs(k.grad - k1.grad))) - - set_to_zero = False - q.clear_gradient(set_to_zero) - k.clear_gradient(set_to_zero) - l.clear_gradient(set_to_zero) - q1.clear_gradient(set_to_zero) - k1.clear_gradient(set_to_zero) - l1.clear_gradient(set_to_zero) diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/ring.py b/ops/src/paddlenlp_kernel/triton/inf_cl/ring.py deleted file mode 100644 index 09e19137eed7..000000000000 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/ring.py +++ /dev/null @@ -1,434 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -this code is modified from https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_cl/ring.py -""" -import paddle -import paddle.autograd -import paddle.distributed -import paddle.distributed as dist -import paddle.nn.functional as F - -from .flash import _cal_flash_loss, _flash_prob_backward, _flash_prob_forward - - -def init_dp_sd_comm_group(): - hcg = dist.fleet.get_hybrid_communicate_group() - dp_world_size = hcg.get_data_parallel_world_size() - sd_world_size = hcg.get_sharding_parallel_world_size() - - if dp_world_size > 1 and sd_world_size > 1: - dp_sd_group, dp_sd_comm_group = hcg.create_fuse_group(["data", "sharding"]) - elif dp_world_size > 1: - dp_sd_group, dp_sd_comm_group = hcg._dp_group, hcg.get_data_parallel_group() - elif sd_world_size > 1: - dp_sd_group, dp_sd_comm_group = hcg._sharding_group, hcg.get_sharding_parallel_group() - - hcg._dp_sd_group = dp_sd_group - hcg._dp_sd_comm_group = dp_sd_comm_group - return dp_sd_group, dp_sd_comm_group - - -class RingComm: - def __init__(self, group): - self.group = group - self._ops = [] - self._reqs = None - self.group_rank = group.rank - self.world_size = group.world_size - self.send_rank = self.group.ranks[(self.group_rank + 1) % self.world_size] - self.recv_rank = self.group.ranks[(self.group_rank - 1) % self.world_size] - - def send_recv(self, to_send, recv_tensor=None): - if recv_tensor is None: - res = paddle.empty_like(to_send) - else: - res = recv_tensor - - send_op = dist.P2POp(dist.isend, to_send, self.send_rank, self.group) - recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, self.group) - - self._ops.append(send_op) - self._ops.append(recv_op) - return res - - def commit(self): - if self._reqs is not None: - raise RuntimeError("commit called twice") - self._reqs = dist.batch_isend_irecv(self._ops) - - def wait(self): - if self._reqs is None: - raise RuntimeError("wait called before commit") - for req in self._reqs: - req.wait() - self._reqs = None - self._ops = [] - - -class RingProb(paddle.autograd.PyLayer): - @staticmethod - def forward( - ctx, - q, - k, - group=None, - ): - if group is None: - hcg = dist.fleet.get_hybrid_communicate_group() - if not hasattr(hcg, "_dp_sd_group") and not hasattr(hcg, "_dp_sd_comm_group"): - init_dp_sd_comm_group() - group = hcg._dp_sd_comm_group - - assert group is not None, "Communication group must be specified!" - - k = k.contiguous() - comm = RingComm(group) - - colle = [q, k] - - lse = None - next_k = None - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: paddle.Tensor = comm.send_recv(k) - comm.commit() - - # vanilla lse - qk = paddle.einsum("mhd,nhd->mn", q, k) - block_lse = paddle.log(paddle.exp(qk).sum(axis=-1)) - - if step == 0: - lse = block_lse - else: - lse = lse - F.sigmoid(lse - block_lse).log() - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - - # this should be out_padded - colle.append(lse) - ctx.save_for_backward(*colle) - ctx.group = group - return lse - - @staticmethod - def backward(ctx, dlse): - q, k, lse = ctx.saved_tensor() - k_comm = RingComm(ctx.group) - d_k_comm = RingComm(ctx.group) - dq, dk = None, None - next_dk = None - - block_dq_buffer = paddle.empty(q.shape, dtype=paddle.float32) - block_dk_buffer = paddle.empty(k.shape, dtype=paddle.float32) - - next_dk, next_k = None, None - - for step in range(k_comm.world_size): - if step + 1 != k_comm.world_size: - next_k = k_comm.send_recv(k) - k_comm.commit() - - # vanilla gradient calculation - qk = paddle.einsum("mhd,nhd->mn", q, k) - qk_grad = paddle.exp(qk - lse[:, None]).cast("float32") - qk_grad = qk_grad * dlse[:, None] - block_dq_buffer = paddle.einsum("mn,nhd->mhd", qk_grad, k.cast("float32")) - block_dk_buffer = paddle.einsum("nm,mhd->nhd", qk_grad.T, q.cast("float32")) - - if step == 0: - dq = block_dq_buffer - dk = block_dk_buffer - else: - dq += block_dq_buffer - d_k_comm.wait() - dk = block_dk_buffer + next_dk - - if step + 1 != k_comm.world_size: - k_comm.wait() - k = next_k - - next_dk = d_k_comm.send_recv(dk) - d_k_comm.commit() - - d_k_comm.wait() - - return dq, next_dk - - -class InfProb(paddle.autograd.PyLayer): - @staticmethod - def forward(ctx, q, k, group): - if group is None: - hcg = dist.fleet.get_hybrid_communicate_group() - if not hasattr(hcg, "_dp_sd_group") and not hasattr(hcg, "_dp_sd_comm_group"): - init_dp_sd_comm_group() - group = hcg._dp_sd_comm_group - - assert group is not None, "Communication group must be specified!" - - k = k.contiguous() - comm = RingComm(group) - - colle = [q, k] - - lse = None - next_k = None - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k: paddle.Tensor = comm.send_recv(k) - comm.commit() - - # flash lse - block_lse = _flash_prob_forward(q, k) - - if step == 0: - lse = block_lse - else: - lse = lse - F.sigmoid(lse - block_lse).log() - - if step + 1 != comm.world_size: - comm.wait() - k = next_k - - # this should be out_padded - colle.append(lse) - ctx.save_for_backward(*colle) - ctx.group = group - return lse - - @staticmethod - def backward(ctx, dlse): - q, k, lse = ctx.saved_tensor() - k_comm = RingComm(ctx.group) - d_k_comm = RingComm(ctx.group) - dq, dk = None, None - next_dk = None - - block_dq_buffer = paddle.empty(q.shape, dtype=paddle.float32) - block_dk_buffer = paddle.empty(k.shape, dtype=paddle.float32) - - next_dk, next_k = None, None - - for step in range(k_comm.world_size): - if step + 1 != k_comm.world_size: - next_k = k_comm.send_recv(k) - k_comm.commit() - - # flash gradient calculation - block_dq_buffer, block_dk_buffer = _flash_prob_backward(q, k, lse, dlse) - - if step == 0: - dq = block_dq_buffer - dk = block_dk_buffer - else: - dq += block_dq_buffer - d_k_comm.wait() - dk = block_dk_buffer + next_dk - - if step + 1 != k_comm.world_size: - k_comm.wait() - k = next_k - - next_dk = d_k_comm.send_recv(dk) - d_k_comm.commit() - - d_k_comm.wait() - - return dq, next_dk - - -def _cal_ring_loss(q, k, labels, head_dim=256): - bq = q.shape[0] - bk = k.shape[0] - q = q.reshape([bq, -1, head_dim]).cast("float32") - k = k.reshape([bk, -1, head_dim]).cast("float32") - - lse = RingProb.apply(q, k, None) - numerator = paddle.einsum("mhd,mhd->m", q, k[labels, ...]) - loss = -numerator + lse - - return loss - - -def _cal_inf_loss(q, k, labels, head_dim=256): - bq = q.shape[0] - bk = k.shape[0] - q = q.reshape([bq, -1, head_dim]).cast("float32") - k = k.reshape([bk, -1, head_dim]).cast("float32") - - lse = InfProb.apply(q, k, None) - numerator = paddle.einsum("mhd,mhd->m", q, k[labels, ...]) - loss = -numerator + lse - - return loss - - -class GradientGather(paddle.autograd.PyLayer): - @staticmethod - def forward(ctx, x): - ctx.save_for_backward(x) - return x - - @staticmethod - def backward(ctx, dx): - dist.all_reduce(dx) - return dx - - -def cal_ring_loss(q, k, labels=None, scale=None, head_dim=256): - """The paddle implementation of the ring-cl. - - Args: - q (paddle.Tensor): The column tensor in contrastive loss. The shape is [B, D]. - k (paddle.Tensor): The row tensor in contrastive loss. The shape is [B, D]. - labels (paddle.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None. - scale (paddle.Tensor, optional): The scale tensor of the query tensor. Defaults to None. - head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256. - - """ - - if labels is None: - labels = paddle.arange(q.shape[0]) - if scale is not None and scale != 1.0: - scale = GradientGather.apply(scale) - q = scale * q - - if paddle.distributed.is_initialized(): - return _cal_ring_loss(q, k, labels, head_dim).mean() - else: - return _cal_flash_loss(q, k, labels, head_dim).mean() - - -def cal_inf_loss(q, k, labels=None, scale=None, head_dim=256): - """The triton implementation of the inf-cl. - - Args: - q (paddle.Tensor): The column tensor in contrastive loss. The shape is [B, D]. - k (paddle.Tensor): The row tensor in contrastive loss. The shape is [B, D]. - labels (paddle.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None. - scale (paddle.Tensor, optional): The scale tensor of the query tensor. Defaults to None. - head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256. - - """ - - if labels is None: - labels = paddle.arange(q.shape[0]) - if scale is not None and scale != 1.0: - scale = GradientGather.apply(scale) - q = scale * q - if paddle.distributed.is_initialized(): - return _cal_inf_loss(q, k, labels, head_dim).mean() - else: - return _cal_flash_loss(q, k, labels, head_dim).mean() - - -if __name__ == "__main__": - import time - - import numpy as np - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.hybrid_configs = { - "dp_degree": 2, - "mp_degree": 1, - "pp_degree": 1, - "sharding_degree": 1, - "sep_degree": 1, - } - paddle.distributed.fleet.init(is_collective=True, strategy=strategy) - - rank = dist.get_rank() - world_size = dist.get_world_size() - - # Parameters - dtype = paddle.float32 - num_heads = 3 # Number of attention heads - seq_length_q = 32768 # Sequence length - seq_length_k = 32768 - d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128) - - # Randomly initialize inputs - q = paddle.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype) - k = paddle.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype) - l = paddle.ones([], dtype=dtype) * np.log(1 / 0.07) - l.stop_gradient = False # Logit scale - - q = F.normalize(q, p=2, axis=-1) - q.stop_gradient = False # Query - k = F.normalize(k, p=2, axis=-1) - k.stop_gradient = False # Key - - q1 = q.clone().detach() - q1.stop_gradient = False - k1 = k.clone().detach() - k1.stop_gradient = False - l1 = l.clone().detach() - l1.stop_gradient = False - - for i in range(1000): - # A. local torch gradient - start = time.time() - # A.1. gather q, k - gathered_q = [paddle.zeros_like(q) for _ in range(world_size)] - gathered_k = [paddle.zeros_like(k) for _ in range(world_size)] - dist.all_gather(gathered_q, q) - dist.all_gather(gathered_k, k) - gathered_q[rank] = q - gathered_k[rank] = k - all_q = paddle.concat(gathered_q, axis=0) - all_k = paddle.concat(gathered_k, axis=0) - # A.2. calculating qk logits - qk = paddle.einsum("md,nd->mn", l.exp() * all_q, all_k) - kq = qk.T - _labels = paddle.arange(seq_length_q) - # A.3. calculating loss - loss_i2t = F.cross_entropy(qk, _labels, reduction="mean") - loss_t2i = F.cross_entropy(kq, _labels, reduction="mean") - # A.4. scaling loss to normal value - scale_factor = all_q.shape[0] / q.shape[0] - loss = (loss_i2t + loss_t2i) * 0.5 * scale_factor - loss.backward() - show_loss = loss.detach().clone() - dist.all_reduce(show_loss) - show_loss = show_loss / (world_size * scale_factor) - end = time.time() - - dist.barrier() - - # B. triton implementation - start1 = time.time() - # labels = torch.arange(seq_length_q // world_size).to(q.device) - loss1_i2t = cal_inf_loss(q1, k1, scale=l1.exp()) - loss1_t2i = cal_inf_loss(k1, q1, scale=l1.exp()) - loss1 = (loss1_i2t + loss1_t2i).mean() * 0.5 - loss1.backward() - end1 = time.time() - - dist.barrier() - - if rank == 0: - print(rank, end - start, end1 - start1, loss, show_loss, loss1) - print(l.grad, l1.grad, paddle.max(paddle.abs(q.grad - q1.grad)), paddle.max(paddle.abs(k.grad - k1.grad))) - - set_to_zero = False - q.clear_gradient(set_to_zero) - k.clear_gradient(set_to_zero) - l.clear_gradient(set_to_zero) - q1.clear_gradient(set_to_zero) - k1.clear_gradient(set_to_zero) - l1.clear_gradient(set_to_zero) diff --git a/ops/src/paddlenlp_kernel/triton/mamba/__init__.py b/ops/src/paddlenlp_kernel/triton/mamba/__init__.py deleted file mode 100644 index 603b9e56f43f..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .causal_conv1d_varlen import ( - causal_conv1d_varlen_states, - causal_conv1d_varlen_states_ref, -) -from .k_activations import swiglu -from .selective_state_update import selective_state_update, selective_state_update_ref -from .ssd_chunk_scan import chunk_scan, chunk_scan_ref -from .ssd_chunk_state import chunk_state, chunk_state_ref -from .ssd_combined import ( - mamba_chunk_scan, - mamba_chunk_scan_combined, - mamba_conv1d_scan_ref, - mamba_split_conv1d_scan_combined, - mamba_split_conv1d_scan_ref, - ssd_selective_scan, -) -from .ssd_state_passing import state_passing, state_passing_ref diff --git a/ops/src/paddlenlp_kernel/triton/mamba/causal_conv1d_varlen.py b/ops/src/paddlenlp_kernel/triton/mamba/causal_conv1d_varlen.py deleted file mode 100644 index 8ee0acca6298..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/causal_conv1d_varlen.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) 2024, Tri Dao. -""" -this code is modified from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_varlen.py -""" -import paddle -import triton -import triton.language as tl -from paddle import Tensor - - -@triton.jit -def _causal_conv1d_varlen_states( - X, - CU_SEQLENS, - STATES, - state_len, - dim, - stride_x_seqlen, - stride_x_dim, - stride_states_batch, - stride_states_seqlen, - stride_states_dim, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - batch_idx = tl.program_id(2) - STATES += batch_idx * stride_states_batch - end_idx = tl.load(CU_SEQLENS + batch_idx + 1) - start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) - rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) - x = tl.load( - X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, - mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), - other=0, - ) - rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) - tl.store( - STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, - x, - mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim), - ) - - -def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = paddle.empty([batch, state_len, dim], dtype=x.dtype).transpose([0, 2, 1]) - BLOCK_M = min(triton.next_power_of_2(state_len), 16) - BLOCK_N = min(triton.next_power_of_2(dim), 256) - grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) - _causal_conv1d_varlen_states[grid]( - x, - cu_seqlens, - states, - state_len, - dim, - x.strides[0], - x.strides[1], - states.strides[0], - states.strides[2], - states.strides[1], - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - return states - - -def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: - """ - Forward pass only, does not support backward pass. - - Parameters: - x: (total_tokens, dim) - cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. - state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. - If some of those elements belong to a different sequence, the value of the states will be zero. - Return: - states: (batch, dim, state_len) - """ - _, dim = x.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - states = paddle.zeros([batch, state_len, dim], dtype=x.dtype).transpose([0, 2, 1]) - for i in range(batch): - end_idx = cu_seqlens[i + 1] - start_idx = paddle.maximum(cu_seqlens[i], end_idx - state_len) - states[i, :, -(end_idx - start_idx) :] = x[start_idx:end_idx].T - return states diff --git a/ops/src/paddlenlp_kernel/triton/mamba/k_activations.py b/ops/src/paddlenlp_kernel/triton/mamba/k_activations.py deleted file mode 100644 index 40c7d8b026a2..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/k_activations.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -import paddle -import triton -import triton.language as tl - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_N": 32}), - triton.Config({"BLOCK_N": 64}), - triton.Config({"BLOCK_N": 128}), - triton.Config({"BLOCK_N": 256}), - triton.Config({"BLOCK_N": 512}), - triton.Config({"BLOCK_N": 1024}), - ], - key=["ncols"], -) -@triton.jit -def _swiglu_fwd_kernel( - X, - Y, - OUT, - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_out_row, - ncols, - BLOCK_N: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - start_col = tl.program_id(1) * BLOCK_N - X += row * stride_x_row - Y += row * stride_y_row - OUT += row * stride_out_row - cols = start_col + tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < ncols, other=0.0).to(tl.float32) - y = tl.load(Y + cols, mask=cols < ncols, other=0.0).to(tl.float32) - out = x * tl.sigmoid(x) * y - tl.store(OUT + cols, out, mask=cols < ncols) - - -def _swiglu_fwd(xy, out=None): - if xy.strides[-1] != 1: - xy = xy.contiguous() - batch_shape = xy.shape[:-1] - xy = xy.reshape([-1, xy.shape[-1]]) - x, y = xy.chunk(2, axis=-1) - if out is None: - out = paddle.empty_like(x) - else: - out = out.reshape([-1, out.shape[-1]]) - assert out.shape == x.shape - assert out.strides[-1] == 1 - M, N = x.shape - grid = lambda META: (M, triton.cdiv(N, META["BLOCK_N"])) - _swiglu_fwd_kernel[grid](x, y, out, x.strides[0], y.strides[0], out.strides[0], N) - return out.reshape([*batch_shape, out.shape[-1]]) - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_N": 32}), - triton.Config({"BLOCK_N": 64}), - triton.Config({"BLOCK_N": 128}), - triton.Config({"BLOCK_N": 256}), - triton.Config({"BLOCK_N": 512}), - triton.Config({"BLOCK_N": 1024}), - ], - key=["ncols"], -) -@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None}) -@triton.jit -def _swiglu_bwd_kernel( - X, - Y, - DOUT, - OUT, - DX, - DY, - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_dout_row, - stride_out_row, - stride_dx_row, - stride_dy_row, - ncols, - BLOCK_N: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - start_col = tl.program_id(1) * BLOCK_N - X += row * stride_x_row - Y += row * stride_y_row - DOUT += row * stride_dout_row - if RECOMPUTE_OUTPUT: - OUT += row * stride_out_row - DX += row * stride_dx_row - DY += row * stride_dy_row - cols = start_col + tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < ncols, other=0.0).to(tl.float32) - y = tl.load(Y + cols, mask=cols < ncols, other=0.0).to(tl.float32) - dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.0).to(tl.float32) - x_sigmoid = tl.sigmoid(x) - dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout - dy = x * x_sigmoid * dout - tl.store(DX + cols, dx, mask=cols < ncols) - tl.store(DY + cols, dy, mask=cols < ncols) - if RECOMPUTE_OUTPUT: - out = x * x_sigmoid * y - tl.store(OUT + cols, out, mask=cols < ncols) - - -def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None): - if xy.strides[-1] != 1: - xy = xy.contiguous() - if dout.strides[-1] != 1: - dout = dout.contiguous() - batch_shape = xy.shape[:-1] - xy = xy.reshape([-1, xy.shape[-1]]) - x, y = xy.chunk(2, axis=-1) - dout = dout.reshape([-1, dout.shape[-1]]) - assert dout.shape == x.shape - if dxy is None: - dxy = paddle.empty_like(xy) - else: - dxy = dxy.reshape([-1, dxy.shape[-1]]) - assert dxy.shape == xy.shape - dx, dy = dxy.chunk(2, axis=-1) - assert dx.strides[-1] == 1 - assert dy.strides[-1] == 1 - if recompute_output: - if out is None: - out = paddle.empty_like(x) - else: - out = out.reshape([-1, out.shape[-1]]) - assert out.shape == x.shape - assert out.strides[-1] == 1 - M, N = x.shape - grid = lambda META: (M, triton.cdiv(N, META["BLOCK_N"])) - _swiglu_bwd_kernel[grid]( - x, - y, - dout, - out if recompute_output else None, - dx, - dy, - x.strides[0], - y.strides[0], - dout.strides[0], - out.strides[0] if recompute_output else 0, - dx.strides[0], - dy.strides[0], - N, - ) - if not recompute_output: - return dxy.reshape([*batch_shape, dxy.shape[-1]]) - else: - return dxy.reshape([*batch_shape, dxy.shape[-1]]), out.reshape([*batch_shape, out.shape[-1]]) - - -class SwiGLU(paddle.autograd.PyLayer): - @staticmethod - def forward(ctx, xy): - ctx.save_for_backward(xy) - return _swiglu_fwd(xy) - - @staticmethod - def backward(ctx, dout): - (xy,) = ctx.saved_tensor() - return _swiglu_bwd(xy, dout) - - -swiglu = SwiGLU.apply diff --git a/ops/src/paddlenlp_kernel/triton/mamba/layer_norm.py b/ops/src/paddlenlp_kernel/triton/mamba/layer_norm.py deleted file mode 100755 index cc61853eff8b..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/layer_norm.py +++ /dev/null @@ -1,1101 +0,0 @@ -# Copyright (c) 2024, Tri Dao. -# Implement dropout + residual + layer_norm / rms_norm. - -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -import math - -import paddle -import paddle.nn as nn -import paddle.nn.functional as F -import triton -import triton.language as tl - -from ...utils import custom_bwd, custom_fwd, get_autocast_gpu_dtype, is_autocast_enabled - - -def layer_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, - epsilon=None, -): - if epsilon is not None: - eps = epsilon - dtype = x.dtype - if upcast: - x = x.cast("float32") - weight = weight.cast("float32") - bias = bias.cast("float32") if bias is not None else None - residual = residual.cast("float32") if residual is not None else residual - x1 = x1.cast("float32") if x1 is not None else None - weight1 = weight1.cast("float32") if weight1 is not None else None - bias1 = bias1.cast("float32") if bias1 is not None else None - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).cast(x.dtype) - out = F.layer_norm(x.cast(weight.dtype), x.shape[-1:], weight=weight, bias=bias, epsilon=eps).cast(dtype) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = F.layer_norm(x.cast(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, epsilon=eps).cast(dtype) - return (out, out1) if not prenorm else (out, out1, x) - - -def rms_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, - epsilon=None, -): - if epsilon is not None: - eps = epsilon - dtype = x.dtype - if upcast: - x = x.cast("float32") - weight = weight.cast("float32") - bias = bias.cast("float32") if bias is not None else None - residual = residual.cast("float32") if residual is not None else residual - x1 = x1.cast("float32") if x1 is not None else None - weight1 = weight1.cast("float32") if weight1 is not None else None - bias1 = bias1.cast("float32") if bias1 is not None else None - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).cast(x.dtype) - rstd = 1 / paddle.sqrt((x.square()).mean(axis=-1, keepdim=True) + eps) - out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).cast(dtype) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).cast(dtype) - return (out, out1) if not prenorm else (out, out1, x) - - -def config_prune(configs): - # if paddle.version.hip: - # try: - # # set warp size based on gcn architecture - # gcn_arch_name = paddle.device.get_device_properties(0).gcnArchName - # if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name: - # # radeon - # warp_size = 32 - # else: - # # instinct - # warp_size = 64 - # except AttributeError as e: - # # fall back to crude method to set warp size - # device_name = paddle.device.cuda.get_device_properties(0).name - # if 'instinct' in device_name.lower(): - # warp_size = 64 - # else: - # warp_size = 32 - # warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning) - - # else: - # cuda - warp_size = 32 - - max_block_sz = 1024 - max_num_warps = max_block_sz // warp_size - pruned_configs = [config for config in configs if config.num_warps <= max_num_warps] - return pruned_configs - - -configs_autotune = [ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), -] - -pruned_configs_autotune = config_prune(configs_autotune) - - -@triton.paddle_autotune( - configs=pruned_configs_autotune, - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], -) -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) -@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) -@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - RESIDUAL, # pointer to the residual - X1, - W1, - B1, - Y1, - RESIDUAL_OUT, # pointer to the residual - ROWSCALE, - SEEDS, # Dropout seeds for each row - DROPOUT_MASK, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_res_row, - stride_res_out_row, - stride_x1_row, - stride_y1_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, # Dropout probability - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - STORE_RESIDUAL_OUT: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - STORE_DROPOUT_MASK: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_X1: tl.constexpr, - HAS_W1: tl.constexpr, - HAS_B1: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_RESIDUAL: - RESIDUAL += row * stride_res_row - if STORE_RESIDUAL_OUT: - RESIDUAL_OUT += row * stride_res_out_row - if HAS_X1: - X1 += row * stride_x1_row - if HAS_W1: - Y1 += row * stride_y1_row - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - x *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) - if HAS_X1: - x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) - x1 *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) - x += x1 - if HAS_RESIDUAL: - residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) - x += residual - if STORE_RESIDUAL_OUT: - tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - # Write output - tl.store(Y + cols, y, mask=mask) - if HAS_W1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if HAS_B1: - b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) - y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 - tl.store(Y1 + cols, y1, mask=mask) - - -def _layer_norm_fwd( - x, - weight, - bias, - eps, - residual=None, - x1=None, - weight1=None, - bias1=None, - dropout_p=0.0, - rowscale=None, - out_dtype=None, - residual_dtype=None, - is_rms_norm=False, - return_dropout_mask=False, -): - if residual is not None: - residual_dtype = residual.dtype - M, N = x.shape - assert x.strides[-1] == 1 - if residual is not None: - assert residual.strides[-1] == 1 - assert tuple(residual.shape) == (M, N) - assert weight.shape[0] == N - assert weight.strides[-1] == 1 - if bias is not None: - assert bias.strides[-1] == 1 - assert bias.shape[0] == N - if x1 is not None: - assert x1.shape == x.shape - assert rowscale is None - assert x1.strides[-1] == 1 - if weight1 is not None: - assert weight1.shape[0] == N - assert weight1.strides[-1] == 1 - if bias1 is not None: - assert bias1.shape[0] == N - assert bias1.strides[-1] == 1 - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape[0] == M - # allocate output - y = paddle.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - assert y.strides[-1] == 1 - if weight1 is not None: - y1 = paddle.empty_like(y) - assert y1.strides[-1] == 1 - else: - y1 = None - if ( - residual is not None - or (residual_dtype is not None and residual_dtype != x.dtype) - or dropout_p > 0.0 - or rowscale is not None - or x1 is not None - ): - residual_out = paddle.empty(M, N, dtype=residual_dtype if residual_dtype is not None else x.dtype) - assert residual_out.strides[-1] == 1 - else: - residual_out = None - mean = paddle.empty((M,), dtype=paddle.float32) if not is_rms_norm else None - rstd = paddle.empty((M,), dtype=paddle.float32) - if dropout_p > 0.0: - seeds = paddle.randint(2**32, (M if x1 is None else 2 * M,), dtype=paddle.int64) - else: - seeds = None - if return_dropout_mask and dropout_p > 0.0: - dropout_mask = paddle.empty(M if x1 is None else 2 * M, N, dtype=paddle.bool) - else: - dropout_mask = None - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - _layer_norm_fwd_1pass_kernel[(M,)]( - x, - y, - weight, - bias, - residual, - x1, - weight1, - bias1, - y1, - residual_out, - rowscale, - seeds, - dropout_mask, - mean, - rstd, - x.strides[0], - y.strides[0], - residual.strides[0] if residual is not None else 0, - residual_out.strides[0] if residual_out is not None else 0, - x1.strides[0] if x1 is not None else 0, - y1.strides[0] if y1 is not None else 0, - M, - N, - eps, - dropout_p, - is_rms_norm, - BLOCK_N, - residual is not None, - residual_out is not None, - bias is not None, - dropout_p > 0.0, - dropout_mask is not None, - rowscale is not None, - ) - # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 - if dropout_mask is not None and x1 is not None: - dropout_mask, dropout_mask1 = dropout_mask.chunk(2, axis=0) - else: - dropout_mask1 = None - return ( - y, - y1, - mean, - rstd, - residual_out if residual_out is not None else x, - seeds, - dropout_mask, - dropout_mask1, - ) - - -@triton.paddle_autotune( - configs=pruned_configs_autotune, - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], -) -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) -# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) -@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) -@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) -@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) -@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) -@triton.jit -def _layer_norm_bwd_kernel( - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Y, # pointer to the output to be recomputed - DY, # pointer to the output gradient - DX, # pointer to the input gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - DRESIDUAL, - W1, - DY1, - DX1, - DW1, - DB1, - DRESIDUAL_IN, - ROWSCALE, - SEEDS, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_dy_row, - stride_dx_row, - stride_dres_row, - stride_dy1_row, - stride_dx1_row, - stride_dres_in_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, - rows_per_program, - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_DRESIDUAL: tl.constexpr, - STORE_DRESIDUAL: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_DY1: tl.constexpr, - HAS_DX1: tl.constexpr, - HAS_B1: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, -): - # Map the program id to the elements of X, DX, and DY it should compute. - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - # Do not early exit if row_start >= M, because we need to write DW and DB - cols = tl.arange(0, BLOCK_N) - mask = cols < N - X += row_start * stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += row_start * stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += row_start * stride_dres_in_row - DY += row_start * stride_dy_row - DX += row_start * stride_dx_row - if HAS_DY1: - DY1 += row_start * stride_dy1_row - if HAS_DX1: - DX1 += row_start * stride_dx1_row - if RECOMPUTE_OUTPUT: - Y += row_start * stride_y_row - w = tl.load(W + cols, mask=mask).to(tl.float32) - if RECOMPUTE_OUTPUT and HAS_BIAS: - b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) - if HAS_DY1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - dw = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_BIAS: - db = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_DY1: - dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_B1: - db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - row_end = min((row_block_id + 1) * rows_per_program, M) - for row in range(row_start, row_end): - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - if HAS_DY1: - dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.load(Mean + row) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - xhat = tl.where(mask, xhat, 0.0) - if RECOMPUTE_OUTPUT: - y = xhat * w + b if HAS_BIAS else xhat * w - tl.store(Y + cols, y, mask=mask) - wdy = w * dy - dw += dy * xhat - if HAS_BIAS: - db += dy - if HAS_DY1: - wdy += w1 * dy1 - dw1 += dy1 * xhat - if HAS_B1: - db1 += dy1 - if not IS_RMS_NORM: - c1 = tl.sum(xhat * wdy, axis=0) / N - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - else: - c1 = tl.sum(xhat * wdy, axis=0) / N - dx = (wdy - xhat * c1) * rstd - if HAS_DRESIDUAL: - dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) - dx += dres - # Write dx - if STORE_DRESIDUAL: - tl.store(DRESIDUAL_IN + cols, dx, mask=mask) - if HAS_DX1: - if HAS_DROPOUT: - keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - else: - dx1 = dx - tl.store(DX1 + cols, dx1, mask=mask) - if HAS_DROPOUT: - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - dx *= rowscale - tl.store(DX + cols, dx, mask=mask) - - X += stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += stride_dres_in_row - if RECOMPUTE_OUTPUT: - Y += stride_y_row - DY += stride_dy_row - DX += stride_dx_row - if HAS_DY1: - DY1 += stride_dy1_row - if HAS_DX1: - DX1 += stride_dx1_row - tl.store(DW + row_block_id * N + cols, dw, mask=mask) - if HAS_BIAS: - tl.store(DB + row_block_id * N + cols, db, mask=mask) - if HAS_DY1: - tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) - if HAS_B1: - tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) - - -def _layer_norm_bwd( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual=None, - dy1=None, - weight1=None, - bias1=None, - seeds=None, - dropout_p=0.0, - rowscale=None, - has_residual=False, - has_x1=False, - is_rms_norm=False, - x_dtype=None, - recompute_output=False, -): - M, N = x.shape - assert x.strides[-1] == 1 - assert dy.strides[-1] == 1 - assert tuple(dy.shape) == (M, N) - if dresidual is not None: - assert dresidual.strides[-1] == 1 - assert tuple(dresidual.shape) == (M, N) - assert weight.shape[0] == N - assert weight.strides[-1] == 1 - if bias is not None: - assert bias.strides[-1] == 1 - assert bias.shape[0] == N - if dy1 is not None: - assert weight1 is not None - assert dy1.shape == dy.shape - assert dy1.strides[-1] == 1 - if weight1 is not None: - assert weight1.shape[0] == N - assert weight1.strides[-1] == 1 - if bias1 is not None: - assert bias1.shape[0] == N - assert bias1.strides[-1] == 1 - if seeds is not None: - assert seeds.is_contiguous() - assert seeds.shape[0] == M if not has_x1 else M * 2 - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape[0] == M - # allocate output - dx = paddle.empty_like(x) if x_dtype is None else paddle.empty([M, N], dtype=x_dtype) - dresidual_in = ( - paddle.empty_like(x) - if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) - else None - ) - dx1 = paddle.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None - y = paddle.empty(M, N, dtype=dy.dtype) if recompute_output else None - if recompute_output: - assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" - - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - sm_count = paddle.device.cuda.get_device_properties(paddle.get_device()).multi_processor_count - _dw = paddle.empty((sm_count, N), dtype=paddle.float32) - _db = paddle.empty((sm_count, N), dtype=paddle.float32) if bias is not None else None - _dw1 = paddle.empty_like(_dw) if weight1 is not None else None - _db1 = paddle.empty_like(_db) if bias1 is not None else None - rows_per_program = math.ceil(M / sm_count) - grid = (sm_count,) - _layer_norm_bwd_kernel[grid]( - x, - weight, - bias, - y, - dy, - dx, - _dw, - _db, - dresidual, - weight1, - dy1, - dx1, - _dw1, - _db1, - dresidual_in, - rowscale, - seeds, - mean, - rstd, - x.strides[0], - 0 if not recompute_output else y.strides[0], - dy.strides[0], - dx.strides[0], - dresidual.strides[0] if dresidual is not None else 0, - dy1.strides[0] if dy1 is not None else 0, - dx1.strides[0] if dx1 is not None else 0, - dresidual_in.strides[0] if dresidual_in is not None else 0, - M, - N, - eps, - dropout_p, - rows_per_program, - is_rms_norm, - BLOCK_N, - dresidual is not None, - dresidual_in is not None, - bias is not None, - dropout_p > 0.0, - ) - dw = _dw.sum(0).cast(weight.dtype) - db = _db.sum(0).cast(bias.dtype) if bias is not None else None - dw1 = _dw1.sum(0).cast(weight1.dtype) if weight1 is not None else None - db1 = _db1.sum(0).cast(bias1.dtype) if bias1 is not None else None - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: - dresidual_in = dx - if has_x1 and dropout_p == 0.0: - dx1 = dx - return ( - (dx, dw, db, dresidual_in, dx1, dw1, db1) - if not recompute_output - else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) - ) - - -class LayerNormFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - return_dropout_mask=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape([-1, x.shape[-1]]) - if x.strides[-1] != 1: - x = x.contiguous() - if residual is not None: - assert residual.shape == x_shape_og - residual = residual.reshape([-1, residual.shape[-1]]) - if residual.strides[-1] != 1: - residual = residual.contiguous() - if x1 is not None: - assert x1.shape == x_shape_og - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - x1 = x1.reshape([-1, x1.shape[-1]]) - if x1.strides[-1] != 1: - x1 = x1.contiguous() - weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - if weight1 is not None: - weight1 = weight1.contiguous() - if bias1 is not None: - bias1 = bias1.contiguous() - if rowscale is not None: - rowscale = rowscale.reshape( - [ - -1, - ] - ).contiguous() - residual_dtype = residual.dtype if residual is not None else (paddle.float32 if residual_in_fp32 else None) - y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( - x, - weight, - bias, - eps, - residual, - x1, - weight1, - bias1, - dropout_p=dropout_p, - rowscale=rowscale, - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - ) - ctx.save_for_backward(residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.dropout_p = dropout_p - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.has_x1 = x1 is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - y = y.reshape(x_shape_og) - y1 = y1.reshape(x_shape_og) if y1 is not None else None - residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None - dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None - dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None - if not return_dropout_mask: - if weight1 is None: - return y if not prenorm else (y, residual_out) - else: - return (y, y1) if not prenorm else (y, y1, residual_out) - else: - if weight1 is None: - return ( - (y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1) - ) - else: - return ( - (y, y1, dropout_mask, dropout_mask1) - if not prenorm - else (y, y1, residual_out, dropout_mask, dropout_mask1) - ) - - @staticmethod - @custom_bwd - def backward(ctx, dy, *args): - x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensor() - dy = dy.reshape([-1, dy.shape[-1]]) - if dy.strides[-1] != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - if weight1 is not None: - dy1, args = args[0], args[1:] - dy1 = dy1.reshape([-1, dy1.shape[-1]]) - if dy1.strides[-1] != 1: - dy1 = dy1.contiguous() - assert dy1.shape == x.shape - else: - dy1 = None - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape([-1, dresidual.shape[-1]]) - if dresidual.strides[-1] != 1: - dresidual = dresidual.contiguous() - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( - dy, - x, - weight, - bias, - ctx.eps, - mean, - rstd, - dresidual, - dy1, - weight1, - bias1, - seeds, - ctx.dropout_p, - rowscale, - ctx.has_residual, - ctx.has_x1, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - ) - return ( - dx.reshape(ctx.x_shape_og), - dw, - db, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, - dw1, - db1, - None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - return_dropout_mask=False, - epsilon=None, -): - if epsilon is not None: - eps = epsilon - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - is_rms_norm, - return_dropout_mask, - ) - - -def rms_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, - epsilon=None, -): - if epsilon is not None: - eps = epsilon - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - True, - return_dropout_mask, - ) - - -class RMSNorm(nn.Layer): - def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, epsilon=None, dtype=None): - super().__init__() - self.eps = epsilon or eps - if dropout_p > 0.0: - self.drop = nn.Dropout(dropout_p) - else: - self.drop = None - self.weight = self.create_parameter( - shape=[ - hidden_size, - ], - default_initializer=nn.initializer.Constant(value=1.0), - dtype=paddle.get_default_dtype(), - ) - self.bias = None - - def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): - return rms_norm_fn( - x, - self.weight, - self.bias, - residual=residual, - eps=self.eps, - dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, - prenorm=prenorm, - residual_in_fp32=residual_in_fp32, - ) - - -class LayerNormLinearFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape([-1, x.shape[-1]]) - if x.strides[-1] != 1: - x = x.contiguous() - if residual is not None: - assert residual.shape == x_shape_og - residual = residual.reshape([-1, residual.shape[-1]]) - if residual.strides[-1] != 1: - residual = residual.contiguous() - norm_weight = norm_weight.contiguous() - if norm_bias is not None: - norm_bias = norm_bias.contiguous() - residual_dtype = residual.dtype if residual is not None else (paddle.float32 if residual_in_fp32 else None) - y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( - x, - norm_weight, - norm_bias, - eps, - residual, - out_dtype=None if is_autocast_enabled() else get_autocast_gpu_dtype(), - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - ) - y = y.reshape(x_shape_og) - dtype = get_autocast_gpu_dtype() if is_autocast_enabled() else y.dtype - linear_weight = linear_weight.cast(dtype) - linear_bias = linear_bias.cast(dtype) if linear_bias is not None else None - out = F.linear(y.cast(linear_weight.dtype), linear_weight, linear_bias) - # We don't store y, will be recomputed in the backward pass to save memory - ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.linear_bias_is_none = linear_bias is None - return out if not prenorm else (out, residual_out.reshape(x_shape_og)) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensor() - dout = dout.reshape([-1, dout.shape[-1]]) - dy = F.linear(dout, linear_weight, transpose_y=True) - dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - if dy.strides[-1] != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape([-1, dresidual.shape[-1]]) - if dresidual.strides[-1] != 1: - dresidual = dresidual.contiguous() - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( - dy, - x, - norm_weight, - norm_bias, - ctx.eps, - mean, - rstd, - dresidual=dresidual, - has_residual=ctx.has_residual, - is_rms_norm=ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=True, - ) - dlinear_weight = paddle.einsum("bo,bi->oi", dout, y) - return ( - dx.reshape(ctx.x_shape_og), - dnorm_weight, - dnorm_bias, - dlinear_weight, - dlinear_bias, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_linear_fn( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - epsilon=None, -): - if epsilon is not None: - eps = epsilon - return LayerNormLinearFn.apply( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual, - eps, - prenorm, - residual_in_fp32, - is_rms_norm, - ) diff --git a/ops/src/paddlenlp_kernel/triton/mamba/layernorm_gated.py b/ops/src/paddlenlp_kernel/triton/mamba/layernorm_gated.py deleted file mode 100644 index 2b7d24a51884..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/layernorm_gated.py +++ /dev/null @@ -1,521 +0,0 @@ -# Copyright (c) 2024, Tri Dao. -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -import math - -import paddle -import paddle.nn as nn -import paddle.nn.functional as F -import triton -import triton.language as tl -from einops import rearrange - -from ...utils import custom_bwd, custom_fwd - - -def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True, epsilon=None): - if epsilon is not None: - eps = epsilon - dtype = x.dtype - # N = x.shape[-1] - weight = weight.cast("float32") - bias = bias.cast("float32") if bias is not None else None - if upcast: - x = x.cast("float32") - z = z.cast("float32") if z is not None else z - if z is not None and not norm_before_gate: - x = x * F.silu(z) - if group_size is None: - rstd = 1 / paddle.sqrt((x.square()).mean(axis=-1, keepdim=True) + eps) - out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) - else: - x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) - rstd = 1 / paddle.sqrt((x_group.square()).mean(axis=-1, keepdim=True) + eps) - out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight - if bias is not None: - out = out + bias - if z is not None and norm_before_gate: - out *= F.silu(z) - return out.cast(dtype) - - -@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - Z, # pointer to the other branch - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_z_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_N: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_Z: tl.constexpr, - NORM_BEFORE_GATE: tl.constexpr, - IS_RMS_NORM: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - group = tl.program_id(1) - X += row * stride_x_row + group * N - Y += row * stride_y_row + group * N - if HAS_Z: - Z += row * stride_z_row + group * N - if not IS_RMS_NORM: - Mean += group * M - Rstd += group * M - W += group * N - if HAS_BIAS: - B += group * N - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_Z and not NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=cols < N).to(tl.float32) - x *= z * tl.sigmoid(z) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - if HAS_Z and NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=mask).to(tl.float32) - y *= z * tl.sigmoid(z) - # Write output - tl.store(Y + cols, y, mask=mask) - - -def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False): - M, N = x.shape - if group_size is None: - group_size = N - assert N % group_size == 0 - ngroups = N // group_size - assert x.strides[-1] == 1 - if z is not None: - assert z.strides[-1] == 1 - assert tuple(z.shape) == (M, N) - assert weight.shape[0] == N - assert weight.strides[-1] == 1 - if bias is not None: - assert bias.strides[-1] == 1 - assert bias.shape[0] == N - # allocate output - if out is not None: - assert out.shape == x.shape - else: - out = paddle.empty_like(x) - assert out.strides[-1] == 1 - mean = paddle.empty((ngroups * M,), dtype=paddle.float32) if not is_rms_norm else None - rstd = paddle.empty((ngroups * M,), dtype=paddle.float32) - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) - if group_size > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_N // 256, 1), 8) - grid = (M, ngroups) - _layer_norm_fwd_1pass_kernel[grid]( - x, - out, - weight, - bias, - z, - mean, - rstd, - x.strides[0], - out.strides[0], - z.strides[0] if z is not None else 0, - M, - group_size, - eps, - BLOCK_N=BLOCK_N, - NORM_BEFORE_GATE=norm_before_gate, - IS_RMS_NORM=is_rms_norm, - num_warps=num_warps, - ) - return out, mean, rstd - - -@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) -@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) -@triton.jit -def _layer_norm_bwd_kernel( - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Z, # pointer to the other branch - Y, # pointer to the output to be recomputed - DY, # pointer to the output gradient - DX, # pointer to the input gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - DZ, # pointer to the other branch - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_z_row, - stride_y_row, - stride_dy_row, - stride_dx_row, - stride_dz_row, - stride_dw_row, - stride_db_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - rows_per_program, - NORM_BEFORE_GATE: tl.constexpr, - IS_RMS_NORM: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_Z: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, - BLOCK_N: tl.constexpr, -): - # Map the program id to the elements of X, DX, and DY it should compute. - row_block_id = tl.program_id(0) - group = tl.program_id(1) - row_start = row_block_id * rows_per_program - cols = tl.arange(0, BLOCK_N) - mask = cols < N - X += row_start * stride_x_row + group * N - if HAS_Z: - Z += row_start * stride_z_row + group * N - DZ += row_start * stride_dz_row + group * N - DY += row_start * stride_dy_row + group * N - DX += row_start * stride_dx_row + group * N - if RECOMPUTE_OUTPUT: - Y += row_start * stride_y_row + group * N - if not IS_RMS_NORM: - Mean += group * M - Rstd += group * M - W += group * N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS: - B += group * N - b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) - dw = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_BIAS: - db = tl.zeros((BLOCK_N,), dtype=tl.float32) - row_end = min((row_block_id + 1) * rows_per_program, M) - for row in range(row_start, row_end): - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.load(Mean + row) - if HAS_Z and not NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=mask, other=0.0).to(tl.float32) - x_og = x - x = x_og * z * tl.sigmoid(z) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - xhat = tl.where(mask, xhat, 0.0) - if HAS_Z and NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=mask, other=0.0).to(tl.float32) - z_sigmoid = tl.sigmoid(z) - y = xhat * w + b if HAS_BIAS else xhat * w - if RECOMPUTE_OUTPUT: - tl.store(Y + cols, y * z * z_sigmoid, mask=mask) - dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid)) - tl.store(DZ + cols, dz, mask=mask) - dy *= z * z_sigmoid - else: - if RECOMPUTE_OUTPUT: - y = xhat * w + b if HAS_BIAS else xhat * w - tl.store(Y + cols, y, mask=mask) - wdy = w * dy - c1 = tl.sum(xhat * wdy, axis=0) / N - if not IS_RMS_NORM: - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - else: - dx = (wdy - xhat * c1) * rstd - dw += dy * xhat - if HAS_BIAS: - db += dy - if HAS_Z and not NORM_BEFORE_GATE: - z_sigmoid = tl.sigmoid(z) - dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid)) - tl.store(DZ + cols, dz, mask=mask) - dx *= z * z_sigmoid - # Write dx - tl.store(DX + cols, dx, mask=mask) - - X += stride_x_row - if HAS_Z: - Z += stride_z_row - DZ += stride_dz_row - if RECOMPUTE_OUTPUT: - Y += stride_y_row - DY += stride_dy_row - DX += stride_dx_row - tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask) - if HAS_BIAS: - tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask) - - -def _layer_norm_bwd( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - z=None, - group_size=None, - norm_before_gate=True, - is_rms_norm=False, - recompute_output=False, - dz=None, - out=None, -): - M, N = x.shape - if group_size is None: - group_size = N - assert N % group_size == 0 - ngroups = N // group_size - assert x.strides[-1] == 1 - assert dy.strides[-1] == 1 - assert tuple(dy.shape) == (M, N) - if z is not None: - assert z.strides[-1] == 1 - assert tuple(z.shape) == (M, N) - assert weight.shape[0] == N - assert weight.strides[-1] == 1 - if bias is not None: - assert bias.strides[-1] == 1 - assert bias.shape[0] == N - # allocate output - dx = paddle.empty_like(x) - if dz is not None: - assert z is not None - assert dz.shape == z.shape - assert dz.strides[-1] == 1 - else: - dz = paddle.empty_like(z) if z is not None else None - if recompute_output: - if out is None: - out = paddle.empty_like(x) - assert out.shape == x.shape - - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) - if group_size > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_N // 256, 1), 8) - sm_count = paddle.device.cuda.get_device_properties(paddle.get_device()).multi_processor_count - # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs - # would limit the occupancy. - nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups) - _dw = paddle.empty((nrow_groups, N), dtype=paddle.float32) - _db = paddle.empty((nrow_groups, N), dtype=paddle.float32) if bias is not None else None - rows_per_program = math.ceil(M / nrow_groups) - grid = (nrow_groups, ngroups) - _layer_norm_bwd_kernel[grid]( - x, - weight, - bias, - z, - out if recompute_output else None, - dy, - dx, - _dw, - _db, - dz, - mean, - rstd, - x.strides[0], - z.strides[0] if z is not None else 0, - 0 if not recompute_output else out.strides[0], - dy.strides[0], - dx.strides[0], - dz.strides[0] if dz is not None else 0, - _dw.strides[0], - _db.strides[0] if _db is not None else 0, - M, - group_size, - eps, - rows_per_program, - BLOCK_N=BLOCK_N, - NORM_BEFORE_GATE=norm_before_gate, - IS_RMS_NORM=is_rms_norm, - num_warps=num_warps, - ) - dw = _dw.sum(0).cast(weight.dtype) - db = _db.sum(0).cast(bias.dtype) if bias is not None else None - return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out) - - -class LayerNormFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape([-1, x.shape[-1]]) - if x.strides[-1] != 1: - x = x.contiguous() - if z is not None: - assert z.shape == x_shape_og - z = z.reshape([-1, z.shape[-1]]) - if z.strides[-1] != 1: - z = z.contiguous() - weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - y, mean, rstd = _layer_norm_fwd( - x, - weight, - bias, - eps, - z=z, - group_size=group_size, - norm_before_gate=norm_before_gate, - is_rms_norm=is_rms_norm, - ) - ctx.save_for_backward(x, weight, bias, mean, rstd, z) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.group_size = group_size - ctx.norm_before_gate = norm_before_gate - ctx.is_rms_norm = is_rms_norm - return y.reshape(x_shape_og) - - @staticmethod - @custom_bwd - def backward(ctx, dy): - x, weight, bias, mean, rstd, z = ctx.saved_tensor() - dy = dy.reshape([-1, dy.shape[-1]]) - if dy.strides[-1] != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - dx, dw, db, dz = _layer_norm_bwd( - dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size, ctx.norm_before_gate, ctx.is_rms_norm - ) - return ( - dx.reshape(ctx.x_shape_og), - dw, - db, - dz.reshape(ctx.x_shape_og) if dz is not None else None, - None, - None, - None, - None, - ) - - -def layernorm_fn( - x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False, epsilon=None -): - if epsilon is not None: - eps = epsilon - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm) - - -def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, epsilon=None): - if epsilon is not None: - eps = epsilon - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True) - - -class LayerNorm(nn.Layer): - def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, epsilon=None, dtype=None): - """If group_size is not None, we do GroupNorm with each group having group_size elements. - group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). - """ - - super().__init__() - self.eps = epsilon or eps - self.weight = self.create_parameter( - shape=[ - hidden_size, - ], - default_initializer=nn.initializer.Constant(value=1.0), - dtype=paddle.get_default_dtype(), - ) - self.bias = self.create_parameter( - shape=[ - hidden_size, - ], - default_initializer=nn.initializer.Constant(value=0.0), - dtype=paddle.get_default_dtype(), - ) - self.group_size = group_size - self.norm_before_gate = norm_before_gate - - def forward(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - return layernorm_fn( - x, - self.weight, - self.bias, - z=z, - group_size=self.group_size, - eps=self.eps, - norm_before_gate=self.norm_before_gate, - ) - - -class RMSNorm(nn.Layer): - def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, epsilon=None, dtype=None): - """If group_size is not None, we do GroupNorm with each group having group_size elements. - group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). - """ - super().__init__() - self.eps = epsilon or eps - self.weight = self.create_parameter( - shape=[ - hidden_size, - ], - default_initializer=nn.initializer.Constant(value=1.0), - dtype=paddle.get_default_dtype(), - ) - self.bias = None - self.group_size = group_size - self.norm_before_gate = norm_before_gate - - def forward(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - return rmsnorm_fn( - x, - self.weight, - self.bias, - z=z, - eps=self.eps, - group_size=self.group_size, - norm_before_gate=self.norm_before_gate, - ) diff --git a/ops/src/paddlenlp_kernel/triton/mamba/math.py b/ops/src/paddlenlp_kernel/triton/mamba/math.py deleted file mode 100644 index 309404f40107..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/math.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import operator - -import triton -import triton.language as tl - -from ...utils import compare_version - -__all__ = ["softplus", "tanh", "rsqrt"] - -if compare_version("triton", operator.ge, "3.0.0"): - # softplus - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) - return dt - - # tanh - try: - # typical import path with dispatch available - from triton.language.extra.libdevice import tanh - except ModuleNotFoundError: - # for working with NGC containers - from triton.language.extra.cuda.libdevice import tanh - # rsqrt - try: - # typical import path with dispatch available - from triton.language.extra.libdevice import rsqrt - except ModuleNotFoundError: - # for working with NGC containers - from triton.language.extra.cuda.libdevice import rsqrt -else: - - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) - return dt - - from triton.language.math import rsqrt, tanh diff --git a/ops/src/paddlenlp_kernel/triton/mamba/selective_state_update.py b/ops/src/paddlenlp_kernel/triton/mamba/selective_state_update.py deleted file mode 100644 index 0e0153521494..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/selective_state_update.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this""" - -import paddle -import paddle.nn.functional as F -import triton -import triton.language as tl -from einops import rearrange, repeat - -from .math import softplus - - -@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) -@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) -@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) -@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) -@triton.jit -def _selective_scan_update_kernel( - # Pointers to matrices - state_ptr, - x_ptr, - dt_ptr, - dt_bias_ptr, - A_ptr, - B_ptr, - C_ptr, - D_ptr, - z_ptr, - out_ptr, - # Matrix dimensions - batch, - nheads, - dim, - dstate, - nheads_ngroups_ratio, - # Strides - stride_state_batch, - stride_state_head, - stride_state_dim, - stride_state_dstate, - stride_x_batch, - stride_x_head, - stride_x_dim, - stride_dt_batch, - stride_dt_head, - stride_dt_dim, - stride_dt_bias_head, - stride_dt_bias_dim, - stride_A_head, - stride_A_dim, - stride_A_dstate, - stride_B_batch, - stride_B_group, - stride_B_dstate, - stride_C_batch, - stride_C_group, - stride_C_dstate, - stride_D_head, - stride_D_dim, - stride_z_batch, - stride_z_head, - stride_z_dim, - stride_out_batch, - stride_out_head, - stride_out_dim, - # Meta-parameters - DT_SOFTPLUS: tl.constexpr, - TIE_HDIM: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - HAS_DT_BIAS: tl.constexpr, - HAS_D: tl.constexpr, - HAS_Z: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head - x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head - if HAS_DT_BIAS: - dt_bias_ptr += pid_h * stride_dt_bias_head - A_ptr += pid_h * stride_A_head - B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group - C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group - if HAS_Z: - z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) - state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) - x_ptrs = x_ptr + offs_m * stride_x_dim - dt_ptrs = dt_ptr + offs_m * stride_dt_dim - if HAS_DT_BIAS: - dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim - if HAS_D: - D_ptr += pid_h * stride_D_head - A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) - B_ptrs = B_ptr + offs_n * stride_B_dstate - C_ptrs = C_ptr + offs_n * stride_C_dstate - if HAS_D: - D_ptrs = D_ptr + offs_m * stride_D_dim - if HAS_Z: - z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - - state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) - x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if not TIE_HDIM: - dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, softplus(dt), dt) - A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) - dA = tl.exp(A * dt[:, None]) - else: - dt = tl.load(dt_ptr).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptr).to(tl.float32) - if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, softplus(dt), dt) - A = tl.load(A_ptr).to(tl.float32) - dA = tl.exp(A * dt) # scalar, not a matrix - - B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - if HAS_D: - D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_Z: - z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - - if not TIE_HDIM: - dB = B[None, :] * dt[:, None] - else: - dB = B * dt # vector of size (dstate,) - state = state * dA + dB * x[:, None] - tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) - out = tl.sum(state * C[None, :], axis=1) - if HAS_D: - out += x * D - if HAS_Z: - out *= z * tl.sigmoid(z) - tl.store(out_ptrs, out, mask=offs_m < dim) - - -def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): - """ - Argument: - state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) - dt: (batch, dim) or (batch, nheads, dim) - A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) - C: (batch, dstate) or (batch, ngroups, dstate) - D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) - dt_bias: (dim,) or (nheads, dim) - Return: - out: (batch, dim) or (batch, nheads, dim) - """ - has_heads = state.dim() > 3 - if state.dim() == 3: - state = state.unsqueeze(1) - if x.dim() == 2: - x = x.unsqueeze(1) - if dt.dim() == 2: - dt = dt.unsqueeze(1) - if A.dim() == 2: - A = A.unsqueeze(0) - if B.dim() == 2: - B = B.unsqueeze(1) - if C.dim() == 2: - C = C.unsqueeze(1) - if D is not None and D.dim() == 1: - D = D.unsqueeze(0) - if z is not None and z.dim() == 2: - z = z.unsqueeze(1) - if dt_bias is not None and dt_bias.dim() == 1: - dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape - assert tuple(x.shape) == (batch, nheads, dim) - assert dt.shape == x.shape - assert tuple(A.shape) == (nheads, dim, dstate) - ngroups = B.shape[1] - assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert tuple(B.shape) == (batch, ngroups, dstate) - assert C.shape == B.shape - if D is not None: - assert tuple(D.shape) == (nheads, dim) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert tuple(dt_bias.shape) == (nheads, dim) - out = paddle.empty_like(x) - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) - z_strides = (z.strides[0], z.strides[1], z.strides[2]) if z is not None else (0, 0, 0) - # We don't want autotune since it will overwrite the state - # We instead tune by hand. - BLOCK_SIZE_M, num_warps = ( - (32, 4) - if dstate <= 16 - else ((16, 4) if dstate <= 32 else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))) - ) - tie_hdim = A.strides[-1] == 0 and A.strides[-2] == 0 and dt.strides[-1] == 0 and dt_bias.strides[-1] == 0 - _selective_scan_update_kernel[grid]( - state, - x, - dt, - dt_bias, - A, - B, - C, - D, - z, - out, - batch, - nheads, - dim, - dstate, - nheads // ngroups, - state.strides[0], - state.strides[1], - state.strides[2], - state.strides[3], - x.strides[0], - x.strides[1], - x.strides[2], - dt.strides[0], - dt.strides[1], - dt.strides[2], - *(dt_bias.strides[0], dt_bias.strides[1]) if dt_bias is not None else 0, - A.strides[0], - A.strides[1], - A.strides[2], - B.strides[0], - B.strides[1], - B.strides[2], - C.strides[0], - C.strides[1], - C.strides[2], - *(D.strides[0], D.strides[1]) if D is not None else 0, - z_strides[0], - z_strides[1], - z_strides[2], - out.strides[0], - out.strides[1], - out.strides[2], - dt_softplus, - tie_hdim, - BLOCK_SIZE_M, - num_warps=num_warps, - ) - if not has_heads: - out = out.squeeze(1) - return out - - -def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): - """ - Argument: - state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) - dt: (batch, dim) or (batch, nheads, dim) - A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) - C: (batch, dstate) or (batch, ngroups, dstate) - D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) - dt_bias: (dim,) or (nheads, dim) - Return: - out: (batch, dim) or (batch, nheads, dim) - """ - has_heads = state.dim() > 3 - if state.dim() == 3: - state = state.unsqueeze(1) - if x.dim() == 2: - x = x.unsqueeze(1) - if dt.dim() == 2: - dt = dt.unsqueeze(1) - if A.dim() == 2: - A = A.unsqueeze(0) - if B.dim() == 2: - B = B.unsqueeze(1) - if C.dim() == 2: - C = C.unsqueeze(1) - if D is not None and D.dim() == 1: - D = D.unsqueeze(0) - if z is not None and z.dim() == 2: - z = z.unsqueeze(1) - if dt_bias is not None and dt_bias.dim() == 1: - dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape - assert tuple(x.shape) == (batch, nheads, dim) - assert dt.shape == x.shape - assert tuple(A.shape) == (nheads, dim, dstate) - ngroups = B.shape[1] - assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert tuple(B.shape) == (batch, ngroups, dstate) - assert C.shape == B.shape - if D is not None: - assert tuple(D.shape) == (nheads, dim) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert tuple(dt_bias.shape) == (nheads, dim) - dt = dt + dt_bias - dt = F.softplus(dt) if dt_softplus else dt - dA = paddle.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate) - B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) - C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) - dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) - state.copy_((state * dA + dB * rearrange(x, "b h d -> b h d 1")).cast(state.dtype), False) # (batch, dim, dstate - out = paddle.einsum("bhdn,bhn->bhd", state.cast(C.dtype), C) - if D is not None: - out += (x * D).cast(out.dtype) - out = (out if z is None else out * F.silu(z)).cast(x.dtype) - if not has_heads: - out = out.squeeze(1) - return out diff --git a/ops/src/paddlenlp_kernel/triton/mamba/ssd_bmm.py b/ops/src/paddlenlp_kernel/triton/mamba/ssd_bmm.py deleted file mode 100644 index 36b8906d4c50..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/ssd_bmm.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -"""We want triton==2.1.0 or 2.2.0 for this""" - -import math - -import paddle -import triton -import triton.language as tl - - -def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=2), - ], - key=["chunk_size", "K", "IS_CAUSAL"], -) -@triton.jit -def _bmm_chunk_fwd_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - out_ptr, - seq_idx_ptr, - # Matrix dimensions - seqlen, - chunk_size, - K, - ngroups, - stride_a_batch, - stride_a_seqlen, - stride_a_head, - stride_ak, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_bk, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_outm, - stride_outn, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - # Meta-parameters - IS_CAUSAL: tl.constexpr, - dot_dtype: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_ch = tl.program_id(axis=2) - pid_c = pid_ch // ngroups - pid_h = pid_ch - pid_c * ngroups - num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - if IS_CAUSAL: - if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: - return - a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load( - a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0 - ).to(dot_dtype) - b = tl.load( - b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0 - ).to(dot_dtype) - acc += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_SEQ_IDX: - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) - out = acc.to(out_ptr.dtype.element_ty) - - out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) - tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_CS": 64}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_CS": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_CS": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_CS": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_CS": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_CS": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_CS": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_CS": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_CS": 32}, num_stages=4, num_warps=2), - ], - key=["chunk_size", "K"], -) -@triton.jit -def _bmm_chunk_bwd_kernel( - # Pointers to matrices - a_ptr, - dout_ptr, - db_ptr, - res_ptr, - # Matrix dimensions - seqlen, - chunk_size, - K, - ngroups, - stride_a_batch, - stride_a_seqlen, - stride_a_head, - stride_ak, - stride_dout_batch, - stride_dout_chunk, - stride_dout_head, - stride_dout_csize_m, - stride_dout_csize_n, - stride_db_batch, - stride_db_seqlen, - stride_db_head, - stride_db_k, - stride_res_batch, - stride_res_seqlen, - stride_res_head, - stride_res_k, - # Meta-parameters - dot_dtype: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_CS: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_ch = tl.program_id(axis=2) - pid_c = pid_ch // ngroups - pid_h = pid_ch - pid_c * ngroups - num_pid_n = tl.cdiv(K, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - - a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head - dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_cs = tl.arange(0, BLOCK_SIZE_CS) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m) - a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)): - dout = tl.load( - dout_ptrs, - mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), - other=0.0, - ).to(dot_dtype) - a = tl.load( - a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0 - ).to(dot_dtype) - acc += tl.dot(dout, a) - dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m - a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_RESIDUAL: - res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head - res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k) - res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32) - acc += res - db = acc.to(db_ptr.dtype.element_ty) - - db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head - db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k) - tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)) - - -def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): - """ - Argument: - a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. - causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are - guaranteed to be correct. - Return: - out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) - """ - # Check constraints. - has_groups = a.dim() == 4 - if not has_groups: - batch, seqlen, k = a.shape - else: - batch, seqlen, ngroups, k = a.shape - assert b.shape == a.shape - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - if a.strides[-1] != 1 and a.strides[1] != 1: - a = a.contiguous() - if b.strides[-1] != 1 and b.strides[1] != 1: - b = b.contiguous() - nchunks = math.ceil(seqlen / chunk_size) - # Allocates output. - out_dtype = a.dtype if output_dtype is None else output_dtype - out = paddle.empty( - (batch, nchunks, chunk_size, chunk_size) - if not has_groups - else (batch, nchunks, ngroups, chunk_size, chunk_size), - dtype=out_dtype, - ) - dot_dtype = ( - tl.bfloat16 - if a.dtype == paddle.bfloat16 or b.dtype == paddle.bfloat16 - else (tl.float16 if a.dtype == paddle.float16 or b.dtype == paddle.float16 else tl.float32) - ) - grid = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), - batch, - nchunks if not has_groups else nchunks * ngroups, - ) - _bmm_chunk_fwd_kernel[grid]( - a, - b, - out, - seq_idx, - seqlen, - chunk_size, - k, - ngroups if has_groups else 1, - a.strides[0], - a.strides[1], - 0 if not has_groups else a.strides[2], - a.strides[-1], - b.strides[0], - b.strides[1], - 0 if not has_groups else b.strides[2], - b.strides[-1], - out.strides[0], - out.strides[1], - 0 if not has_groups else out.strides[2], - out.strides[-2], - out.strides[-1], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - causal, - dot_dtype, - HAS_SEQ_IDX=seq_idx is not None, - ) - return out - - -def _bmm_chunk_bwd(a, dout, residual=None, out=None): - """ - Argument: - a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) - residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - Return: - out: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - - If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be zeroed out - before calling this function. - """ - # Check constraints. - has_groups = a.dim() == 4 - if not has_groups: - batch, seqlen, k = a.shape - else: - batch, seqlen, ngroups, k = a.shape - nchunks, chunk_size = dout.shape[1], dout.shape[-1] - if a.strides[-1] != 1 and a.strides[-2] != 1: - a = a.contiguous() - if dout.strides[-1] != 1 and dout.strides[-2] != 1: - dout = dout.contiguous() - if residual is not None: - assert tuple(residual.shape) == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k) - if residual.strides[-1] != 1 and residual.strides[1] != 1: - residual = residual.contiguous() - # Allocates output. - if out is not None: - assert out.shape == a.shape - assert out.strides[-1] == 1 or out.strides[1] == 1 - else: - out = paddle.empty_like(a) - dot_dtype = ( - tl.bfloat16 - if a.dtype == paddle.bfloat16 or dout.dtype == paddle.bfloat16 - else (tl.float16 if a.dtype == paddle.float16 or dout.dtype == paddle.float16 else tl.float32) - ) - grid = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(k, META["BLOCK_SIZE_N"]), - batch, - nchunks if not has_groups else nchunks * ngroups, - ) - residual_strides = ( - (residual.strides[0], residual.strides[1], 0 if not has_groups else residual.strides[2], residual.strides[-1]) - if residual is not None - else (0, 0, 0, 0) - ) - _bmm_chunk_bwd_kernel[grid]( - a, - dout, - out, - residual, - seqlen, - chunk_size, - k, - ngroups if has_groups else 1, - a.strides[0], - a.strides[1], - 0 if not has_groups else a.strides[2], - a.strides[-1], - dout.strides[0], - dout.strides[1], - 0 if not has_groups else dout.strides[2], - dout.strides[-2], - dout.strides[-1], - out.strides[0], - out.strides[1], - 0 if not has_groups else out.strides[2], - out.strides[-1], - residual_strides[0], - residual_strides[1], - residual_strides[2], - residual_strides[3], - dot_dtype, - HAS_RESIDUAL=residual is not None, - ) - return out diff --git a/ops/src/paddlenlp_kernel/triton/mamba/ssd_chunk_scan.py b/ops/src/paddlenlp_kernel/triton/mamba/ssd_chunk_scan.py deleted file mode 100644 index 5f3e71bf60e4..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/ssd_chunk_scan.py +++ /dev/null @@ -1,2858 +0,0 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -"""We want triton==2.1.0 or 2.2.0 for this""" - -import math -import operator - -import paddle -import paddle.nn.functional as F -import triton -import triton.language as tl -from einops import rearrange, repeat - -from ...utils import compare_version, custom_bwd, custom_fwd -from .ssd_bmm import _bmm_chunk_bwd, _bmm_chunk_fwd - -TRITON_22 = compare_version("triton", operator.ge, "2.2.0") - - -def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=2), - ], - key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], -) -@triton.jit -def _chunk_scan_fwd_kernel( - # Pointers to matrices - cb_ptr, - x_ptr, - z_ptr, - out_ptr, - out_x_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - C_ptr, - prev_states_ptr, - D_ptr, - # Matrix dimensions - chunk_size, - hdim, - dstate, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_k, - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_z_batch, - stride_z_seqlen, - stride_z_head, - stride_z_hdim, - stride_out_batch, - stride_out_seqlen, - stride_out_head, - stride_out_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_C_batch, - stride_C_seqlen, - stride_C_head, - stride_C_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_D_head, - # Meta-parameters - IS_CAUSAL: tl.constexpr, - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - HAS_Z: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, - IS_TRITON_22: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += ( - pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head - ) - prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Without the if (pid_c > -1), with Triton 2.1.0, I get - # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. - # With Triton 2.2.0, this works - if IS_TRITON_22 or pid_c > -1: - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate - ) - if not HAS_SEQ_IDX: - scale_m = tl.exp(dA_cs_m) - else: - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - if BLOCK_SIZE_DSTATE <= 128: - C = tl.load( - C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0 - ) - prev_states = tl.load( - prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0 - ) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc = tl.dot(C, prev_states) * scale_m[:, None] - else: - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load( - C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), - other=0.0, - ) - # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - prev_states = tl.load( - prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0 - ) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - prev_states_ptrs += BLOCK_SIZE_K - acc *= scale_m[:, None] - - offs_k = tl.arange(0, BLOCK_SIZE_K) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) - x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) - for k in range(0, K_MAX, BLOCK_SIZE_K): - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to( - tl.float32 - ) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) - # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. - # So we don't need masking wrt seq_idx here. - cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) - cb *= dt_k - if IS_CAUSAL: - mask = offs_m[:, None] >= k + offs_k[None, :] - cb = tl.where(mask, cb, 0.0) - cb = cb.to(x_ptr.dtype.element_ty) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) - acc += tl.dot(cb, x) - cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k - x_ptrs += BLOCK_SIZE_K * stride_x_seqlen - dt_ptrs += BLOCK_SIZE_K * stride_dt_csize - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - if HAS_D: - if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - x_residual = tl.load( - x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), - mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), - other=0.0, - ).to(tl.float32) - acc += x_residual * D - - if HAS_Z: - out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) - tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) - - z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head - z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) - z = tl.load( - z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0 - ).to(tl.float32) - acc *= z * tl.sigmoid(z) - - out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) - tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) - - -@triton.paddle_autotune( - configs=[ - # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), - # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_N": 64}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_N": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_N": 64}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_SIZE_N": 32}, num_stages=4, num_warps=8), - ], - key=["chunk_size", "hdim", "dstate"], -) -@triton.jit -def _chunk_scan_fwd_kernel_wip( - # Pointers to matrices - cb_ptr, - x_ptr, - z_ptr, - out_ptr, - out_x_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - C_ptr, - B_ptr, - prev_states_ptr, - D_ptr, - # Matrix dimensions - chunk_size, - hdim, - dstate, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_k, - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_z_batch, - stride_z_seqlen, - stride_z_head, - stride_z_hdim, - stride_out_batch, - stride_out_seqlen, - stride_out_head, - stride_out_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_C_batch, - stride_C_seqlen, - stride_C_head, - stride_C_dstate, - stride_B_batch, - stride_B_seqlen, - stride_B_head, - stride_B_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_D_head, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - HAS_Z: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - pid_n = tl.program_id(axis=0) - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += ( - pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head - ) - B_ptr += ( - pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head - ) - prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - - offs_m = tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE) - - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate) - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate - ) - # num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) - - prev_states = tl.load( - prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0 - ) - # if pid_c == 0: - # if pid_b == 0: - # if pid_h == 0: - # tl.device_print("", prev_states) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # scale_m = tl.exp(dA_cs_m) - # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] - # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32) - # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) - # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # cb *= dt_m - # mask = offs_m[:, None] >= offs_m[None, :] - # cb = tl.where(mask, cb, 0.0) - # cb = cb.to(x_ptr.dtype.element_ty) - # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0) - # acc += tl.dot(cb, x) - # if HAS_D: - # if D_HAS_HDIM: - # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - # else: - # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - # acc += x.to(tl.float32) * D - # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - - for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M): - start_m = tl.multiple_of(start_m, BLOCK_SIZE_M) - dA_cs_m = tl.load( - dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0 - ).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load( - seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit - start_m, - other=-1, - ) - if not HAS_SEQ_IDX: - scale_m = tl.exp(dA_cs_m) - else: - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - C = tl.load( - C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0 - ) - acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] - # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32) - # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) - # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) - # cb *= dt_m - # mask = offs_m[:, None] >= offs_m[None, :] - # cb = tl.where(mask, cb, 0.0) - # cb = cb.to(x_ptr.dtype.element_ty) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0) - # acc += tl.dot(cb, x) - - if HAS_D: - if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - acc += x.to(tl.float32) * D - - # if HAS_Z: - # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) - # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) - - # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head - # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) - # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) - # acc *= z * tl.sigmoid(z) - - tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim)) - - # TODO: this is not correct, and quite a bit slower - if start_m + BLOCK_SIZE_M < chunk_size_limit: - # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32) - B = tl.load( - B_ptrs, - mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), - other=0.0, - ) - # dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32) - # TODO: seq_idx - # scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m - # B *= scale - B = B.to(x_ptr.dtype.element_ty) - tmp = tl.dot(B, x) - prev_states += tmp.to(prev_states.dtype) - - C_ptrs += BLOCK_SIZE_M * stride_C_seqlen - B_ptrs += BLOCK_SIZE_M * stride_B_seqlen - cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k - x_ptrs += BLOCK_SIZE_M * stride_x_seqlen - dt_ptrs += BLOCK_SIZE_M * stride_dt_csize - out_ptrs += BLOCK_SIZE_M * stride_out_seqlen - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 32}), - triton.Config({"BLOCK_SIZE_M": 64}), - triton.Config({"BLOCK_SIZE_M": 128}), - triton.Config({"BLOCK_SIZE_M": 256}), - ], - key=["chunk_size", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_dz_kernel( - # Pointers to matrices - dout_ptr, - out_ptr, - z_ptr, - x_ptr, - D_ptr, - outz_ptr, - dz_ptr, - dout_x_ptr, - dD_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, - hdim, - batch, - seqlen, - # Strides - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_out_batch, - stride_out_seqlen, - stride_out_head, - stride_out_hdim, - stride_z_batch, - stride_z_seqlen, - stride_z_head, - stride_z_hdim, - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_D_head, - stride_outz_batch, - stride_outz_seqlen, - stride_outz_head, - stride_outz_hdim, - stride_dz_batch, - stride_dz_seqlen, - stride_dz_head, - stride_dz_hdim, - stride_doutx_batch, - stride_doutx_seqlen, - stride_doutx_head, - stride_doutx_hdim, - stride_dD_batch, - stride_dD_chunk, - stride_dD_head, - stride_dD_csize, - stride_dD_hdim, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - HAS_DDACS: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head - out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head - dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head - if RECOMPUTE_OUTPUT: - outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head - if HAS_DDACS: - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - if HAS_D: - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_N) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim) - out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) - z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim) - dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim) - if RECOMPUTE_OUTPUT: - outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim) - if HAS_D: - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - if D_HAS_HDIM: - dD_ptrs = dD_ptr + offs_n * stride_dD_hdim - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to( - tl.float32 - ) - out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to( - tl.float32 - ) - z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - z_sigmoid = tl.sigmoid(z) - if RECOMPUTE_OUTPUT: - outz = out * z * z_sigmoid - tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid)) - tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - dout *= z * z_sigmoid - tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - if HAS_D: - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to( - tl.float32 - ) - if D_HAS_HDIM: - dD = tl.sum(dout * x, axis=0) - tl.store(dD_ptrs, dD, mask=offs_n < hdim) - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - dD = tl.sum(dout * x) - tl.store(dD_ptr, dD) - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - out -= x * D - if HAS_DDACS: - ddA_cs = tl.sum(dout * out, axis=1) - tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=2), - ], - key=["hdim", "dstate", "chunk_size"], -) -@triton.jit -def _chunk_scan_bwd_dstates_kernel( - # Pointers to matrices - dout_ptr, - c_ptr, - dprev_states_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - # Matrix dimensions - hdim, - dstate, - chunk_size, - batch, - seqlen, - nchunks, - nheads_ngroups_ratio, - # Strides - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_c_batch, - stride_c_seqlen, - stride_c_head, - stride_c_dstate, - stride_dprev_states_batch, - stride_dprev_states_chunk, - stride_dprev_states_head, - stride_dprev_states_hdim, - stride_dprev_states_dstate, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - c_ptr += ( - pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head - ) - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen) - c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - dout = tl.load( - dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0 - ).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale_k = tl.exp(dA_cs_k) - else: - seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) - scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0) - dout = (dout * scale_k).to(dout_ptr.dtype.element_ty) - c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0) - acc += tl.dot(dout, c) - dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen - c_ptrs += BLOCK_SIZE_K * stride_c_seqlen - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen - out = acc.to(dprev_states_ptr.dtype.element_ty) - - dprev_states_ptr += ( - pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head - ) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dprev_states_ptrs = dprev_states_ptr + ( - offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate - ) - tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)) - - -@triton.paddle_autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - ], - key=["chunk_size", "dstate", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_dc_kernel( - # Pointers to matrices - dout_ptr, - prev_states_ptr, - C_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - dc_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, - dstate, - hdim, - batch, - seqlen, - nheads, - nheads_per_program, - ngroups, - # Strides - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_prev_states_batch, - stride_prev_states_chunk, - stride_prev_states_head, - stride_prev_states_hdim, - stride_prev_states_dstate, - stride_C_batch, - stride_C_seqlen, - stride_C_head, - stride_C_dstate, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_dc_batch, - stride_dc_seqlen, - stride_dc_split, - stride_dc_group, - stride_dc_dstate, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize, - # Meta-parameters - HAS_DDA_CS: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_sg = tl.program_id(axis=2) - pid_s = pid_sg // ngroups - pid_g = pid_sg - pid_s * ngroups - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - dout_ptr += ( - pid_b * stride_dout_batch - + pid_c * chunk_size * stride_dout_seqlen - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head - ) - dc_ptr += ( - pid_b * stride_dc_batch - + pid_c * chunk_size * stride_dc_seqlen - + pid_g * stride_dc_group - + pid_s * stride_dc_split - ) - prev_states_ptr += ( - pid_b * stride_prev_states_batch - + pid_c * stride_prev_states_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head - ) - dA_cumsum_ptr += ( - pid_b * stride_dA_cs_batch - + pid_c * stride_dA_cs_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head - ) - if HAS_DDA_CS: - C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head - ddA_cumsum_ptr += ( - pid_b * stride_ddA_cs_batch - + pid_c * stride_ddA_cs_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head - ) - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim - ) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize - if HAS_DDA_CS: - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - if HAS_DDA_CS: - c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to( - tl.float32 - ) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) - for h in range(nheads_iter): - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) - prev_states = prev_states.to(dout_ptrs.dtype.element_ty) - dc = tl.dot(dout, prev_states) - dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_m) - else: - scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - dc *= scale[:, None] - if HAS_DDA_CS: - ddA_cs = tl.sum(dc * c, axis=1) - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - acc += dc - dout_ptrs += stride_dout_head - prev_states_ptrs += stride_prev_states_head - dA_cumsum_ptrs += stride_dA_cs_head - if HAS_DDA_CS: - ddA_cumsum_ptrs += stride_ddA_cs_head - # if HAS_SEQ_IDX: - # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate) - tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) - - -@triton.paddle_autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, - num_warps=8, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - ], - key=["chunk_size", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_dx_kernel( - # Pointers to matrices - x_ptr, - cb_ptr, - dout_ptr, - dt_ptr, - dA_cumsum_ptr, - D_ptr, - dx_ptr, - ddt_ptr, # dD_ptr, - # Matrix dimensions - chunk_size, - hdim, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_k, - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_D_head, - stride_dx_batch, - stride_dx_seqlen, - stride_dx_head, - stride_dx_hdim, - stride_ddt_batch, - stride_ddt_chunk, - stride_ddt_head, - stride_ddt_csize, - # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - # if HAS_D: - # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) - dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to( - tl.float32 - ) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Idk why limiting K_MAX gives wrong results, is it a Triton bug? - # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) - K_MAX = chunk_size_limit - for k in range(0, K_MAX, BLOCK_SIZE_K): - # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) - dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) - # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, - # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. - # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. - # This will cause NaN in acc, and hence NaN in dx and ddt. - mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) - cb = tl.where(mask, cb, 0.0) - cb = cb.to(dout_ptr.dtype.element_ty) - acc += tl.dot(cb, dout) - cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k - dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - dx = acc * dt_m[:, None] - dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head - dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) - if HAS_D: - dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dout_res = tl.load( - dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0 - ).to(tl.float32) - if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - dx += dout_res * D - tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - ddt = tl.sum(acc * x, axis=1) - ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) - - # if HAS_D: - # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) - # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32) - # dD = tl.sum(x * dout, axis=0) - # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) - - -# Disabling HAS_DDA_CS for now since it's much slower -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), - ], - key=["chunk_size", "hdim"], -) -# @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) -# @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32}) -@triton.jit -def _chunk_scan_bwd_dcb_kernel( - # Pointers to matrices - x_ptr, - dout_ptr, - cb_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - dcb_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, - hdim, - batch, - seqlen, - nheads, - nheads_per_program, - ngroups, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_n, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_dcb_batch, - stride_dcb_chunk, - stride_dcb_split, - stride_dcb_group, - stride_dcb_csize_m, - stride_dcb_csize_n, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize_m, - stride_ddA_cs_csize_n, - # Meta-parameters - HAS_DDA_CS: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_sg = tl.program_id(axis=2) - pid_s = pid_sg // ngroups - pid_g = pid_sg - pid_s * ngroups - num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - - x_ptr += ( - pid_b * stride_x_batch - + pid_c * chunk_size * stride_x_seqlen - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head - ) - dout_ptr += ( - pid_b * stride_dout_batch - + pid_c * chunk_size * stride_dout_seqlen - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head - ) - dt_ptr += ( - pid_b * stride_dt_batch - + pid_c * stride_dt_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head - ) - dA_cumsum_ptr += ( - pid_b * stride_dA_cs_batch - + pid_c * stride_dA_cs_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head - ) - if HAS_DDA_CS: - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head - ddA_cumsum_ptr += ( - pid_b * stride_ddA_cs_batch - + pid_c * stride_ddA_cs_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head - + pid_m * stride_ddA_cs_csize_m - ) - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_n * stride_dt_csize - if HAS_DDA_CS: - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n - - if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: - dcb_ptr += ( - pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split - ) - dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) - tl.store( - dcb_ptrs, - tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), - mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), - ) - return - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - if HAS_DDA_CS: - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to( - tl.float32 - ) - nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) - for h in range(nheads_iter): - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) - dcb = tl.dot(dout, x) - dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - dcb *= dt_n - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to( - tl.float32 - ) - dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to( - tl.float32 - ) - dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - if HAS_DDA_CS: - tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") - ddA_cs = dcb * cb - mask = offs_m[:, None] >= offs_n[None, :] + 1 - ddA_cs = tl.where(mask, ddA_cs, 0.0) - ddA_cs = tl.cumsum(ddA_cs, axis=1) - ddA_cs = tl.where(mask, ddA_cs, 0.0) - ddA_cs = tl.sum(ddA_cs, axis=0) - tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) - tl.store(ddA_cumsum_ptr, 0.0) - acc += dcb - dout_ptrs += stride_dout_head - x_ptrs += stride_x_head - dt_ptrs += stride_dt_head - dA_cumsum_ptr += stride_dA_cs_head - if HAS_DDA_CS: - ddA_cumsum_ptr += stride_ddA_cs_head - ddA_cumsum_ptrs += stride_ddA_cs_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_SEQ_IDX: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) - mask = offs_m[:, None] >= offs_n[None, :] - acc = tl.where(mask, acc, 0.0) - dcb_ptr += ( - pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split - ) - dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) - tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) - - -# Not numerically stable and should not be used. Leaving here for reference. -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 32}), - triton.Config({"BLOCK_SIZE_M": 64}), - triton.Config({"BLOCK_SIZE_M": 128}), - triton.Config({"BLOCK_SIZE_M": 256}), - ], - key=["chunk_size", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_ddAcs_unstable_kernel( - # Pointers to matrices - dout_ptr, - out_ptr, - dt_ptr, - ddt_ptr, - x_ptr, - D_ptr, - ddA_cumsum_ptr, - dD_ptr, - # Matrix dimensions - chunk_size, - hdim, - batch, - seqlen, - # Strides - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_out_batch, - stride_out_seqlen, - stride_out_head, - stride_out_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_ddt_batch, - stride_ddt_chunk, - stride_ddt_head, - stride_ddt_csize, - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_D_head, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize, - stride_dD_batch, - stride_dD_chunk, - stride_dD_head, - stride_dD_csize, - stride_dD_hdim, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - SUBTRACT_DDTDT: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - if HAS_D: - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_N) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) - if HAS_D: - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - if D_HAS_HDIM: - dD_ptrs = dD_ptr + offs_n * stride_dD_hdim - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to( - tl.float32 - ) - out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to( - tl.float32 - ) - if HAS_D: - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to( - tl.float32 - ) - if D_HAS_HDIM: - dD = tl.sum(dout * x, axis=0) - tl.store(dD_ptrs, dD, mask=offs_n < hdim) - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - dD = tl.sum(dout * x) - tl.store(dD_ptr, dD) - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - out -= x * D - ddA_cs = tl.sum(dout * out, axis=1) - if SUBTRACT_DDTDT: - dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - ddA_cs -= dt * ddt - tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) - - -@triton.paddle_autotune( - configs=[ - # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 16}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 32}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 16}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 32}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 64}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 128}, num_stages=4, num_warps=8), - ], - key=["chunk_size", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_ddAcs_stable_kernel_old( - # Pointers to matrices - x_ptr, - dout_ptr, - dt_ptr, - dA_cumsum_ptr, - cb_ptr, - ddAcs_ptr, - # Matrix dimensions - chunk_size, - hdim, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_n, - stride_ddAcs_batch, - stride_ddAcs_chunk, - stride_ddAcs_head, - stride_ddAcs_csize_m, - stride_ddAcs_csize_n, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_n * stride_dt_csize - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) - # Doing a matmul loop with cumsum later on will cause Triton to crash - # Instead we do just one big matmul - # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # for k in range(0, hdim, BLOCK_SIZE_K): - # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) - # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) - # acc += tl.dot(dout, x) - # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim - # x_ptrs += BLOCK_SIZE_K * stride_x_hdim - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) - acc = tl.dot(dout, x) - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to( - tl.float32 - ) - acc *= cb - dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - acc *= dt_n - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - mask = offs_m[:, None] >= offs_n[None, :] + 1 - acc = tl.where(mask, acc, 0.0) - acc = tl.cumsum(acc, axis=1) - acc = tl.where(mask, acc, 0.0) - ddA_cs = tl.sum(acc, axis=0) - ddAcs_ptr += ( - pid_b * stride_ddAcs_batch - + pid_c * stride_ddAcs_chunk - + pid_h * stride_ddAcs_head - + pid_m * stride_ddAcs_csize_m - ) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n - tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) - tl.store(ddAcs_ptr, 0.0) - - # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64) - # offs_k = tl.arange(0, BLOCK_SIZE_K) - # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) - # dt_ptrs = dt_ptr + offs_n * stride_dt_csize - # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) - - # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) - # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m - # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n - # for n in range(0, chunk_size_limit_n, 64): - # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0) - # acc = tl.dot(dout, x) - # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32) - # acc *= cb - # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) - # acc *= dt_n - # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) - # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n - # acc = tl.where(mask, acc, 0.0) - # acc = tl.cumsum(acc, axis=1) - # acc = tl.where(mask, acc, 0.0) - # ddA_cs = tl.sum(acc, axis=0) - # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n) - # # tl.store(ddAcs_ptr, 0.0) - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, num_stages=3, num_warps=4), - # triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64}, num_stages=3, num_warps=4), - # triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, num_stages=3, num_warps=4), - # triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_stages=3, num_warps=4), - ], - key=["chunk_size", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_ddAcs_stable_kernel( - # Pointers to matrices - x_ptr, - dout_ptr, - dt_ptr, - dA_cumsum_ptr, - cb_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, - hdim, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_n, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize_m, - stride_ddA_cs_csize_n, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - ddA_cumsum_ptr += ( - pid_b * stride_ddA_cs_batch - + pid_c * stride_ddA_cs_chunk - + pid_h * stride_ddA_cs_head - + pid_m * stride_ddA_cs_csize_m - ) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_n * stride_dt_csize - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) - ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n - tl.store(ddA_cumsum_ptr, 0.0) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower - lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M - # lo, hi = 0, chunk_size - for start_n in range(lo, hi, BLOCK_SIZE_N): - start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) - # Doing a matmul loop with cumsum later on will cause Triton to crash - # Instead we do just one big matmul - # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # for k in range(0, hdim, BLOCK_SIZE_K): - # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) - # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) - # acc += tl.dot(dout, x) - # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim - # x_ptrs += BLOCK_SIZE_K * stride_x_hdim - # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) - acc = tl.dot(dout, x) - dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) - acc *= dt_n - # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] - cb = tl.load( - cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0 - ).to(tl.float32) - acc *= cb - dA_cs_n = tl.load( - dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0 - ).to(tl.float32) - acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 - acc = tl.where(mask, acc, 0.0) - rowsum_new = rowsum + tl.sum(acc, axis=1) - acc = rowsum[:, None] + tl.cumsum(acc, axis=1) - rowsum = rowsum_new - acc = tl.where(mask, acc, 0.0) - ddA_cs = tl.sum(acc, axis=0) - tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) - x_ptrs += BLOCK_SIZE_N * stride_x_seqlen - dt_ptrs += BLOCK_SIZE_N * stride_dt_csize - cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n - ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n - - # Need to zero out the rest, since we'll be summing the rows together - for start_n in range(hi, chunk_size, BLOCK_SIZE_N): - tl.store( - ddAcs_ptrs + stride_ddA_cs_csize_n, - tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), - mask=offs_n < chunk_size - start_n - 1, - ) - ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n - - -@triton.paddle_autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - ], - key=["chunk_size", "dstate", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_ddAcs_prev_kernel( - # Pointers to matrices - dout_ptr, - prev_states_ptr, - C_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, - dstate, - hdim, - batch, - seqlen, - nchunks, - nheads_ngroups_ratio, - # Strides - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_prev_states_batch, - stride_prev_states_chunk, - stride_prev_states_head, - stride_prev_states_hdim, - stride_prev_states_dstate, - stride_C_batch, - stride_C_seqlen, - stride_C_head, - stride_C_dstate, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize, - # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - prev_states_ptr += ( - pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head - ) - C_ptr += ( - pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head - ) - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim - ) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) - prev_states = prev_states.to(dout_ptrs.dtype.element_ty) - acc = tl.dot(dout, prev_states) - c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to( - tl.float32 - ) - ddA_cs = tl.sum(acc * c, axis=1) - dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_m) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - ddA_cs *= scale - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - - -def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = C.shape - assert nheads % ngroups == 0 - assert tuple(C.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(cb.shape) == (batch, nchunks, ngroups, chunk_size, chunk_size) - if z is not None: - assert z.shape == x.shape - if D is not None: - assert tuple(D.shape) == (nheads, headdim) or D.shape[0] == nheads - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert tuple(dA_cumsum.shape) == (batch, nheads, nchunks, chunk_size) - assert tuple(states.shape) == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - # Allocates output. - out = paddle.empty([batch, seqlen, nheads, headdim], dtype=x.dtype) - if z is not None: - out_x = paddle.empty([batch, seqlen, nheads, headdim], dtype=x.dtype) - assert out_x.strides == out.strides - else: - out_x = None - grid = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - z_strides = (z.strides[0], z.strides[1], z.strides[2], z.strides[3]) if z is not None else (0, 0, 0, 0) - _chunk_scan_fwd_kernel[grid]( - cb, - x, - z, - out, - out_x, - dt, - dA_cumsum, - seq_idx, - C, - states, - D, - chunk_size, - headdim, - dstate, - batch, - seqlen, - nheads // ngroups, - cb.strides[0], - cb.strides[1], - cb.strides[2], - cb.strides[3], - cb.strides[4], - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - z_strides[0], - z_strides[1], - z_strides[2], - z_strides[3], - out.strides[0], - out.strides[1], - out.strides[2], - out.strides[3], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - C.strides[0], - C.strides[1], - C.strides[2], - C.strides[3], - states.strides[0], - states.strides[1], - states.strides[2], - states.strides[3], - states.strides[4], - D.strides[0] if D is not None else 0, - True, - D is not None, - D.dim() == 2 if D is not None else True, - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - HAS_Z=z is not None, - HAS_SEQ_IDX=seq_idx is not None, - IS_TRITON_22=TRITON_22, - ) - return out, out_x - - -def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = C.shape - assert nheads % ngroups == 0 - assert tuple(C.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(B.shape) == C.shape - assert tuple(cb.shape) == (batch, nchunks, ngroups, chunk_size, chunk_size) - if z is not None: - assert z.shape == x.shape - if D is not None: - assert tuple(D.shape) == (nheads, headdim) or D.shape[0] == nheads - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert tuple(dA_cumsum.shape) == (batch, nheads, nchunks, chunk_size) - assert tuple(states.shape) == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - # Allocates output. - out = paddle.empty([batch, seqlen, nheads, headdim], dtype=x.dtype) - if z is not None: - out_x = paddle.empty([batch, seqlen, nheads, headdim], dtype=x.dtype) - assert out_x.strides == out.strides - else: - out_x = None - grid = lambda META: (triton.cdiv(headdim, META["BLOCK_SIZE_N"]), batch * nchunks, nheads) - z_strides = (z.strides[0], z.strides[1], z.strides[2], z.strides[3]) if z is not None else (0, 0, 0, 0) - _chunk_scan_fwd_kernel_wip[grid]( - cb, - x, - z, - out, - out_x, - dt, - dA_cumsum, - seq_idx, - C, - B, - states, - D, - chunk_size, - headdim, - dstate, - batch, - seqlen, - nheads // ngroups, - cb.strides[0], - cb.strides[1], - cb.strides[2], - cb.strides[3], - cb.strides[4], - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - z_strides[0], - z_strides[1], - z_strides[2], - z_strides[3], - out.strides[0], - out.strides[1], - out.strides[2], - out.strides[3], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - C.strides[0], - C.strides[1], - C.strides[2], - C.strides[3], - B.strides[0], - B.strides[1], - B.strides[2], - B.strides[3], - states.strides[0], - states.strides[1], - states.strides[2], - states.strides[3], - states.strides[4], - D.strides[0] if D is not None else 0, - D is not None, - D.dim() == 2 if D is not None else True, - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - BLOCK_SIZE_M=128, - HAS_Z=z is not None, - HAS_SEQ_IDX=seq_idx is not None, - ) - return out, out_x - - -def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False): - batch, seqlen, nheads, headdim = x.shape - assert z.shape == x.shape - assert out.shape == x.shape - assert dout.shape == out.shape - nchunks = math.ceil(seqlen / chunk_size) - if D is not None: - assert tuple(D.shape) == (nheads, headdim) or D.shape[0] == nheads - assert D.strides[-1] == 1 - if has_ddAcs: - ddA_cumsum = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - if D is not None: - BLOCK_SIZE_min = 32 - dD = paddle.empty( - [triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, headdim if D.dim() == 2 else 1], - dtype=paddle.float32, - ) - else: - dD = None - if dz is not None: - assert dz.shape == z.shape - else: - dz = paddle.empty_like(z) - if recompute_output: - outz = paddle.empty_like(x) - dout_x = paddle.empty_like(dout) - dD_strides = ( - (dD.strides[0], dD.strides[1], dD.strides[2], dD.strides[3], dD.strides[4]) - if D is not None - else (0, 0, 0, 0, 0) - ) - grid_dz = lambda META: (triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]), batch * nchunks, nheads) - _chunk_scan_bwd_dz_kernel[grid_dz]( - dout, - out, - z, - x, - D, - outz if recompute_output else None, - dz, - dout_x, - dD, - ddA_cumsum if has_ddAcs else None, - chunk_size, - headdim, - batch, - seqlen, - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - out.strides[0], - out.strides[1], - out.strides[2], - out.strides[3], - z.strides[0], - z.strides[1], - z.strides[2], - z.strides[3], - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - D.strides[0] if D is not None else 0, - *((outz.strides[0], outz.strides[1], outz.strides[2], outz.strides[3]) if recompute_output else (0, 0, 0, 0)), - dz.strides[0], - dz.strides[1], - dz.strides[2], - dz.strides[3], - dout_x.strides[0], - dout_x.strides[1], - dout_x.strides[2], - dout_x.strides[3], - dD_strides[1], - dD_strides[2], - dD_strides[3], - dD_strides[0], - dD_strides[4], - *( - (ddA_cumsum.strides[0], ddA_cumsum.strides[2], ddA_cumsum.strides[1], ddA_cumsum.strides[3]) - if has_ddAcs - else (0, 0, 0, 0) - ), - D is not None, - D.dim() == 2 if D is not None else True, - has_ddAcs, - BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), - RECOMPUTE_OUTPUT=recompute_output, - ) - if D is not None: - BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - dD = dD[:n_valid_blocks].sum(axis=(0, 1, 2)).cast(dtype=D.dtype) - if D.dim() == 1: - dD = rearrange(dD, "h 1 -> h") - return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD) - return return_vals if not recompute_output else (*return_vals, outz) - - -def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): - batch, seqlen, nheads, headdim = dout.shape - _, _, nchunks, chunk_size = dA_cumsum.shape - _, _, ngroups, dstate = C.shape - assert nheads % ngroups == 0 - assert tuple(C.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(dA_cumsum.shape) == (batch, nheads, nchunks, chunk_size) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - dtype = C.dtype if dtype is None else dtype - dprev_states = paddle.empty([batch, nchunks, nheads, headdim, dstate], dtype=dtype) - grid_dstates = lambda META: ( - triton.cdiv(headdim, META["BLOCK_SIZE_M"]) * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - _chunk_scan_bwd_dstates_kernel[grid_dstates]( - dout, - C, - dprev_states, - dA_cumsum, - seq_idx, - headdim, - dstate, - chunk_size, - batch, - seqlen, - nchunks, - nheads // ngroups, - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - C.strides[0], - C.strides[1], - C.strides[2], - C.strides[3], - dprev_states.strides[0], - dprev_states.strides[1], - dprev_states.strides[2], - dprev_states.strides[3], - dprev_states.strides[4], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - HAS_SEQ_IDX=seq_idx is not None, - ) - return dprev_states - - -def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1): - batch, nchunks, nheads, headdim, dstate = prev_states.shape - _, seqlen, _, _ = dout.shape - _, _, _, chunk_size = dA_cumsum.shape - assert tuple(prev_states.shape) == (batch, nchunks, nheads, headdim, dstate) - assert tuple(dA_cumsum.shape) == (batch, nheads, nchunks, chunk_size) - assert tuple(dout.shape) == (batch, seqlen, nheads, headdim) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - if C is not None: - assert tuple(C.shape) == (batch, seqlen, ngroups, dstate) - C_strides = (C.strides[0], C.strides[1], C.strides[2], C.strides[3]) - ddA_cumsum_prev = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - ddA_cumsum_prev_strides = ( - ddA_cumsum_prev.strides[0], - ddA_cumsum_prev.strides[2], - ddA_cumsum_prev.strides[1], - ddA_cumsum_prev.strides[3], - ) - else: - C_strides = (0, 0, 0, 0) - ddA_cumsum_prev = None - ddA_cumsum_prev_strides = (0, 0, 0, 0) - nheads_ngroups_ratio = nheads // ngroups - sm_count = paddle.device.cuda.get_device_properties(paddle.get_device()).multi_processor_count - nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) - nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) - dC = paddle.empty([batch, seqlen, nsplits, ngroups, dstate], dtype=paddle.float32) - grid_dc = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch * nchunks, - nsplits * ngroups, - ) - _chunk_scan_bwd_dc_kernel[grid_dc]( - dout, - prev_states, - C, - dA_cumsum, - seq_idx, - dC, - ddA_cumsum_prev, - chunk_size, - dstate, - headdim, - batch, - seqlen, - nheads, - nheads_per_program, - ngroups, - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - prev_states.strides[0], - prev_states.strides[1], - prev_states.strides[2], - prev_states.strides[3], - prev_states.strides[4], - *C_strides, - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - dC.strides[0], - dC.strides[1], - dC.strides[2], - dC.strides[3], - dC.strides[4], - *ddA_cumsum_prev_strides, - HAS_DDA_CS=ddA_cumsum_prev is not None, - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - dC = dC.sum(2) - return dC if C is None else (dC, ddA_cumsum_prev) - - -def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dout.shape == x.shape - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - if CB is not None: - assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - CB_strides = (CB.strides[0], CB.strides[1], CB.strides[2], CB.strides[3], CB.strides[4]) - BLOCK_SIZE_M_min = 16 - ddA_cumsum = paddle.empty( - [batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size], dtype=paddle.float32 - ) - ddA_cumsum_strides = ( - ddA_cumsum.strides[0], - ddA_cumsum.strides[2], - ddA_cumsum.strides[1], - ddA_cumsum.strides[3], - ddA_cumsum.strides[4], - ) - else: - CB_strides = (0, 0, 0, 0, 0) - ddA_cumsum = None - ddA_cumsum_strides = (0, 0, 0, 0, 0) - nheads_ngroups_ratio = nheads // ngroups - sm_count = paddle.device.cuda.get_device_properties(paddle.get_device()).multi_processor_count - nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) - nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) - dcb = paddle.empty([batch, nchunks, nsplits, ngroups, chunk_size, chunk_size], dtype=paddle.float32) - grid_dcb = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), - batch * nchunks, - nsplits * ngroups, - ) - _chunk_scan_bwd_dcb_kernel[grid_dcb]( - x, - dout, - CB, - dt, - dA_cumsum, - seq_idx, - dcb, - ddA_cumsum, - chunk_size, - headdim, - batch, - seqlen, - nheads, - nheads_per_program, - ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - *CB_strides, - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - dcb.strides[0], - dcb.strides[1], - dcb.strides[2], - dcb.strides[3], - dcb.strides[4], - dcb.strides[5], - *ddA_cumsum_strides, - HAS_DDA_CS=ddA_cumsum is not None, - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - dcb = dcb.sum(2) - if ddA_cumsum is not None: - BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual - ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(axis=3) - return dcb if CB is None else (dcb, ddA_cumsum) - - -def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - ngroups = cb.shape[2] - assert nheads % ngroups == 0 - assert tuple(cb.shape) == (batch, nchunks, ngroups, chunk_size, chunk_size) - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dout.shape == x.shape - # if D is not None: - # BLOCK_SIZE_M_min = 32 - # dD = paddle.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=paddle.float32) - # else: - # dD = None - dx = paddle.empty_like(x) - ddt = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - grid_dx = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - _chunk_scan_bwd_dx_kernel[grid_dx]( - x, - cb, - dout, - dt, - dA_cumsum, - D, - dx, - ddt, # dD, - chunk_size, - headdim, - batch, - seqlen, - nheads // ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - cb.strides[0], - cb.strides[1], - cb.strides[2], - cb.strides[-1], - cb.strides[-2], - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - D.strides[0] if D is not None else 0, - dx.strides[0], - dx.strides[1], - dx.strides[2], - dx.strides[3], - ddt.strides[0], - ddt.strides[2], - ddt.strides[1], - ddt.strides[3], - # dD.strides[1] if dD is not None else 0, dD.strides[2] if dD is not None else 0, dD.strides[3] if dD is not None else 0, dD.strides[4] if dD is not None else 0, dD.strides[0] if dD is not None else 0, - D is not None, - D.dim() == 2 if D is not None else True, - ) - # if D is not None: - # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] - # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - # dD = dD[:n_valid_blocks].sum(axis=(0, 1, 2)).cast(dtype=D.dtype) - return dx, ddt.cast(dtype=dt.dtype) - - -def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): - """Not numerically stable and should not be used. Leaving here for reference.""" - - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert ddt.shape == dt.shape - assert out.shape == x.shape - assert dout.shape == x.shape - if D is not None: - assert tuple(D.shape) == (nheads, headdim) or D.shape[0] == nheads - ddA_cumsum = paddle.empty_like(dt) - grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]), batch * nchunks, nheads) - if D is not None: # Triton gives wrong results if we write to the same location - BLOCK_SIZE_min = 32 - dD = paddle.empty( - [triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, headdim if D.dim() == 2 else 1], - dtype=paddle.float32, - ) - else: - dD = None - dD_strides = ( - (dD.strides[0], dD.strides[1], dD.strides[2], dD.strides[3], dD.strides[4]) - if D is not None - else (0, 0, 0, 0, 0) - ) - _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs]( - dout, - out, - dt, - ddt, - x, - D, - ddA_cumsum, - dD, - chunk_size, - headdim, - batch, - seqlen, - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - out.strides[0], - out.strides[1], - out.strides[2], - out.strides[3], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - ddt.strides[0], - ddt.strides[2], - ddt.strides[1], - ddt.strides[3], - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - D.strides[0] if D is not None else 0, - ddA_cumsum.strides[0], - ddA_cumsum.strides[2], - ddA_cumsum.strides[1], - ddA_cumsum.strides[3], - dD_strides[1], - dD_strides[2], - dD_strides[3], - dD_strides[0], - dD_strides[4], - D is not None, - D.dim() == 2 if D is not None else True, - subtract_ddtdt, - BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), - ) - if D is not None: - BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - dD = dD[:n_valid_blocks].sum(axis=(0, 1, 2)).cast(dtype=D.dtype) - if D.dim() == 1: - dD = rearrange(dD, "h 1 -> h") - return ddA_cumsum, dD - - -def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dout.shape == x.shape - assert dA_cumsum.shape == dt.shape - ngroups = cb.shape[2] - assert nheads % ngroups == 0 - assert tuple(cb.shape) == (batch, nchunks, ngroups, chunk_size, chunk_size) - BLOCK_SIZE_M_min = 16 - ddA_cumsum = paddle.empty( - [batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size], dtype=paddle.float32 - ) - grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]), batch * nchunks, nheads) - _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs]( - x, - dout, - dt, - dA_cumsum, - cb, - ddA_cumsum, - chunk_size, - headdim, - batch, - seqlen, - nheads // ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - cb.strides[0], - cb.strides[1], - cb.strides[2], - cb.strides[3], - cb.strides[4], - ddA_cumsum.strides[0], - ddA_cumsum.strides[2], - ddA_cumsum.strides[1], - ddA_cumsum.strides[3], - ddA_cumsum.strides[4], - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16), - ) - BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual - ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(axis=3) - return ddA_cumsum - - -def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dout.shape == x.shape - assert dA_cumsum.shape == dt.shape - ngroups = cb.shape[2] - assert nheads % ngroups == 0 - assert tuple(cb.shape) == (batch, nchunks, ngroups, chunk_size, chunk_size) - BLOCK_SIZE_M_min = 32 - ddA_cumsum = paddle.empty( - [batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size], dtype=paddle.float32 - ) - grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]), batch * nchunks, nheads) - _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( - x, - dout, - dt, - dA_cumsum, - cb, - ddA_cumsum, - chunk_size, - headdim, - batch, - seqlen, - nheads // ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - cb.strides[0], - cb.strides[1], - cb.strides[2], - cb.strides[3], - cb.strides[4], - ddA_cumsum.strides[0], - ddA_cumsum.strides[2], - ddA_cumsum.strides[1], - ddA_cumsum.strides[3], - ddA_cumsum.strides[4], - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual - ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(axis=3) - return ddA_cumsum - - -def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None): - batch, nchunks, nheads, headdim, dstate = prev_states.shape - _, seqlen, _, _ = dout.shape - _, _, _, chunk_size = dA_cumsum.shape - assert tuple(prev_states.shape) == (batch, nchunks, nheads, headdim, dstate) - assert tuple(dA_cumsum.shape) == (batch, nheads, nchunks, chunk_size) - assert tuple(dout.shape) == (batch, seqlen, nheads, headdim) - ngroups = C.shape[2] - assert nheads % ngroups == 0 - assert tuple(C.shape) == (batch, seqlen, ngroups, dstate) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - ddA_cumsum_prev = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - grid_ddAcs = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs]( - dout, - prev_states, - C, - dA_cumsum, - seq_idx, - ddA_cumsum_prev, - chunk_size, - dstate, - headdim, - batch, - seqlen, - nchunks, - nheads // ngroups, - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - prev_states.strides[0], - prev_states.strides[1], - prev_states.strides[2], - prev_states.strides[3], - prev_states.strides[4], - C.strides[0], - C.strides[1], - C.strides[2], - C.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - ddA_cumsum_prev.strides[0], - ddA_cumsum_prev.strides[2], - ddA_cumsum_prev.strides[1], - ddA_cumsum_prev.strides[3], - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - return ddA_cumsum_prev - - -class ChunkScanFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): - # Check constraints. - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape - assert B.shape == (batch, seqlen, ngroups, dstate) - _, _, nchunks, chunk_size = dt.shape - assert seqlen == nchunks * chunk_size - assert C.shape == B.shape - if z is not None: - assert z.shape == x.shape - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) - if B.strides[-1] != 1: - B = B.contiguous() - if C.strides[-1] != 1: - C = C.contiguous() - if x.strides[-1] != 1 and x.strides[1] != 1: # Either M or K dimension should be contiguous - x = x.contiguous() - if z is not None and z.strides[-1] != 1 and z.strides[1] != 1: # Either M or K dimension should be contiguous - z = z.contiguous() - if D is not None and D.strides[-1] != 1: - D = D.contiguous() - CB = _bmm_chunk_fwd(C, B, chunk_size) - out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z) - ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dout): - if dout.strides[-1] != 1: - dout = dout.contiguous() - out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensor() - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert dout.shape == (batch, seqlen, nheads, headdim) - if z is not None: - dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D) - else: - dz = None - dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype) - dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups) - dC = dC.cast(C.dtype) - dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups) - dCB = dCB.cast(CB.dtype) - dB = _bmm_chunk_bwd(C, dCB) - dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC) - dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D) - # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. - # ddA_cumsum = paddle.einsum("bclhp,bclhp->bhcl", out.cast("float32"), dout.cast("float32")) - ddt * dt - if z is not None: - ddA_cumsum -= ddt * dt - else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz - ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D) - ddA_cumsum = ddA_cumsum.cast(dA_cumsum.dtype) - return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz - - -def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): - """ - prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1. - - Argument: - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) - dA_cumsum: (batch, nheads, nchunks, chunk_size) - prev_states: (batch, nchunks, nheads, headdim, dstate) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - Return: - out: (batch, seqlen, nheads, headdim) - """ - return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z) - - -def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): - """ - Argument: - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) - dA_cumsum: (batch, nheads, nchunks, chunk_size) - prev_states: (batch, nchunks, nheads, headdim, dstate) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - Return: - out: (batch, seqlen, nheads, headdim) - """ - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - _, _, nchunks, chunk_size = dt.shape - assert seqlen == nchunks * chunk_size - assert C.shape == B.shape - B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) - C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) - CB = paddle.einsum( - "bclhn,bcshn->bchls", - rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), - rearrange(B, "b (c s) h n -> b c s h n", c=nchunks), - ) - # (batch, nheads, nchunks, chunksize, chunksize) - dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] - decay = paddle.exp(dt_segment_sum) - scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = paddle.tril(paddle.ones(chunk_size, chunk_size, dtype=bool), diagonal=0) - scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = paddle.einsum( - "bchls,bhcs,bcshp->bclhp", - scores_decay.cast(x.dtype), - dt.cast(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks), - ) - state_decay_out = paddle.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = ( - paddle.einsum( - "bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.cast(C.dtype) - ) - * state_decay_out - ) - out = out + out_prev - out = rearrange(out, "b c l h p -> b (c l) h p") - if D is not None: - if D.dim() == 1: - D = rearrange(D, "h -> h 1") - out = out + x * D - return out if z is None else out * F.silu(z) diff --git a/ops/src/paddlenlp_kernel/triton/mamba/ssd_chunk_state.py b/ops/src/paddlenlp_kernel/triton/mamba/ssd_chunk_state.py deleted file mode 100644 index 36bb7efc0704..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/ssd_chunk_state.py +++ /dev/null @@ -1,1569 +0,0 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -"""We want triton==2.1.0 or 2.2.0 for this""" - -import math - -import paddle -import paddle.autograd -import paddle.nn.functional as F -import triton -import triton.language as tl -from einops import rearrange, repeat - -from ...utils import custom_bwd, custom_fwd -from .math import softplus - - -def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_H": 1}), - triton.Config({"BLOCK_SIZE_H": 2}), - triton.Config({"BLOCK_SIZE_H": 4}), - triton.Config({"BLOCK_SIZE_H": 8}), - triton.Config({"BLOCK_SIZE_H": 16}), - triton.Config({"BLOCK_SIZE_H": 32}), - triton.Config({"BLOCK_SIZE_H": 64}), - ], - key=["chunk_size", "nheads"], -) -@triton.jit -def _chunk_cumsum_fwd_kernel( - # Pointers to matrices - dt_ptr, - A_ptr, - dt_bias_ptr, - dt_out_ptr, - dA_cumsum_ptr, - # Matrix dimension - batch, - seqlen, - nheads, - chunk_size, - dt_min, - dt_max, - # Strides - stride_dt_batch, - stride_dt_seqlen, - stride_dt_head, - stride_A_head, - stride_dt_bias_head, - stride_dt_out_batch, - stride_dt_out_chunk, - stride_dt_out_head, - stride_dt_out_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - # Meta-parameters - DT_SOFTPLUS: tl.constexpr, - HAS_DT_BIAS: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, - BLOCK_SIZE_CHUNK: tl.constexpr, -): - pid_b = tl.program_id(axis=0) - pid_c = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen - dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk - - offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) - offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) - dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) - A_ptrs = A_ptr + offs_h * stride_A_head - dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize) - dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to( - tl.float32 - ) - if HAS_DT_BIAS: - dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) - dt += dt_bias[:, None] - if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, softplus(dt), dt) - # As of Triton 2.2.0, tl.clamp is not available yet - # dt = tl.clamp(dt, dt_min, dt_max) - dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) - dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) - tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) - A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) - dA = dt * A[:, None] - dA_cs = tl.cumsum(dA, axis=1) - tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - ], - key=["chunk_size", "nheads"], -) -@triton.jit -def _chunk_cumsum_bwd_kernel( - # Pointers to matrices - ddA_ptr, - ddt_out_ptr, - dt_ptr, - A_ptr, - dt_bias_ptr, - ddt_ptr, - dA_ptr, - ddt_bias_ptr, - # Matrix dimensions - batch, - seqlen, - nheads, - chunk_size, - dt_min, - dt_max, - # Strides - stride_ddA_batch, - stride_ddA_chunk, - stride_ddA_head, - stride_ddA_csize, - stride_ddt_out_batch, - stride_ddt_out_chunk, - stride_ddt_out_head, - stride_ddt_out_csize, - stride_dt_batch, - stride_dt_seqlen, - stride_dt_head, - stride_A_head, - stride_dt_bias_head, - stride_ddt_batch, - stride_ddt_seqlen, - stride_ddt_head, - stride_dA_head, - stride_ddt_bias_head, - # Meta-parameters - DT_SOFTPLUS: tl.constexpr, - HAS_DT_BIAS: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, - BLOCK_SIZE_CHUNK: tl.constexpr, -): - pid_b = tl.program_id(axis=0) - pid_c = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk - ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk - dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen - ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen - - offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) - offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) - ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize) - ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize) - dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) - ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen) - A_ptrs = A_ptr + offs_h * stride_A_head - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to( - tl.float32 - ) - ddt_out = tl.load( - ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0 - ).to(tl.float32) - A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) - ddt = ddA * A[:, None] + ddt_out - dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to( - tl.float32 - ) - if HAS_DT_BIAS: - dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) - dt += dt_bias[:, None] - if DT_SOFTPLUS: - dt_presoftplus = dt - dt = softplus(dt) - clamp_mask = (dt < dt_min) | (dt > dt_max) - # As of Triton 2.2.0, tl.clamp is not available yet - # dt = tl.clamp(dt, dt_min, dt_max) - dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) - dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) - ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0) - ddt = tl.where(clamp_mask, 0.0, ddt) - if DT_SOFTPLUS: - ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt) - tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) - dA = tl.sum(ddA * dt, axis=1) - tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) - if HAS_DT_BIAS: - ddt_bias = tl.sum(ddt, axis=1) - tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=2), - ], - key=["hdim", "dstate", "chunk_size"], -) -@triton.jit -def _chunk_state_fwd_kernel( - # Pointers to matrices - x_ptr, - b_ptr, - states_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - # Matrix dimensions - hdim, - dstate, - chunk_size, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += ( - pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head - ) - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) - dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - if HAS_SEQ_IDX: - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to( - tl.float32 - ) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k - else: - scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) - b *= scale[:, None] - b = b.to(x_ptr.dtype.element_ty) - acc += tl.dot(x, b) - x_ptrs += BLOCK_SIZE_K * stride_x_seqlen - b_ptrs += BLOCK_SIZE_K * stride_b_seqlen - dt_ptrs += BLOCK_SIZE_K * stride_dt_csize - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen - states = acc.to(states_ptr.dtype.element_ty) - - states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) - c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) - tl.store(states_ptrs, states, mask=c_mask) - - -@triton.paddle_autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, - num_warps=8, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]), - ), - ], - key=["chunk_size", "hdim", "dstate"], -) -@triton.jit -def _chunk_state_bwd_dx_kernel( - # Pointers to matrices - x_ptr, - b_ptr, - dstates_ptr, - dt_ptr, - dA_cumsum_ptr, - dx_ptr, - ddt_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, - hdim, - dstate, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_dstates_batch, - stride_dstates_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_dx_batch, - stride_dx_seqlen, - stride_dx_head, - stride_dx_hdim, - stride_ddt_batch, - stride_ddt_chunk, - stride_ddt_head, - stride_ddt_csize, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - b_ptr += ( - pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head - ) - dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) - dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) - if BLOCK_SIZE_DSTATE <= 128: - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc = tl.dot(b, dstates) - else: - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, dstate, BLOCK_SIZE_K): - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc += tl.dot(b, dstates) - b_ptrs += BLOCK_SIZE_K * stride_b_dstate - dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize - dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] - - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - ddt = tl.sum(acc * x, axis=1) - ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) - ddA_cs = -(ddt * dt_m) - ddA_cs_last = -tl.sum(ddA_cs) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) - - dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty) - dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head - dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) - tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - - -@triton.paddle_autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - ], - key=["chunk_size", "dstate", "hdim"], -) -@triton.jit -def _chunk_state_bwd_db_kernel( - # Pointers to matrices - x_ptr, - dstates_ptr, - b_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - db_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, - dstate, - hdim, - batch, - seqlen, - nheads, - nheads_per_program, - ngroups, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_dstates_batch, - stride_dstates_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_db_batch, - stride_db_seqlen, - stride_db_split, - stride_db_group, - stride_db_dstate, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize, - # Meta-parameters - HAS_DDA_CS: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_sg = tl.program_id(axis=2) - pid_s = pid_sg // ngroups - pid_g = pid_sg - pid_s * ngroups - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += ( - pid_b * stride_x_batch - + pid_c * chunk_size * stride_x_seqlen - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head - ) - db_ptr += ( - pid_b * stride_db_batch - + pid_c * chunk_size * stride_db_seqlen - + pid_g * stride_db_group - + pid_s * stride_db_split - ) - dstates_ptr += ( - pid_b * stride_dstates_batch - + pid_c * stride_dstates_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head - ) - dt_ptr += ( - pid_b * stride_dt_batch - + pid_c * stride_dt_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head - ) - dA_cumsum_ptr += ( - pid_b * stride_dA_cs_batch - + pid_c * stride_dA_cs_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head - ) - if HAS_DDA_CS: - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head - ddA_cumsum_ptr += ( - pid_b * stride_ddA_cs_batch - + pid_c * stride_ddA_cs_chunk - + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head - ) - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim) - dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize - if HAS_DDA_CS: - b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - if HAS_DDA_CS: - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to( - tl.float32 - ) - if HAS_SEQ_IDX: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) - for h in range(nheads_iter): - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) - dstates = dstates.to(x_ptrs.dtype.element_ty) - db = tl.dot(x, dstates) - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) - else: - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - db *= (scale * dt_m)[:, None] - if HAS_DDA_CS: - # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum - ddA_cs = tl.sum(db * b, axis=1) - tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) - acc += db - x_ptrs += stride_x_head - dstates_ptrs += stride_states_head - dt_ptrs += stride_dt_head - dA_cumsum_ptr += stride_dA_cs_head - dA_cumsum_ptrs += stride_dA_cs_head - if HAS_DDA_CS: - ddA_cumsum_ptrs += stride_ddA_cs_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - # if HAS_SEQ_IDX: - # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0) - db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate) - tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) - - -@triton.paddle_autotune( - configs=[ - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config( - {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=3, - num_warps=4, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=8, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=8, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=8, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=8, - pre_hook=init_to_zero(["ddA_cumsum_ptr"]), - ), - ], - key=["chunk_size", "hdim", "dstate"], -) -@triton.jit -def _chunk_state_bwd_ddAcs_stable_kernel( - # Pointers to matrices - x_ptr, - b_ptr, - dstates_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, - hdim, - dstate, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_dstates_batch, - stride_dstates_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_ddA_cs_csize, - # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - b_ptr += ( - pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head - ) - dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) - dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) - if BLOCK_SIZE_DSTATE <= 128: - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc = tl.dot(b, dstates) - else: - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, dstate, BLOCK_SIZE_K): - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc += tl.dot(b, dstates) - b_ptrs += BLOCK_SIZE_K * stride_b_dstate - dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) - else: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - acc *= scale[:, None] - - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - ddt = tl.sum(acc * x, axis=1) - # ddA_cs = -(ddt * dt_m) - # Triton 2.2.0 errors if we have the cumsum here, so we just write it out - # then call paddle.cumsum outside this kernel. - # ddA_cs = tl.cumsum(ddt * dt_m) - ddA_cs = ddt * dt_m - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, num_warps=2), - ], - key=["hdim", "dstate", "chunk_size"], -) -@triton.jit -def _chunk_state_varlen_kernel( - # Pointers to matrices - x_ptr, - b_ptr, - dt_ptr, - dA_cumsum_ptr, - chunk_states_ptr, - cu_seqlens_ptr, - states_ptr, - # Matrix dimensions - hdim, - dstate, - chunk_size, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_chunk_states_chunk, - stride_chunk_states_head, - stride_chunk_states_hdim, - stride_chunk_states_dstate, - stride_states_batch, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) - pid_c = (end_idx - 1) // chunk_size - b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head - x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) - dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - - chunk_size_limit = end_idx - pid_c * chunk_size - start_idx = tl.load(cu_seqlens_ptr + pid_b) - start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load( - x_ptrs, - mask=(offs_m[:, None] < hdim) - & (offs_k[None, :] < chunk_size_limit - k) - & (offs_k[None, :] >= start_idx_cur - k), - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) - & (offs_n[None, :] < dstate) - & (offs_k[:, None] >= start_idx_cur - k), - other=0.0, - ).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - scale = tl.where( - (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0 - ) - b *= scale[:, None] - b = b.to(x_ptr.dtype.element_ty) - acc += tl.dot(x, b) - x_ptrs += BLOCK_SIZE_K * stride_x_seqlen - b_ptrs += BLOCK_SIZE_K * stride_b_seqlen - dt_ptrs += BLOCK_SIZE_K * stride_dt_csize - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk - if start_idx < pid_c * chunk_size: - chunk_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate - ) - chunk_states = tl.load( - chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0 - ).to(tl.float32) - # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0) - scale = tl.exp(dA_cs_last) - acc += chunk_states * scale - - states = acc.to(states_ptr.dtype.element_ty) - - states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) - c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) - tl.store(states_ptrs, states, mask=c_mask) - - -def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): - batch, seqlen, nheads = dt.shape - assert A.shape[0] == nheads - if dt_bias is not None: - assert dt_bias.shape[0] == nheads - nchunks = math.ceil(seqlen / chunk_size) - dt_out = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - dA_cumsum = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) - _chunk_cumsum_fwd_kernel[grid_chunk_cs]( - dt, - A, - dt_bias, - dt_out, - dA_cumsum, - batch, - seqlen, - nheads, - chunk_size, - dt_limit[0], - dt_limit[1], - dt.strides[0], - dt.strides[1], - dt.strides[2], - A.strides[0], - dt_bias.strides[0] if dt_bias is not None else 0, - dt_out.strides[0], - dt_out.strides[2], - dt_out.strides[1], - dt_out.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - dt_softplus, - HAS_DT_BIAS=dt_bias is not None, - BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), - ) - return dA_cumsum, dt_out - - -def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None): - batch, seqlen, nheads = dt.shape - _, _, nchunks, chunk_size = ddA.shape - assert tuple(ddA.shape) == (batch, nheads, nchunks, chunk_size) - assert tuple(ddt_out.shape) == (batch, nheads, nchunks, chunk_size) - assert A.shape[0] == nheads - if dt_bias is not None: - assert dt_bias.shape[0] == nheads - ddt_bias = paddle.empty_like(dt_bias, dtype=paddle.float32) - else: - ddt_bias = None - if ddt is not None: - assert ddt.shape == dt.shape - else: - ddt = paddle.empty_like(dt) - dA = paddle.empty_like(A, dtype=paddle.float32) - grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) - _chunk_cumsum_bwd_kernel[grid_chunk_cs]( - ddA, - ddt_out, - dt, - A, - dt_bias, - ddt, - dA, - ddt_bias, - batch, - seqlen, - nheads, - chunk_size, - dt_limit[0], - dt_limit[1], - ddA.strides[0], - ddA.strides[2], - ddA.strides[1], - ddA.strides[3], - ddt_out.strides[0], - ddt_out.strides[2], - ddt_out.strides[1], - ddt_out.strides[3], - dt.strides[0], - dt.strides[1], - dt.strides[2], - A.strides[0], - dt_bias.strides[0] if dt_bias is not None else 0, - ddt.strides[0], - ddt.strides[1], - ddt.strides[2], - dA.strides[0], - ddt_bias.strides[0] if ddt_bias is not None else 0, - dt_softplus, - HAS_DT_BIAS=dt_bias is not None, - BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), - ) - return ddt, dA, ddt_bias - - -def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - if states is not None: - assert tuple(states.shape) == (batch, nchunks, nheads, headdim, dstate) - else: - states_dtype = paddle.float32 if states_in_fp32 else B.dtype - states = paddle.empty((batch, nchunks, nheads, headdim, dstate), dtype=states_dtype) - grid = lambda META: ( - triton.cdiv(headdim, META["BLOCK_SIZE_M"]) * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - _chunk_state_fwd_kernel[grid]( - x, - B, - states, - dt, - dA_cumsum, - seq_idx, - headdim, - dstate, - chunk_size, - batch, - seqlen, - nheads // ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - B.strides[0], - B.strides[1], - B.strides[2], - B.strides[-1], - states.strides[0], - states.strides[1], - states.strides[2], - states.strides[3], - states.strides[4], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - HAS_SEQ_IDX=seq_idx is not None, - ) - return states - - -def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert tuple(dstates.shape) == (batch, nchunks, nheads, headdim, dstate) - if dx is not None: - assert dx.shape == x.shape - else: - dx = paddle.empty_like(x) - ddt = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - ddA_cumsum = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - grid_dx = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - _chunk_state_bwd_dx_kernel[grid_dx]( - x, - B, - dstates, - dt, - dA_cumsum, - dx, - ddt, - ddA_cumsum, - chunk_size, - headdim, - dstate, - batch, - seqlen, - nheads // ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - B.strides[0], - B.strides[1], - B.strides[2], - B.strides[-1], - dstates.strides[0], - dstates.strides[1], - dstates.strides[2], - dstates.strides[3], - dstates.strides[4], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - dx.strides[0], - dx.strides[1], - dx.strides[2], - dx.strides[3], - ddt.strides[0], - ddt.strides[2], - ddt.strides[1], - ddt.strides[3], - ddA_cumsum.strides[0], - ddA_cumsum.strides[2], - ddA_cumsum.strides[1], - ddA_cumsum.strides[3], - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - ) - return dx, ddt.cast(dt.dtype), ddA_cumsum.cast(dA_cumsum.dtype) - - -def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - dstate = dstates.shape[-1] - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert tuple(dstates.shape) == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - if B is not None: - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - B_strides = (B.strides[0], B.strides[1], B.strides[2], B.strides[3]) - # Use paddle.empty since the Triton kernel will call init_to_zero - ddA_cumsum = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - ddA_cumsum_strides = ( - ddA_cumsum.strides[0], - ddA_cumsum.strides[2], - ddA_cumsum.strides[1], - ddA_cumsum.strides[3], - ) - else: - B_strides = (0, 0, 0, 0) - ddA_cumsum = None - ddA_cumsum_strides = (0, 0, 0, 0) - nheads_ngroups_ratio = nheads // ngroups - sm_count = paddle.device.cuda.get_device_properties(paddle.get_device()).multi_processor_count - nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) - nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) - dB = paddle.empty([batch, seqlen, nsplits, ngroups, dstate], dtype=paddle.float32) - grid_db = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch * nchunks, - nsplits * ngroups, - ) - _chunk_state_bwd_db_kernel[grid_db]( - x, - dstates, - B, - dt, - dA_cumsum, - seq_idx, - dB, - ddA_cumsum, - chunk_size, - dstate, - headdim, - batch, - seqlen, - nheads, - nheads_per_program, - ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - dstates.strides[0], - dstates.strides[1], - dstates.strides[2], - dstates.strides[3], - dstates.strides[4], - *B_strides, - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - dB.strides[0], - dB.strides[1], - dB.strides[2], - dB.strides[3], - dB.strides[4], - *ddA_cumsum_strides, - HAS_DDA_CS=ddA_cumsum is not None, - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - dB = dB.sum(2) - if ddA_cumsum is not None: - # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute - # to the state of the chunk. - # paddle.cumsum(ddA_cumsum[..., 1:], axis=-1, out=ddA_cumsum[..., 1:]) - # But it's easier to just do the cumsum for all elements, the result will be the same. - ddA_cumsum = paddle.cumsum(ddA_cumsum, axis=-1) - return dB if B is None else (dB, ddA_cumsum) - - -def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert tuple(dstates.shape) == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - # Use paddle.empty since the Triton kernel will call init_to_zero - ddA_cumsum = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - grid_ddtcs = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs]( - x, - B, - dstates, - dt, - dA_cumsum, - seq_idx, - ddA_cumsum, - chunk_size, - headdim, - dstate, - batch, - seqlen, - nheads // ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - B.strides[0], - B.strides[1], - B.strides[2], - B.strides[-1], - dstates.strides[0], - dstates.strides[1], - dstates.strides[2], - dstates.strides[3], - dstates.strides[4], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - ddA_cumsum.strides[0], - ddA_cumsum.strides[2], - ddA_cumsum.strides[1], - ddA_cumsum.strides[3], - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - ) - ddA_cumsum[..., 1:] = paddle.cumsum(ddA_cumsum[..., 1:], axis=-1) - return ddA_cumsum - - -def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): - total_seqlen, nheads, headdim = x.shape - _, nchunks, chunk_size = dt.shape - _, ngroups, dstate = B.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - assert nheads % ngroups == 0 - assert tuple(B.shape) == (total_seqlen, ngroups, dstate) - assert tuple(dt.shape) == (nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert tuple(chunk_states.shape) == (nchunks, nheads, headdim, dstate) - states = paddle.empty([batch, nheads, headdim, dstate], dtype=chunk_states.dtype) - grid = lambda META: ( - triton.cdiv(headdim, META["BLOCK_SIZE_M"]) * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch, - nheads, - ) - _chunk_state_varlen_kernel[grid]( - x, - B, - dt, - dA_cumsum, - chunk_states, - cu_seqlens, - states, - headdim, - dstate, - chunk_size, - total_seqlen, - nheads // ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - B.strides[0], - B.strides[1], - B.strides[2], - dt.strides[1], - dt.strides[0], - dt.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - chunk_states.strides[0], - chunk_states.strides[1], - chunk_states.strides[2], - chunk_states.strides[3], - states.strides[0], - states.strides[1], - states.strides[2], - states.strides[3], - ) - return states - - -class ChunkStateFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert seqlen <= nchunks * chunk_size - _, _, ngroups, dstate = B.shape - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert tuple(dA_cumsum.shape) == (batch, nheads, nchunks, chunk_size) - if B.strides[-1] != 1: - B = B.contiguous() - if x.strides[-1] != 1 and x.strides[1] != 1: # Either M or K dimension should be contiguous - x = x.contiguous() - states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32) - ctx.save_for_backward(B, x, dt, dA_cumsum) - return states - - @staticmethod - @custom_bwd - def backward(ctx, dstates): - B, x, dt, dA_cumsum = ctx.saved_tensor() - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert tuple(dstates.shape) == (batch, nchunks, nheads, headdim, dstate) - if dstates.strides[-1] != 1: - dstates = dstates.contiguous() - dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates) - dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups) - dB = dB.cast(B.dtype) - return dB, dx, ddt, ddA_cumsum, None - - -def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True): - """ - Argument: - B: (batch, seqlen, ngroups, headdim) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) - dA_cumsum: (batch, nheads, nchunks, chunk_size) - Return: - states: (batch, nchunks, nheads, headdim, dstate) - """ - return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32) - - -def chunk_state_ref(B, x, dt, dA_cumsum): - """ - Argument: - B: (batch, seqlen, ngroups, headdim) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) - dA_cumsum: (batch, nheads, nchunks, chunk_size) - Return: - states: (batch, nchunks, nheads, headdim, dstate) - """ - # Check constraints. - batch, seqlen, nheads, headdim = x.shape - dstate = B.shape[-1] - _, _, nchunks, chunk_size = dt.shape - assert seqlen <= nchunks * chunk_size - assert tuple(x.shape) == (batch, seqlen, nheads, headdim) - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - ngroups = B.shape[2] - assert nheads % ngroups == 0 - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) - assert tuple(dA_cumsum.shape) == (batch, nheads, nchunks, chunk_size) - if seqlen < nchunks * chunk_size: - x = F.pad(x, (0, 0, 0, nchunks * chunk_size - seqlen), data_format="NHWC") - B = F.pad(B, (0, 0, 0, nchunks * chunk_size - seqlen), data_format="NHWC") - x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) - B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) - decay_states = paddle.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) - return paddle.einsum( - "bclhn,bhcl,bhcl,bclhp->bchpn", B.cast(x.dtype), decay_states.cast(x.dtype), dt.cast(x.dtype), x - ) diff --git a/ops/src/paddlenlp_kernel/triton/mamba/ssd_combined.py b/ops/src/paddlenlp_kernel/triton/mamba/ssd_combined.py deleted file mode 100644 index 0996599343c6..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/ssd_combined.py +++ /dev/null @@ -1,1597 +0,0 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -"""We want triton==2.1.0 or 2.2.0 for this""" - -import operator - -import paddle -import paddle.nn.functional as F -import triton -import triton.language as tl -from einops import rearrange, repeat - -from ...utils import ( - compare_version, - custom_bwd, - custom_fwd, - get_autocast_gpu_dtype, - is_autocast_enabled, -) - -try: - import causal_conv1d_cuda_pd as causal_conv1d_cuda - - from ..cuda.causal_conv1d import causal_conv1d_fn -except ImportError: - causal_conv1d_fn, causal_conv1d_cuda = None, None - - -from .k_activations import _swiglu_bwd, _swiglu_fwd -from .layernorm_gated import _layer_norm_bwd, _layer_norm_fwd, rmsnorm_fn -from .ssd_bmm import _bmm_chunk_bwd, _bmm_chunk_fwd -from .ssd_chunk_scan import ( - _chunk_scan_bwd_dC, - _chunk_scan_bwd_dcb, - _chunk_scan_bwd_ddAcs_stable, - _chunk_scan_bwd_dstates, - _chunk_scan_bwd_dz, - _chunk_scan_fwd, - chunk_scan, - chunk_scan_ref, -) -from .ssd_chunk_state import ( - _chunk_cumsum_bwd, - _chunk_cumsum_fwd, - _chunk_state_bwd_db, - _chunk_state_fwd, - chunk_state, - chunk_state_ref, - chunk_state_varlen, -) -from .ssd_state_passing import ( - _state_passing_bwd, - _state_passing_fwd, - state_passing, - state_passing_ref, -) - -TRITON_22 = compare_version("triton", operator.ge, "2.2.0") - - -def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] - - -@triton.paddle_autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, - num_warps=8, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - pre_hook=init_to_zero(["ddt_ptr"]), - ), - ], - key=["chunk_size", "hdim", "dstate"], -) -@triton.jit -def _chunk_scan_chunk_state_bwd_dx_kernel( - # Pointers to matrices - x_ptr, - cb_ptr, - dout_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - D_ptr, - b_ptr, - dstates_ptr, - dx_ptr, - ddt_ptr, - dD_ptr, - # Matrix dimensions - chunk_size, - hdim, - dstate, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_k, - stride_dout_batch, - stride_dout_seqlen, - stride_dout_head, - stride_dout_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_D_head, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_dstates_batch, - stride_dstates_chunk, - stride_dstates_head, - stride_dstates_hdim, - stride_dstates_dstate, - stride_dx_batch, - stride_dx_seqlen, - stride_dx_head, - stride_dx_hdim, - stride_ddt_batch, - stride_ddt_chunk, - stride_ddt_head, - stride_ddt_csize, - stride_dD_batch, - stride_dD_chunk, - stride_dD_head, - stride_dD_csize, - stride_dD_hdim, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, - IS_TRITON_22: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - b_ptr += ( - pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head - ) - dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to( - tl.float32 - ) - - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) - else: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - # However, we're getting error with the Triton compiler 2.1.0 for that code path: - # Unexpected mma -> mma layout conversion - # Triton 2.2.0 fixes this - offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate) - dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate) - if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc = tl.dot(b, dstates) * scale[:, None] - else: - for k in range(0, dstate, BLOCK_SIZE_K): - b = tl.load( - b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0 - ) - dstates = tl.load( - dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0 - ) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc += tl.dot(b, dstates) - b_ptrs += BLOCK_SIZE_K * stride_b_dstate - dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate - acc *= scale[:, None] - - # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - # dt_ptrs = dt_ptr + offs_m * stride_dt_csize - # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - # ddt = tl.sum(acc * x, axis=1) * dt_m - # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) - - offs_k = tl.arange(0, BLOCK_SIZE_K) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) - dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - K_MAX = chunk_size_limit - K_MIN = pid_m * BLOCK_SIZE_M - cb_ptrs += K_MIN * stride_cb_csize_k - dout_ptrs += K_MIN * stride_dout_seqlen - dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize - for k in range(K_MIN, K_MAX, BLOCK_SIZE_K): - k = tl.multiple_of(k, BLOCK_SIZE_K) - # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) - dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) - # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, - # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. - # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. - # This will cause NaN in acc, and hence NaN in dx and ddt. - mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) - cb = tl.where(mask, cb, 0.0) - cb = cb.to(dout_ptr.dtype.element_ty) - acc += tl.dot(cb, dout) - cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k - dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - dx = acc * dt_m[:, None] - dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head - dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) - if HAS_D: - dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dout_res = tl.load( - dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0 - ).to(tl.float32) - if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - dx += dout_res * D - tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - if HAS_D: - dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize - if D_HAS_HDIM: - dD_ptrs = dD_ptr + offs_n * stride_dD_hdim - dD = tl.sum(dout_res * x, axis=0) - tl.store(dD_ptrs, dD, mask=offs_n < hdim) - else: - dD = tl.sum(dout_res * x) - tl.store(dD_ptr, dD) - ddt = tl.sum(acc * x, axis=1) - ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) - - -def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(CB.shape) == (batch, nchunks, ngroups, chunk_size, chunk_size) - assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dout.shape == x.shape - assert tuple(dstates.shape) == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - if D is not None: - assert tuple(D.shape) == (nheads, headdim) or D.shape[0] == nheads - assert D.strides[-1] == 1 - BLOCK_SIZE_min = 32 - dD = paddle.empty( - [triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, headdim if D.dim() == 2 else 1], - dtype=paddle.float32, - ) - else: - dD = None - dD_strides = ( - (dD.strides[0], dD.strides[1], dD.strides[2], dD.strides[3], dD.strides[4]) - if D is not None - else (0, 0, 0, 0, 0) - ) - if dx is None: - dx = paddle.empty_like(x) - else: - assert dx.shape == x.shape - ddt = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32) - grid_dx = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( - x, - CB, - dout, - dt, - dA_cumsum, - seq_idx, - D, - B, - dstates, - dx, - ddt, - dD, - chunk_size, - headdim, - dstate, - batch, - seqlen, - nheads // ngroups, - x.strides[0], - x.strides[1], - x.strides[2], - x.strides[3], - CB.strides[0], - CB.strides[1], - CB.strides[2], - CB.strides[-1], - CB.strides[-2], - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - dt.strides[0], - dt.strides[2], - dt.strides[1], - dt.strides[3], - dA_cumsum.strides[0], - dA_cumsum.strides[2], - dA_cumsum.strides[1], - dA_cumsum.strides[3], - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - D.strides[0] if D is not None else 0, - B.strides[0], - B.strides[1], - B.strides[2], - B.strides[3], - dstates.strides[0], - dstates.strides[1], - dstates.strides[2], - dstates.strides[3], - dstates.strides[4], - dx.strides[0], - dx.strides[1], - dx.strides[2], - dx.strides[3], - ddt.strides[0], - ddt.strides[2], - ddt.strides[1], - ddt.strides[3], - dD_strides[1], - dD_strides[2], - dD_strides[3], - dD_strides[0], - dD_strides[4], - D is not None, - D.dim() == 2 if D is not None else True, - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - IS_TRITON_22=TRITON_22, - ) - if D is not None: - BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - dD = dD[:n_valid_blocks].sum(axis=(0, 1, 2)).cast(dtype=D.dtype) - if D.dim() == 1: - dD = rearrange(dD, "h 1 -> h") - return dx, ddt.cast(dtype=dt.dtype), dD - - -def _mamba_chunk_scan_combined_fwd( - x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), -): - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - assert tuple(x.shape) == (batch, seqlen, nheads, headdim) - assert tuple(dt.shape) == (batch, seqlen, nheads) - assert A.shape[0] == nheads - assert C.shape == B.shape - if z is not None: - assert z.shape == x.shape - if D is not None: - assert D.shape == [nheads, headdim] or D.shape[0] == nheads - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - if B.strides[-1] != 1: - B = B.contiguous() - if C.strides[-1] != 1: - C = C.contiguous() - if x.strides[-1] != 1 and x.strides[1] != 1: # Either M or K dimension should be contiguous - x = x.contiguous() - if z is not None and z.strides[-1] != 1 and z.strides[1] != 1: # Either M or K dimension should be contiguous - z = z.contiguous() - if D is not None and D.strides[-1] != 1: - D = D.contiguous() - if initial_states is not None: - assert tuple(initial_states.shape) == (batch, nheads, headdim, dstate) - # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size) - # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) - states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) - # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) - # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) - # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) - states, final_states = _state_passing_fwd( - rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], - initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, - seq_idx=seq_idx, - chunk_size=chunk_size, - out_dtype=C.dtype, - ) - states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] - # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) - # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) - CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=paddle.float32) - out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) - if cu_seqlens is None: - return out, out_x, dt, dA_cumsum, states, final_states - else: - assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - varlen_states = chunk_state_varlen( - B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0), cu_seqlens, states.squeeze(0) - ) - return out, out_x, dt, dA_cumsum, states, final_states, varlen_states - - -def _mamba_chunk_scan_combined_bwd( - dout, - x, - dt, - A, - B, - C, - out, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - dfinal_states=None, - seq_idx=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - dx=None, - ddt=None, - dB=None, - dC=None, - dz=None, - recompute_output=False, -): - if dout.strides[-1] != 1: - dout = dout.contiguous() - batch, seqlen, nheads, headdim = x.shape - # nchunks = math.ceil(seqlen / chunk_size) - _, _, ngroups, dstate = B.shape - assert tuple(dout.shape) == (batch, seqlen, nheads, headdim) - assert tuple(dt.shape) == (batch, seqlen, nheads) - assert A.shape[0] == nheads - assert nheads % ngroups == 0 - assert tuple(B.shape) == (batch, seqlen, ngroups, dstate) - assert C.shape == B.shape - assert out.shape == x.shape - if initial_states is not None: - assert tuple(initial_states.shape) == (batch, nheads, headdim, dstate) - if seq_idx is not None: - assert tuple(seq_idx.shape) == (batch, seqlen) - if dx is not None: - assert dx.shape == x.shape - if dB is not None: - assert dB.shape == B.shape - dB_given = dB - else: - dB_given = paddle.empty_like(B) - if dC is not None: - assert dC.shape == C.shape - dC_given = dC - else: - dC_given = paddle.empty_like(C) - if dz is not None: - assert z is not None - assert dz.shape == z.shape - if ddt is not None: - assert ddt.shape == dt.shape - ddt_given = ddt - else: - ddt_given = paddle.empty_like(dt) - # TD: For some reason Triton (2.1.0 and 2.2.0) errors with - # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why. - dt_in = dt.clone() - dA_cumsum, dt = _chunk_cumsum_fwd( - dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit - ) - CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=paddle.float32) - states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) - states, _ = _state_passing_fwd( - rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], - initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, - seq_idx=seq_idx, - chunk_size=chunk_size, - ) - states = rearrange(states, "... (p n) -> ... p n", n=dstate) - if z is not None: - dz, dout, dD, *rest = _chunk_scan_bwd_dz( - x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output - ) - outz = rest[0] if recompute_output else out - else: - dz = None - outz = out - dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype) - # dstates has length nchunks, containing the gradient to initial states at index 0 and - # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1) - # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states - # will be used in matmul in the next kernels. - dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd( - rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], - rearrange(dstates, "... p n -> ... (p n)"), - dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None, - seq_idx=seq_idx, - has_initial_states=initial_states is not None, - dstates_dtype=x.dtype, - states_dtype=x.dtype, - chunk_size=chunk_size, - ) - # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and - # gradient to the final states at index (nchunks - 1) - # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1) - # The final states is not stored. - states = rearrange(states, "... (p n) -> ... p n", n=dstate) - dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate) - dinitial_states = ( - rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None - ) - dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx( - x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx - ) - # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups) - dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) - # dC = _chunk_scan_bwd_dC(states[:, :-1].cast(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) - dC, ddA_cumsum_prev = _chunk_scan_bwd_dC( - states.cast(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups - ) - # Computing ddA with the dcb kernel is much slower, so we're not using it for now - dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) - # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups) - dCB = dCB.cast(CB.dtype) - _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given) - _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given) - # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate - # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16 - if z is None: - dD = dD_from_x - # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. - # ddA_cumsum = paddle.einsum("bclhp,bclhp->bhcl", out.cast("float32"), dout.cast("float32")) - ddt * dt - # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might - # be a lot of underflow. - - # This is already done as part of bwd_dC kernel - # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx) - ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum - ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(axis=-1).flip([-1]) - # This is already done as part of bwd_dB kernel - # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx) - # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j] - ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB) - ddA += ddA_next + ddA_prev - - ddt_given, dA, ddt_bias = _chunk_cumsum_bwd( - ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given - ) - - # These 2 lines are just to test ddt and dA being computed by old code - # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.cast("float32"), z=z) - # ddt_given.copy_(ddt) - - return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states) - return return_vals if not recompute_output else (*return_vals, outz) - - -def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None): - """ - Argument: - dout: (batch, seqlen, nheads, headdim) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size) - A: (nheads) or (dim, dstate) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - Return: - out: (batch, seqlen, nheads, headdim) - """ - from ...cuda import selective_scan_cuda_pd as selective_scan - - batch, seqlen, nheads, headdim = x.shape - chunk_size = dt.shape[-1] - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - x = rearrange(x, "b l h p -> b (h p) l") - squeeze_dt = dt.dim() == 4 - if dt.dim() == 4: - dt = repeat(dt, "b h c l -> b h p c l", p=headdim) - dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim) - squeeze_A = A.dim() == 1 - if A.dim() == 1: - A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).cast(dtype=paddle.float32) - else: - A = A.cast(dtype=paddle.float32) - B = rearrange(B, "b l g n -> b g n l") - C = rearrange(C, "b l g n -> b g n l") - if D is not None: - if D.dim() == 2: - D = rearrange(D, "h p -> (h p)") - else: - D = repeat(D, "h -> (h p)", p=headdim) - if z is not None: - z = rearrange(z, "b l h p -> b (h p) l") - - if x.strides[-1] != 1: - x = x.contiguous() - if dt.strides[-1] != 1: - dt = dt.contiguous() - if D is not None: - D = D.contiguous() - if B.strides[-1] != 1: - B = B.contiguous() - if C.strides[-1] != 1: - C = C.contiguous() - if z is not None and z.strides[-1] != 1: - z = z.contiguous() - _, intermediate, *rest = selective_scan.fwd(x, dt.cast(dtype=x.dtype), A, B, C, D, z, None, False) - if z is not None: - out = rest[0] - else: - out = None - - dout = rearrange(dout, "b l h p -> b (h p) l") - - if dout.strides[-1] != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan with the backward of chunk). - # Here we just pass in None and dz will be allocated in the C++ code. - _, ddt, dA, *rest = selective_scan.bwd( - x, - dt.cast(dtype=x.dtype), - A, - B, - C, - D, - z, - None, - dout, - intermediate, - out, - None, - False, - False, # option to recompute out_z, not used here - ) - ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size) - if squeeze_dt: - ddt = ddt.cast("float32").sum(axis=2) - if squeeze_A: - dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(axis=(1, 2)) - return ddt, dA - - -class MambaChunkScanCombinedFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - return_final_states=False, - return_varlen_states=False, - ): - ctx.dt_dtype = dt.dtype - if not return_varlen_states: - cu_seqlens = None - else: - assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" - out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( - x, - dt, - A, - B, - C, - chunk_size, - D=D, - z=z, - dt_bias=dt_bias, - initial_states=initial_states, - seq_idx=seq_idx, - cu_seqlens=cu_seqlens, - dt_softplus=dt_softplus, - dt_limit=dt_limit, - ) - ctx.save_for_backward( - out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx - ) - ctx.dt_softplus = dt_softplus - ctx.chunk_size = chunk_size - ctx.dt_limit = dt_limit - ctx.return_final_states = return_final_states - ctx.return_varlen_states = return_varlen_states - if not return_varlen_states: - return out if not return_final_states else (out, final_states) - else: - varlen_states = rest[0] - return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensor() - assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward" - dfinal_states = args[0] if ctx.return_final_states else None - dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd( - dout, - x, - dt, - A, - B, - C, - out, - ctx.chunk_size, - D=D, - z=z, - dt_bias=dt_bias, - initial_states=initial_states, - dfinal_states=dfinal_states, - seq_idx=seq_idx, - dt_softplus=ctx.dt_softplus, - dt_limit=ctx.dt_limit, - ) - return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None - - -def mamba_chunk_scan_combined( - x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - return_final_states=False, - return_varlen_states=False, -): - """ - Argument: - x: (batch, seqlen, nheads, headdim) - dt: (batch, seqlen, nheads) - A: (nheads) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - chunk_size: int - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - dt_bias: (nheads,) - initial_states: (batch, nheads, headdim, dstate) - seq_idx: (batch, seqlen) - cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True - dt_softplus: Whether to apply softplus to dt - Return: - out: (batch, seqlen, nheads, headdim) - """ - return MambaChunkScanCombinedFn.apply( - x, - dt, - A, - B, - C, - chunk_size, - D, - z, - dt_bias, - initial_states, - seq_idx, - cu_seqlens, - dt_softplus, - dt_limit, - return_final_states, - return_varlen_states, - ) - - -def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): - """ - Argument: - x: (batch, seqlen, nheads, headdim) - dt: (batch, seqlen, nheads) - A: (nheads) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - dt_bias: (nheads,) - Return: - out: (batch, seqlen, nheads, headdim) - """ - batch, seqlen, nheads, headdim = x.shape - dstate = B.shape[-1] - if seqlen % chunk_size != 0: - dt = F.pad(dt, (0, chunk_size - seqlen % chunk_size), data_format="NLC") - dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) - dt = dt.cast("float32") # We want high precision for this before cumsum - if dt_bias is not None: - dt = dt + rearrange(dt_bias, "h -> h 1 1") - if dt_softplus: - dt = F.softplus(dt) - dA = dt * rearrange(A, "h -> h 1 1") - dA = dt * rearrange(A, "h -> h 1 1") - dA_cumsum = paddle.cumsum(dA, axis=-1) - # 1. Compute the state for each chunk - states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True) - # 2. Pass the state to all the chunks by weighted cumsum. - states = rearrange( - state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], - "... (p n) -> ... p n", - n=dstate, - ) - # 3. Compute the output for each chunk - out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z) - return out - - -def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): - """ - Argument: - x: (batch, seqlen, nheads, headdim) - dt: (batch, seqlen, nheads) - A: (nheads) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - dt_bias: (nheads,) - Return: - out: (batch, seqlen, nheads, headdim) - """ - batch, seqlen, nheads, headdim = x.shape - dstate = B.shape[-1] - if seqlen % chunk_size != 0: - dt = F.pad(dt, (0, chunk_size - seqlen % chunk_size), data_format="NLC") - dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) - dt = dt.cast("float32") # We want high precision for this before cumsum - if dt_bias is not None: - dt = dt + rearrange(dt_bias, "h -> h 1 1") - if dt_softplus: - dt = F.softplus(dt) - dA = dt * rearrange(A, "h -> h 1 1") - dA_cumsum = paddle.cumsum(dA, axis=-1) - # 1. Compute the state for each chunk - states = chunk_state_ref(B, x, dt, dA_cumsum) - states_dtype = states.dtype - if states.dtype not in [paddle.float32, paddle.float64]: - states = states.cast(paddle.float32) - # 2. Pass the state to all the chunks by weighted cumsum. - # state_passing_ref is much less numerically stable - states = rearrange( - state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], - "... (p n) -> ... p n", - n=dstate, - ) - states = states.cast(states_dtype) - # 3. Compute the output for each chunk - out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z) - return out - - -def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): - """ - Argument: - x: (batch, seqlen, nheads, headdim) - dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) - A: (nheads) or (dim, dstate) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - dt_bias: (nheads,) or (nheads, headdim) - Return: - out: (batch, seqlen, nheads, headdim) - """ - from ...cuda.selective_scan import selective_scan_fn - - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape - x = rearrange(x, "b l h p -> b (h p) l") - if dt.dim() == 3: - dt = repeat(dt, "b l h -> b l h p", p=headdim) - dt = rearrange(dt, "b l h p -> b (h p) l") - if A.dim() == 1: - A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).cast(dtype=paddle.float32) - else: - A = A.cast(dtype=paddle.float32) - B = rearrange(B, "b l g n -> b g n l") - C = rearrange(C, "b l g n -> b g n l") - if D is not None: - if D.dim() == 2: - D = rearrange(D, "h p -> (h p)") - else: - D = repeat(D, "h -> (h p)", p=headdim) - if z is not None: - z = rearrange(z, "b l h p -> b (h p) l") - if dt_bias is not None: - if dt_bias.dim() == 1: - dt_bias = repeat(dt_bias, "h -> h p", p=headdim) - dt_bias = rearrange(dt_bias, "h p -> (h p)") - if dt_limit != (0.0, float("inf")): - if dt_bias is not None: - dt = dt + rearrange(dt_bias, "d -> d 1") - if dt_softplus: - dt = F.softplus(dt) - dt = dt.clip(min=dt_limit[0], max=dt_limit[1]).cast(x.dtype) - dt_bias = None - dt_softplus = None - out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus) - return rearrange(out, "b (h p) l -> b l h p", p=headdim) - - -def mamba_conv1d_scan_ref( - xBC, - conv1d_weight, - conv1d_bias, - dt, - A, - chunk_size, - D=None, - z=None, - dt_bias=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - activation="silu", - headdim=None, - ngroups=1, -): - """ - Argument: - xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim - conv1d_weight: (dim + 2 * ngroups * dstate, width) - conv1d_bias: (dim + 2 * ngroups * dstate,) - dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) - A: (nheads) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, dim) - dt_bias: (nheads) or (nheads, headdim) - headdim: if D is 1D and z is None, headdim must be passed in - Return: - out: (batch, seqlen, dim) - """ - batch, seqlen, nheads = dt.shape[:3] - assert nheads % ngroups == 0 - if z is not None: - dim = z.shape[-1] - assert dim % nheads == 0 - headdim = dim // nheads - else: - if D.dim() == 1: - assert headdim is not None - else: - headdim = D.shape[1] - dim = nheads * headdim - xBC = rearrange( - causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), - "b d s -> b s d", - ) - dstate = (xBC.shape[-1] - dim) // ngroups // 2 - x, B, C = paddle.split(xBC, [dim, ngroups * dstate, ngroups * dstate], axis=-1) - x = rearrange(x, "b l (h p) -> b l h p", h=nheads) - B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) - C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) - z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None - out = ssd_selective_scan( - x, - dt.cast(x.dtype), - A, - B, - C, - D=D.cast("float32"), - z=z, - dt_bias=dt_bias, - dt_softplus=dt_softplus, - dt_limit=dt_limit, - ) - return rearrange(out, "b s h p -> b s (h p)") - - -class MambaSplitConv1dScanCombinedFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward( - ctx, - zxbcdt, - conv1d_weight, - conv1d_bias, - dt_bias, - A, - D, - chunk_size, - initial_states=None, - seq_idx=None, - dt_limit=(0.0, float("inf")), - return_final_states=False, - activation="silu", - rmsnorm_weight=None, - rmsnorm_eps=1e-6, - outproj_weight=None, - outproj_bias=None, - headdim=None, - ngroups=1, - norm_before_gate=True, - ): - assert activation in [None, "silu", "swish"] - if D.dim() == 1: - assert headdim is not None - nheads = D.shape[0] - else: - nheads, headdim = D.shape - batch, seqlen, _ = zxbcdt.shape - dim = nheads * headdim - assert nheads % ngroups == 0 - dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2 - d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2 - assert d_nonssm >= 0 - assert tuple(zxbcdt.shape) == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads) - assert dt_bias.shape[0] == nheads - assert A.shape[0] == nheads - zx0, z, xBC, dt = paddle.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], axis=-1) - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - xBC_conv = rearrange( - causal_conv1d_cuda.causal_conv1d_fwd( - rearrange(xBC, "b s d -> b d s"), - conv1d_weight, - conv1d_bias, - seq_idx, - None, - None, - activation in ["silu", "swish"], - ), - "b d s -> b s d", - ) - x, B, C = paddle.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], axis=-1) - x = rearrange(x, "b l (h p) -> b l h p", h=nheads) - B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) - C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) - z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None - if rmsnorm_weight is None: - out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd( - x, - dt, - A, - B, - C, - chunk_size=chunk_size, - D=D, - z=z, - dt_bias=dt_bias, - initial_states=initial_states, - seq_idx=seq_idx, - dt_softplus=True, - dt_limit=dt_limit, - ) - out = rearrange(out, "b s h p -> b s (h p)") - rstd = None - if d_nonssm > 0: - out = paddle.concat([_swiglu_fwd(zx0), out], axis=-1) - else: - out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd( - x, - dt, - A, - B, - C, - chunk_size=chunk_size, - D=D, - z=None, - dt_bias=dt_bias, - initial_states=initial_states, - seq_idx=seq_idx, - dt_softplus=True, - dt_limit=dt_limit, - ) - # reshape input data into 2D tensor - x_rms = rearrange(out_x, "b s h p -> (b s) (h p)") - z_rms = rearrange(z, "b s h p -> (b s) (h p)") - rmsnorm_weight = rmsnorm_weight.contiguous() - if d_nonssm == 0: - out = None - else: - out01 = paddle.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype) - out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d") - _swiglu_fwd(zx0, out=out01[..., :d_nonssm]) - out, _, rstd = _layer_norm_fwd( - x_rms, - rmsnorm_weight, - None, - rmsnorm_eps, - z_rms, - out=out, - group_size=dim // ngroups, - norm_before_gate=norm_before_gate, - is_rms_norm=True, - ) - if d_nonssm == 0: - out = rearrange(out, "(b s) d -> b s d", b=batch) - else: - out = out01 - ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None - if outproj_weight is not None: - if is_autocast_enabled(): - dtype = get_autocast_gpu_dtype() - out, outproj_weight = out.cast(dtype), outproj_weight.cast(dtype) - outproj_bias = outproj_bias.cast(dtype) if outproj_bias is not None else None - out = F.linear(out, outproj_weight, outproj_bias) - else: - assert outproj_bias is None - ctx.save_for_backward( - zxbcdt, - conv1d_weight, - conv1d_bias, - out_x, - A, - D, - dt_bias, - initial_states, - seq_idx, - rmsnorm_weight, - rstd, - outproj_weight, - outproj_bias, - ) - ctx.dt_limit = dt_limit - ctx.return_final_states = return_final_states - ctx.activation = activation - ctx.rmsnorm_eps = rmsnorm_eps - ctx.norm_before_gate = norm_before_gate - ctx.chunk_size = chunk_size - ctx.headdim = headdim - ctx.ngroups = ngroups - return out if not return_final_states else (out, final_states) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - ( - zxbcdt, - conv1d_weight, - conv1d_bias, - out, - A, - D, - dt_bias, - initial_states, - seq_idx, - rmsnorm_weight, - rstd, - outproj_weight, - outproj_bias, - ) = ctx.saved_tensor() - dfinal_states = args[0] if ctx.return_final_states else None - headdim = ctx.headdim - nheads = D.shape[0] - dim = nheads * headdim - assert nheads % ctx.ngroups == 0 - dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2 - d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2 - assert d_nonssm >= 0 - recompute_output = outproj_weight is not None - if recompute_output: - out_recompute = paddle.empty([*out.shape[:2], d_nonssm + dim], dtype=out.dtype) - out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], axis=-1) - zx0, z, xBC, dt = paddle.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], axis=-1) - # Recompute x, B, C - xBC_conv = rearrange( - causal_conv1d_cuda.causal_conv1d_fwd( - rearrange(xBC, "b s d -> b d s"), - conv1d_weight, - conv1d_bias, - seq_idx, - None, - None, - ctx.activation in ["silu", "swish"], - ), - "b d s -> b s d", - ) - x, B, C = paddle.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], axis=-1) - x = rearrange(x, "b l (h p) -> b l h p", h=nheads) - B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups) - C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups) - dzxbcdt = paddle.empty_like(zxbcdt) - dzx0, dz, dxBC_given, ddt_given = paddle.split( - dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], axis=-1 - ) - dxBC = paddle.empty_like(xBC) - dx, dB, dC = paddle.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], axis=-1) - z = rearrange(z, "b l (h p) -> b l h p", h=nheads) - dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads) - dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups) - dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups) - if outproj_weight is not None: - dout_og = dout - dout = F.linear(dout, outproj_weight.t()) - if d_nonssm > 0: - dout0, dout = dout.split([d_nonssm, dim], axis=-1) - _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute) - dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim) - if rmsnorm_weight is None: - dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads) - dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd( - dout, - x, - dt, - A, - B, - C, - out, - ctx.chunk_size, - D=D, - z=z, - dt_bias=dt_bias, - initial_states=initial_states, - dfinal_states=dfinal_states, - seq_idx=seq_idx, - dt_softplus=True, - dt_limit=ctx.dt_limit, - dx=dx, - ddt=ddt_given, - dB=dB, - dC=dC, - dz=dz, - recompute_output=recompute_output, - ) - out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None - drmsnorm_weight = None - else: - batch = dout.shape[0] - dy_rms = rearrange(dout, "b s h p -> (b s) (h p)") - dz = rearrange(dz, "b l d -> (b l) d") - x_rms = rearrange(out, "b s h p -> (b s) (h p)") - z_rms = rearrange(z, "b s h p -> (b s) (h p)") - out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None - dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd( - dy_rms, - x_rms, - rmsnorm_weight, - None, - ctx.rmsnorm_eps, - None, - rstd, - z_rms, - norm_before_gate=ctx.norm_before_gate, - is_rms_norm=True, - recompute_output=recompute_output, - dz=dz, - out=out1_recompute if recompute_output else None, - ) - out_for_linear = out_recompute if recompute_output else None - dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim) - dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd( - dout, - x, - dt, - A, - B, - C, - out, - ctx.chunk_size, - D=D, - z=None, - dt_bias=dt_bias, - initial_states=initial_states, - dfinal_states=dfinal_states, - seq_idx=seq_idx, - dt_softplus=True, - dt_limit=ctx.dt_limit, - dx=dx, - ddt=ddt_given, - dB=dB, - dC=dC, - ) - - if outproj_weight is not None: - doutproj_weight = paddle.einsum("bso,bsd->od", dout_og, out_for_linear) - doutproj_bias = dout_og.sum(axis=(0, 1)) if outproj_bias is not None else None - else: - doutproj_weight, doutproj_bias = None, None - dxBC_given = rearrange(dxBC_given, "b s d -> b d s") - dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( - rearrange(xBC, "b s d -> b d s"), - conv1d_weight, - conv1d_bias, - rearrange(dxBC, "b s d -> b d s"), - seq_idx, - None, - None, - dxBC_given, - False, - ctx.activation in ["silu", "swish"], - ) - dxBC_given = rearrange(dxBC_given, "b d s -> b s d") - return ( - dzxbcdt, - dweight, - dbias, - ddt_bias, - dA, - dD, - None, - dinitial_states, - None, - None, - None, - None, - drmsnorm_weight, - None, - doutproj_weight, - doutproj_bias, - None, - None, - None, - ) - - -def mamba_split_conv1d_scan_combined( - zxbcdt, - conv1d_weight, - conv1d_bias, - dt_bias, - A, - D, - chunk_size, - initial_states=None, - seq_idx=None, - dt_limit=(0.0, float("inf")), - return_final_states=False, - activation="silu", - rmsnorm_weight=None, - rmsnorm_eps=1e-6, - outproj_weight=None, - outproj_bias=None, - headdim=None, - ngroups=1, - norm_before_gate=True, -): - """ - Argument: - zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim - conv1d_weight: (dim + 2 * ngroups * dstate, width) - conv1d_bias: (dim + 2 * ngroups * dstate,) - dt_bias: (nheads,) - A: (nheads) - D: (nheads, headdim) or (nheads,) - initial_states: (batch, nheads, headdim, dstate) - seq_idx: (batch, seqlen), int32 - rmsnorm_weight: (dim,) - outproj_weight: (out_dim, dim) - outproj_bias: (out_dim,) - headdim: if D is 1D, headdim must be passed in - norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) - Return: - out: (batch, seqlen, dim) - """ - return MambaSplitConv1dScanCombinedFn.apply( - zxbcdt, - conv1d_weight, - conv1d_bias, - dt_bias, - A, - D, - chunk_size, - initial_states, - seq_idx, - dt_limit, - return_final_states, - activation, - rmsnorm_weight, - rmsnorm_eps, - outproj_weight, - outproj_bias, - headdim, - ngroups, - norm_before_gate, - ) - - -def mamba_split_conv1d_scan_ref( - zxbcdt, - conv1d_weight, - conv1d_bias, - dt_bias, - A, - D, - chunk_size, - dt_limit=(0.0, float("inf")), - activation="silu", - rmsnorm_weight=None, - rmsnorm_eps=1e-6, - outproj_weight=None, - outproj_bias=None, - headdim=None, - ngroups=1, - norm_before_gate=True, -): - """ - Argument: - zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim - conv1d_weight: (dim + 2 * ngroups * dstate, width) - conv1d_bias: (dim + 2 * ngroups * dstate,) - dt_bias: (nheads,) - A: (nheads) - D: (nheads, headdim) or (nheads,) - rmsnorm_weight: (dim,) - outproj_weight: (out_dim, dim) - outproj_bias: (out_dim,) - headdim: if D is 1D, headdim must be passed in - norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) - Return: - out: (batch, seqlen, dim) - """ - if D.dim() == 1: - assert headdim is not None - (nheads,) = D.shape - else: - nheads, headdim = D.shape - assert nheads % ngroups == 0 - batch, seqlen, _ = zxbcdt.shape - dim = nheads * headdim - dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2 - assert tuple(zxbcdt.shape) == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) - assert dt_bias.shape[0] == nheads - assert A.shape[0] == nheads - if rmsnorm_weight is not None: - assert rmsnorm_weight.shape[0] == dim - z, xBC, dt = paddle.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], axis=-1) - xBC = rearrange( - causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), - "b d s -> b s d", - ) - x, B, C = paddle.split(xBC, [dim, ngroups * dstate, ngroups * dstate], axis=-1) - x = rearrange(x, "b l (h p) -> b l h p", h=nheads) - B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) - C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) - z = rearrange(z, "b l (h p) -> b l h p", h=nheads) - out = ssd_selective_scan( - x, - dt.cast(x.dtype), - A, - B, - C, - D=D.cast("float32"), - z=z if rmsnorm_weight is None else None, - dt_bias=dt_bias, - dt_softplus=True, - dt_limit=dt_limit, - ) - out = rearrange(out, "b s h p -> b s (h p)") - if rmsnorm_weight is not None: - out = rmsnorm_fn( - out, - rmsnorm_weight, - None, - z=rearrange(z, "b l h p -> b l (h p)"), - eps=rmsnorm_eps, - norm_before_gate=norm_before_gate, - ) - if outproj_weight is not None: - out = F.linear(out, outproj_weight, outproj_bias) - return out diff --git a/ops/src/paddlenlp_kernel/triton/mamba/ssd_state_passing.py b/ops/src/paddlenlp_kernel/triton/mamba/ssd_state_passing.py deleted file mode 100644 index d84cd7ac6fd7..000000000000 --- a/ops/src/paddlenlp_kernel/triton/mamba/ssd_state_passing.py +++ /dev/null @@ -1,457 +0,0 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. -""" -this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton -""" -"""We want triton==2.1.0 or 2.2.0 for this""" - -import paddle -import paddle.nn.functional as F -import triton -import triton.language as tl -from einops import rearrange - -from ...utils import custom_bwd, custom_fwd - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - ], - key=["dim"], -) -@triton.jit -def _state_passing_fwd_kernel( - # Pointers to matrices - states_ptr, - out_ptr, - final_states_ptr, - dA_cs_ptr, - initstates_ptr, - seq_idx_ptr, - # Matrix dimensions - dim, - nchunks, - seqlen, - chunk_size, - # Strides - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_dim, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_out_dim, - stride_final_states_batch, - stride_final_states_head, - stride_final_states_dim, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_initstates_batch, - stride_initstates_head, - stride_initstates_dim, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - # Meta-parameters - HAS_INITSTATES: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head - if HAS_INITSTATES: - initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch - - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - states_ptrs = states_ptr + offs_m * stride_states_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim - - if not HAS_INITSTATES: - states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - else: - initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim - states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - tl.store(out_ptrs, states, mask=offs_m < dim) - out_ptrs += stride_out_chunk - seq_idx = 0 - for c in range(nchunks): - new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) - if HAS_SEQ_IDX: - seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) - seq_idx = seq_idx_new - states = scale * states + new_states - if c < nchunks - 1: - tl.store(out_ptrs, states, mask=offs_m < dim) - else: - tl.store(final_states_ptrs, states, mask=offs_m < dim) - states_ptrs += stride_states_chunk - dA_cs_ptr += stride_dA_cs_chunk - out_ptrs += stride_out_chunk - - -@triton.paddle_autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - ], - key=["dim"], -) -@triton.jit -def _state_passing_bwd_kernel( - # Pointers to matrices - dout_ptr, - out_ptr, - dA_cs_ptr, - dfinal_states_ptr, - seq_idx_ptr, - dstates_ptr, - ddA_cs_ptr, - dinitstates_ptr, - states_converted_ptr, - # Matrix dimensions - dim, - nchunks, - seqlen, - chunk_size, - # Strides - stride_dout_batch, - stride_dout_chunk, - stride_dout_head, - stride_dout_dim, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_out_dim, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dfinal_states_batch, - stride_dfinal_states_head, - stride_dfinal_states_dim, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_dstates_batch, - stride_dstates_chunk, - stride_dstates_head, - stride_dstates_dim, - stride_ddA_cs_batch, - stride_ddA_cs_chunk, - stride_ddA_cs_head, - stride_dinitstates_batch, - stride_dinitstates_head, - stride_dinitstates_dim, - # Meta-parameters - CONVERT_STATES: tl.constexpr, - HAS_DFINAL_STATES: tl.constexpr, - HAS_DINITSTATES: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk - ddA_cs_ptr += ( - pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m - ) - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk - dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk - if CONVERT_STATES: - states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk - if HAS_DFINAL_STATES: - dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head - if HAS_DINITSTATES: - dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch - - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - dout_ptrs = dout_ptr + offs_m * stride_dout_dim - if CONVERT_STATES: - states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim - - if HAS_DFINAL_STATES: - dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to( - tl.float32 - ) - else: - dstates = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - tl.store(dstates_ptrs, dstates, mask=offs_m < dim) - if HAS_SEQ_IDX: - seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen) - dstates_ptrs -= stride_dstates_chunk - for c in range(nchunks - 1): - dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) - if HAS_SEQ_IDX: - seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen)) - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) - seq_idx = seq_idx_new - out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if CONVERT_STATES: - tl.store(states_converted_ptrs, out, mask=offs_m < dim) - ddA = tl.sum(out * dstates) * scale - tl.store(ddA_cs_ptr, ddA) - dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - dstates = scale * dstates + dout - tl.store(dstates_ptrs, dstates, mask=offs_m < dim) - dout_ptrs -= stride_dout_chunk - dstates_ptrs -= stride_dstates_chunk - dA_cs_ptr -= stride_dA_cs_chunk - ddA_cs_ptr -= stride_ddA_cs_chunk - out_ptrs -= stride_out_chunk - if CONVERT_STATES: - states_converted_ptrs -= stride_out_chunk - if CONVERT_STATES: - out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - tl.store(states_converted_ptrs, out, mask=offs_m < dim) - if not HAS_DINITSTATES: - tl.store(ddA_cs_ptr, 0.0) - else: - dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) - if HAS_SEQ_IDX: - scale = tl.where(seq_idx == 0, scale, 0.0) - out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - ddA = tl.sum(out * dstates) * scale - tl.store(ddA_cs_ptr, ddA) - dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - dstates = scale * dstates + dout - tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim) - - -def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None): - batch, nchunks, nheads, dim = states.shape - assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks) - if initial_states is not None: - assert tuple(initial_states.shape) == (batch, nheads, dim) - if seq_idx is not None: - assert chunk_size is not None - seqlen = seq_idx.shape[-1] - assert tuple(seq_idx.shape) == (batch, seqlen) - out_dtype = states.dtype if out_dtype is None else out_dtype - out = paddle.empty((batch, nchunks, nheads, dim), dtype=out_dtype) - final_states = paddle.empty((batch, nheads, dim), dtype=paddle.float32) - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads) - _state_passing_fwd_kernel[grid]( - states, - out, - final_states, - dA_chunk_cumsum, - initial_states, - seq_idx, - dim, - nchunks, - seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, - states.strides[0], - states.strides[1], - states.strides[2], - states.strides[3], - out.strides[0], - out.strides[1], - out.strides[2], - out.strides[3], - final_states.strides[0], - final_states.strides[1], - final_states.strides[2], - dA_chunk_cumsum.strides[0], - dA_chunk_cumsum.strides[2], - dA_chunk_cumsum.strides[1], - *( - (initial_states.strides[0], initial_states.strides[1], initial_states.strides[2]) - if initial_states is not None - else (0, 0, 0) - ), - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - HAS_INITSTATES=initial_states is not None, - HAS_SEQ_IDX=seq_idx is not None, - ) - return out, final_states - - -def _state_passing_bwd( - states, - dA_chunk_cumsum, - dout, - dfinal_states=None, - seq_idx=None, - has_initial_states=None, - dstates_dtype=None, - states_dtype=None, - chunk_size=None, -): - """ - states contains the initial_states at index 0. The final states are not included in states. - """ - batch, nchunks, nheads, dim = states.shape - assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks) - assert tuple(dout.shape) == (batch, nchunks, nheads, dim) - if seq_idx is not None: - assert chunk_size is not None - seqlen = seq_idx.shape[-1] - assert tuple(seq_idx.shape) == (batch, seqlen) - dstates = paddle.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) - if states_dtype is not None and states_dtype != states.dtype: - states_converted = paddle.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) - assert states_converted.strides == states.strides - else: - states_converted = None - if has_initial_states: - dinitstates = paddle.empty_like(dstates[:, 0]) - else: - dinitstates = None - if dfinal_states is not None: - assert tuple(dfinal_states.shape) == (batch, nheads, dim) - BLOCK_SIZE_min = 64 - n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min - ddA_chunk_cumsum = paddle.empty([batch, nheads, nchunks, n_blocks], dtype=paddle.float32) - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads) - _state_passing_bwd_kernel[grid]( - dout, - states, - dA_chunk_cumsum, - dfinal_states, - seq_idx, - dstates, - ddA_chunk_cumsum, - dinitstates, - states_converted, - dim, - nchunks, - seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, - dout.strides[0], - dout.strides[1], - dout.strides[2], - dout.strides[3], - states.strides[0], - states.strides[1], - states.strides[2], - states.strides[3], - dA_chunk_cumsum.strides[0], - dA_chunk_cumsum.strides[2], - dA_chunk_cumsum.strides[1], - *( - (dfinal_states.strides[0], dfinal_states.strides[1], dfinal_states.strides[2]) - if dfinal_states is not None - else (0, 0, 0) - ), - *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)), - dstates.strides[0], - dstates.strides[1], - dstates.strides[2], - dstates.strides[3], - ddA_chunk_cumsum.strides[0], - ddA_chunk_cumsum.strides[2], - ddA_chunk_cumsum.strides[1], - *( - (dinitstates.strides[0], dinitstates.strides[1], dinitstates.strides[2]) - if dinitstates is not None - else (0, 0, 0) - ), - CONVERT_STATES=states_converted is not None, - HAS_DFINAL_STATES=dfinal_states is not None, - HAS_DINITSTATES=dinitstates is not None, - HAS_SEQ_IDX=seq_idx is not None, - ) - BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"] - n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(axis=-1).cast(dtype=dA_chunk_cumsum.dtype) - if states_dtype is not None and states_dtype == states.dtype: - states_converted = states - return ( - (dstates, ddA_chunk_cumsum, dinitstates) - if states_dtype is None - else (dstates, ddA_chunk_cumsum, dinitstates, states_converted) - ) - - -class StatePassingFn(paddle.autograd.PyLayer): - @staticmethod - @custom_fwd - def forward(ctx, states, dA_chunk_cumsum, initial_states=None): - batch, nchunks, nheads, dim = states.shape - assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks) - if states.strides[-1] != 1: - states = states.contiguous() - out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states) - ctx.save_for_backward(out, dA_chunk_cumsum) - ctx.has_initial_states = initial_states is not None - return out, final_states - - @staticmethod - @custom_bwd - def backward(ctx, dout, dfinal_states): - out, dA_chunk_cumsum = ctx.saved_tensor() - batch, nchunks, nheads, dim = out.shape - assert tuple(dout.shape) == (batch, nchunks, nheads, dim) - assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks) - assert tuple(dfinal_states.shape) == (batch, nheads, dim) - if dout.strides[-1] != 1: - dout = dout.contiguous() - dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd( - out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states, has_initial_states=ctx.has_initial_states - ) - return dstates, ddA_chunk_cumsum, dinitstates - - -def state_passing(states, dA_chunk_cumsum, initial_states=None): - """ - Argument: - states: (batch, nchunks, nheads, dim) - dA_chunk_cumsum: (batch, nheads, nchunks) - initial_states: (batch, nheads, dim) - Return: - out: (batch, nchunks, nheads, dim) final_states: (batch, nheads, dim) - """ - return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states) - - -def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): - """ - Argument: - states: (batch, nchunks, nheads, dim) - dA_chunk_cumsum: (batch, nheads, nchunks) - initial_states: (batch, nheads, dim) - Return: - out: (batch, nchunks, nheads, dim) final_states: (batch, nheads, dim) - """ - if initial_states is None: - initial_states = paddle.zeros_like(states[:, 0]) - states = paddle.concat([rearrange(initial_states, "b h d -> b 1 h d"), states], axis=1) - dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0), data_format="NCL") - dA_chunk_cumsum = paddle.cumsum(dA_chunk_cumsum, axis=-1) - nchunks = dA_chunk_cumsum.shape[-1] - # (batch, nheads, nchunks, nchunks) - dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :] - # (batch, nheads, nchunks, nchunks) - decay_chunk = paddle.exp(dt_chunk_segment_sum) - causal_mask = paddle.tril(paddle.ones([nchunks, nchunks], dtype=bool), diagonal=0) - decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) - out = paddle.einsum("bhzc,bchd->bzhd", decay_chunk.cast(dtype=states.dtype), states) - return out[:, :-1], out[:, -1] diff --git a/ops/src/paddlenlp_kernel/triton/optimizer/__init__.py b/ops/src/paddlenlp_kernel/triton/optimizer/__init__.py deleted file mode 100644 index 95d81b31b1f4..000000000000 --- a/ops/src/paddlenlp_kernel/triton/optimizer/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .adamw_16bit_moment import adamw_16bit_moment diff --git a/ops/src/paddlenlp_kernel/triton/optimizer/adamw_16bit_moment.py b/ops/src/paddlenlp_kernel/triton/optimizer/adamw_16bit_moment.py deleted file mode 100644 index a6b6c1e034a3..000000000000 --- a/ops/src/paddlenlp_kernel/triton/optimizer/adamw_16bit_moment.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import triton -import triton.language as tl - - -@triton.jit -def adamw_kernel( - param_ptr, - grad_ptr, - moment1_ptr, - moment2_ptr, - lr_ptr, - beta1, - beta2, - epsilon, - coeff, - beta1_pow_ptr, - beta2_pow_ptr, - master_weight_ptr, - dtype, - N, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < N - - if master_weight_ptr is not None: - param = tl.load(master_weight_ptr + offsets, mask=mask) - else: - param = tl.load(param_ptr + offsets, mask=mask).to(tl.float32) - grad = tl.load(grad_ptr + offsets, mask=mask).to(tl.float32) - moment1 = tl.load(moment1_ptr + offsets, mask=mask).to(tl.float32) - moment2 = tl.load(moment2_ptr + offsets, mask=mask).to(tl.float32) - lr = tl.load(lr_ptr) - beta1_pow = tl.load(beta1_pow_ptr) - beta2_pow = tl.load(beta2_pow_ptr) - - # Weight Decay - param *= 1.0 - lr * coeff - - # AdamW - moment1 = beta1 * moment1 + (1.0 - beta1) * grad - moment2 = beta2 * moment2 + (1.0 - beta2) * grad * grad - denom = tl.sqrt(moment2) / tl.sqrt(1.0 - beta2_pow) + epsilon - param += (moment1 / denom) * (-lr / (1 - beta1_pow)) - if dtype == 0: - target_dtype = tl.float16 - elif dtype == 1: - target_dtype = tl.bfloat16 - else: - target_dtype = tl.float32 - target_dtype = tl.bfloat16 - - # Update param - if master_weight_ptr is not None: - tl.store(master_weight_ptr + offsets, param, mask=mask) - tl.store(param_ptr + offsets, param.to(target_dtype), mask=mask) - else: - tl.store(param_ptr + offsets, param.to(target_dtype), mask=mask) - tl.store(moment1_ptr + offsets, moment1.to(target_dtype), mask=mask) - tl.store(moment2_ptr + offsets, moment2.to(target_dtype), mask=mask) - - -def adamw_16bit_moment( - param, - grad, - learning_rate, - moment1, - moment2, - beta1_pow, - beta2_pow, - master_weight, - skip_update, - beta1, - beta2, - epsilon, - lr_ratio, - coeff, - with_decay, - multi_precision, -): - if skip_update: - return - if not with_decay: - coeff = 0.0 - if not multi_precision: - master_weight = None - lr = learning_rate * lr_ratio - - N = param.numel().item() - BLOCK_SIZE = 512 - grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),) - if str(param.dtype) == "paddle.float16": - dtype = 0 - elif str(param.dtype) == "paddle.bfloat16": - dtype = 1 - else: - dtype = 2 - adamw_kernel[grid]( - param, - grad, - moment1, - moment2, - lr, - beta1, - beta2, - epsilon, - coeff, - beta1_pow, - beta2_pow, - master_weight, - dtype, - N, - BLOCK_SIZE, - ) - beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:] diff --git a/ops/src/paddlenlp_kernel/triton/triton_patch.py b/ops/src/paddlenlp_kernel/triton/triton_patch.py deleted file mode 100644 index 5e8c908d2582..000000000000 --- a/ops/src/paddlenlp_kernel/triton/triton_patch.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -from typing import Dict - -import triton -from triton.runtime import autotuner - -__all__ = ["paddle_autotune", "PaddleAutotuner"] - - -def paddle_autotune( - configs, - key, - prune_configs_by=None, - reset_to_zero=None, - restore_value=None, - pre_hook=None, - post_hook=None, - warmup=25, - rep=100, - use_cuda_graph=False, -): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - - .. highlight:: python - .. code-block:: python - - @triton.paddle_autotune(configs=[ - triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - :note: When all the configurations are evaluated, the kernel will run multiple times. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - resets the value of the provided tensor to `zero` before running any configuration. - - If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to - :code:`"1"`, Triton will print a message to stdout after autotuning each - kernel, including the time spent autotuning and the best configuration. - - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - :param restore_value: a list of argument names whose value will be restored after evaluating any configs. - :type restore_value: list[str] - :param pre_hook: a function that will be called before the kernel is called. - This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. - 'args': a list of arguments passed to the kernel. - 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. - :type pre_hook: lambda args, reset_only - :param post_hook: a function that will be called after the kernel is called. - This overrides the default post_hook used for 'restore_value'. - 'args': a list of arguments passed to the kernel. - 'exception': the exception raised by the kernel in case of a compilation or runtime error. - :type post_hook: lambda args, exception - :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. - :type warmup: int - :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. - :type rep: int - """ - - def decorator(fn): - return PaddleAutotuner( - fn, - fn.arg_names, - configs, - key, - reset_to_zero, - restore_value, - pre_hook=pre_hook, - post_hook=post_hook, - prune_configs_by=prune_configs_by, - warmup=warmup, - rep=rep, - use_cuda_graph=use_cuda_graph, - ) - - return decorator - - -class PaddleAutotuner(autotuner.Autotuner): - def __init__( - self, - fn, - arg_names, - configs, - key, - reset_to_zero, - restore_value, - pre_hook=None, - post_hook=None, - prune_configs_by: Dict = None, - warmup=25, - rep=100, - use_cuda_graph=False, - ): - """ - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. - """ - if not configs: - self.configs = [autotuner.Config({}, num_warps=4, num_stages=2, num_ctas=1)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.cache = {} - self.arg_names = arg_names - - # Reset to zero or restore values - self.reset_idx = [] - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - self.restore_idx = [] - if restore_value is not None: - self.restore_idx = [arg_names.index(k) for k in restore_value] - - # Hook to reset or restore for required tensors - self.pre_hook = lambda args, reset_only=False: 0 - self.post_hook = lambda args, exception: 0 - if pre_hook: - self.pre_hook = pre_hook - elif len(self.reset_idx) > 0 or len(self.restore_idx) > 0: - - def _pre_hook(args, reset_only=False): - for i in self.reset_idx: - args[i].zero_() - if not reset_only: - self.restore_copies = [args[i].clone() for i in self.restore_idx] - - self.pre_hook = _pre_hook - - if post_hook: - self.post_hook = post_hook - elif len(self.restore_idx) > 0: - - def _post_hook(args, exception): - for i, j in enumerate(self.restore_idx): - args[j].copy_(self.restore_copies[i], False) - self.restore_copies = [] - - self.post_hook = _post_hook - - self.perf_model = None - self.configs_top_k = 1.0 - self.early_config_prune = None - if prune_configs_by: - self.perf_model = prune_configs_by.get("perf_model", self.perf_model) - self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) - self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) - - self.fn = fn - self.base_fn = fn - while not inspect.isfunction(self.base_fn): - self.base_fn = self.base_fn.fn - self.num_warmups = warmup - self.num_reps = rep - if use_cuda_graph: - use_cuda_graph = False - self.use_cuda_graph = use_cuda_graph - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) - # augment meta-parameters with tunable ones - current = dict(meta, **config.all_kwargs()) - full_nargs = {**self.nargs, **current} - - def kernel_call(): - if config.pre_hook: - config.pre_hook(full_nargs) - self.pre_hook(args) - try: - self.fn.run( - *args, - **current, - ) - except Exception as e: - try: - self.post_hook(args, exception=e) - finally: - # Throw exception raised by `self.fn.run` - raise - - self.post_hook(args, exception=None) - - try: - return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) - except Exception: - return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] - - -def do_bench( - fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", device_type="gpu" -): - """ - Benchmark the runtime of the provided function. By default, return the median runtime of `fn` along with - the 20-th and 80-th performance percentile. - - :param fn: Function to benchmark - :type fn: Callable - :param warmup: Warmup time (in ms) - :type warmup: int - :param rep: Repetition time (in ms) - :type rep: int - :param grad_to_none: Reset the gradient of the provided tensor to None - :type grad_to_none: paddle.Tensor, optional - :param quantiles: Performance percentile to return in addition to the median. - :type quantiles: list[float] - :param fast_flush: Use faster kernel to flush L2 between measurements - :type fast_flush: bool - """ - if device_type == "cuda": - device_type = "gpu" - import paddle - - assert return_mode in ["min", "max", "mean", "median"] - - fn() - paddle.device.cuda.synchronize() if device_type == "gpu" else None - - # We maintain a buffer of 256 MB that we clear - # before each kernel call to make sure that the L2 - # doesn't contain any input data before the run - if fast_flush: - cache = paddle.empty([int(256e6 // 4)], dtype="int32") - else: - cache = paddle.empty([int(256e6)], dtype="int8") - - # Estimate the runtime of the function - start_time = paddle.device.cuda.Event(enable_timing=True) if device_type == "gpu" else None - end_time = paddle.device.cuda.Event(enable_timing=True) if device_type == "gpu" else None - start_time.record() if device_type == "gpu" else None - for _ in range(5): - cache.clear_gradient() - fn() - end_time.record() if device_type == "gpu" else None - paddle.device.cuda.synchronize() if device_type == "gpu" else None - estimate_ms = start_time.elapsed_time(end_time) / 5 if device_type == "gpu" else 0 - - # compute number of warmup and repeat - n_warmup = max(1, int(warmup / estimate_ms)) - n_repeat = max(1, int(rep / estimate_ms)) - start_time = ( - [paddle.device.cuda.Event(enable_timing=True) for _ in range(n_repeat)] if device_type == "gpu" else None - ) - end_time = ( - [paddle.device.cuda.Event(enable_timing=True) for _ in range(n_repeat)] if device_type == "gpu" else None - ) - - # Warm-up - for _ in range(n_warmup): - fn() - - # Benchmark - times = [] - for i in range(n_repeat): - # we don't want `fn` to accumulate gradient values - # if it contains a backward pass. So we clear the - # provided gradients - if grad_to_none is not None: - for x in grad_to_none: - x.clear_gradient() - # we clear the L2 cache before each run - cache.clear_gradient() - # record time of `fn` - start_time[i].record() if device_type == "gpu" else None - fn() - end_time[i].record() if device_type == "gpu" else None - - # Record clocks - paddle.device.cuda.synchronize() if device_type == "gpu" else None - times = [s.elapsed_time(e) for s, e in zip(start_time, end_time)] if device_type == "gpu" else [0] * n_repeat - times = paddle.to_tensor(times, dtype="float32") - - if quantiles is not None: - ret = paddle.quantile(times, paddle.to_tensor(quantiles, dtype="float32")).numpy().tolist() - if len(ret) == 1: - ret = ret[0] - return ret - - return getattr(paddle.tensor, return_mode)(times).item() - - -if not hasattr(triton, "paddle_autotune"): - triton.paddle_autotune = PaddleAutotuner -if not hasattr(autotuner, "PaddleAutotuner"): - autotuner.PaddleAutotuner = PaddleAutotuner diff --git a/ops/src/paddlenlp_kernel/utils.py b/ops/src/paddlenlp_kernel/utils.py deleted file mode 100644 index 4699fc93e211..000000000000 --- a/ops/src/paddlenlp_kernel/utils.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -from typing import Callable - -import paddle -from packaging.version import Version - - -def custom_fwd(func): - def wrapper(*args, **kwargs): - ctx = args[0] - if len(args) == 1: - all_args = tuple(kwargs.values()) - else: - all_args = args[1:] + tuple(kwargs.values()) - - if not hasattr(ctx, "needs_input_grad"): - ctx.needs_input_grad = [False] * len(all_args) - for i, arg in enumerate(all_args): - if isinstance(arg, paddle.Tensor): - if not arg.stop_gradient: - ctx.needs_input_grad[i] = True - else: - ctx.needs_input_grad[i] = "not_tensor" - return func(*args, **kwargs) - - return wrapper - - -def custom_bwd(func): - def wrapper(*args, **kwargs): - ctx = args[0] - output = func(*args, **kwargs) - result = [] - for each, need_input_grad in zip(output, ctx.needs_input_grad): - if isinstance(need_input_grad, str) and need_input_grad == "not_tensor": - continue - if need_input_grad: - result.append(each) - else: - result.append(None) - while result and result[-1] is None: - result.pop() - return tuple(result) - - return wrapper - - -def compare_version(package: str, operator: Callable, target: str): - try: - pkg = importlib.import_module(package) - except ImportError: - return False - pkg_version = Version(pkg.__version__) - return operator(pkg_version, Version(target)) - - -def is_autocast_enabled(): - tracer = paddle.framework._dygraph_tracer() - return False if tracer._amp_level == paddle.core.AmpLevel.O0 else True - - -def get_autocast_gpu_dtype(): - from paddle.amp.auto_cast import amp_global_state - - return amp_global_state().amp_dtype diff --git a/ops/tests/cuda/test_causal_conv1d.py b/ops/tests/cuda/test_causal_conv1d.py deleted file mode 100644 index 454ef8ad4c33..000000000000 --- a/ops/tests/cuda/test_causal_conv1d.py +++ /dev/null @@ -1,364 +0,0 @@ -# Copyright (C) 2024, Tri Dao. - -import paddle -import paddle.nn.functional as F -import pytest -from einops import rearrange -from paddlenlp_kernel.cuda.causal_conv1d import ( - causal_conv1d_fn, - causal_conv1d_ref, - causal_conv1d_update, - causal_conv1d_update_ref, -) -from paddlenlp_kernel.triton.causal_conv1d_varlen import ( - causal_conv1d_varlen_states, - causal_conv1d_varlen_states_ref, -) - -####################################################################################################################################### -# patch paddle.allclose -old_allclose = paddle.allclose - - -def allclose(a, b, **kwargs): - return old_allclose(a.cast("float32"), b.cast("float32"), **kwargs) - - -paddle.allclose = allclose - -old_equal_all = paddle.equal_all - - -def equal_all(a, b): - return old_equal_all(a.cast("float32"), b.cast("float32")) - - -paddle.equal_all = equal_all - - -def requires_grad_(self, value=True): - self.stop_gradient = not value - return self - - -paddle.Tensor.requires_grad_ = requires_grad_ -####################################################################################################################################### - - -@pytest.mark.parametrize("return_final_states", [False, True]) -# @pytest.mark.parametrize("return_final_states", [True]) -@pytest.mark.parametrize("has_initial_states", [False, True]) -# @pytest.mark.parametrize("has_initial_states", [False]) -@pytest.mark.parametrize("channel_last", [False, True]) -# @pytest.mark.parametrize('channel_last', [True]) -@pytest.mark.parametrize("itype", [paddle.float32, paddle.float16, paddle.bfloat16]) -# @pytest.mark.parametrize('itype', [paddle.float16]) -@pytest.mark.parametrize("silu_activation", [False, True]) -# @pytest.mark.parametrize('silu_activation', [True]) -@pytest.mark.parametrize("has_bias", [False, True]) -# @pytest.mark.parametrize('has_bias', [True]) -@pytest.mark.parametrize("width", [2, 3, 4]) -# @pytest.mark.parametrize('width', [3]) -@pytest.mark.parametrize("seqlen", [2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize("dim", [64, 4096 + 32]) -# @pytest.mark.parametrize('dim', [64]) -def test_causal_conv1d( - dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states -): - if not channel_last and (has_initial_states or return_final_states): - pytest.skip("Only channel_last support initial_states or return_final_states") - - rtol, atol = (3e-4, 1e-3) if itype == paddle.float32 else (3e-3, 5e-3) - if itype == paddle.bfloat16: - rtol, atol = 1e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # set seed - paddle.seed(0) - batch = 2 - # batch = 1 - if not channel_last: - x = paddle.randn([batch, 4096 + dim + 64, seqlen], dtype=itype)[:, 4096 : 4096 + dim, :].requires_grad_() - else: - x = rearrange( - paddle.randn([batch, seqlen, 4096 + dim + 64], dtype=itype)[:, :, 4096 : 4096 + dim], "b s d -> b d s" - ).requires_grad_() - weight = paddle.randn([dim, width], dtype=paddle.float32).requires_grad_() - if has_bias: - bias = paddle.randn( - [ - dim, - ], - dtype=paddle.float32, - ).requires_grad_() - else: - bias = None - if has_initial_states: - initial_states = paddle.randn([batch, width - 1, dim], dtype=itype).transpose([0, 2, 1]).requires_grad_() - else: - initial_states = None - x_ref = x.detach().clone().requires_grad_() - weight_ref = weight.detach().clone().requires_grad_() - bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None - initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None - activation = None if not silu_activation else "silu" - out = causal_conv1d_fn( - x, weight, bias, initial_states=initial_states, return_final_states=return_final_states, activation=activation - ) - out_ref = causal_conv1d_ref( - x_ref, - weight_ref, - bias_ref, - initial_states=initial_states_ref, - return_final_states=return_final_states, - activation=activation, - ) - if return_final_states: - out, final_states = out - out_ref, final_states_ref = out_ref - print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}") - print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}") - assert paddle.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert paddle.allclose(out, out_ref, rtol=rtol, atol=atol) - - if return_final_states: - out += F.sigmoid(final_states).sum(axis=-1, keepdim=True) - out_ref += F.sigmoid(final_states_ref).sum(axis=-1, keepdim=True) - - g = paddle.randn(out.shape, dtype=out.dtype) - out.backward(g) - out_ref.backward(g) - - print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") - print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") - if has_bias: - print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") - if has_initial_states: - print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}") - - assert paddle.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) - assert paddle.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) - if has_bias: - assert paddle.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) - if has_initial_states: - assert paddle.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) - - -@pytest.mark.parametrize("itype", [paddle.float32, paddle.float16, paddle.bfloat16]) -# @pytest.mark.parametrize('itype', [paddle.float16]) -@pytest.mark.parametrize("silu_activation", [False, True]) -# @pytest.mark.parametrize('silu_activation', [True]) -@pytest.mark.parametrize("has_bias", [False, True]) -# @pytest.mark.parametrize('has_bias', [True]) -@pytest.mark.parametrize("has_cache_seqlens", [False, True]) -# @pytest.mark.parametrize('has_cache_seqlens', [True]) -@pytest.mark.parametrize("seqlen", [1, 4, 5]) -# @pytest.mark.parametrize('seqlen', [4]) -@pytest.mark.parametrize("width", [2, 3, 4]) -# @pytest.mark.parametrize('width', [4]) -@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -# @pytest.mark.parametrize("dim", [2048]) -def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype): - rtol, atol = (3e-4, 1e-3) if itype == paddle.float32 else (3e-3, 5e-3) - if itype == paddle.bfloat16: - rtol, atol = 1e-2, 5e-2 - # rtolw, atolw = (1e-3, 1e-3) - # set seed - paddle.seed(0) - batch = 64 - # batch = 1 - # dim = 64 - x = paddle.randn([batch, seqlen, dim], dtype=itype).transpose([0, 2, 1]) - state_len = paddle.randint(width - 1, width + 10, (1,)).item() - conv_state = paddle.randn([batch, state_len, dim], dtype=itype).transpose([0, 2, 1]) - weight = paddle.randn([dim, width], dtype=paddle.float32).requires_grad_() - if has_bias: - bias = paddle.randn( - [ - dim, - ], - dtype=paddle.float32, - ).requires_grad_() - else: - bias = None - conv_state_ref = conv_state.detach().clone() - activation = None if not silu_activation else "silu" - cache_seqlens = paddle.randint(0, 1024, (batch,), dtype=paddle.int32) if has_cache_seqlens else None - out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens) - out_ref = causal_conv1d_update_ref( - x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens - ) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert paddle.equal_all(conv_state, conv_state_ref) - assert paddle.allclose(out, out_ref, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize("itype", [paddle.float32, paddle.float16, paddle.bfloat16]) -# @pytest.mark.parametrize('itype', [paddle.float16]) -@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -# @pytest.mark.parametrize("dim", [2048]) -def test_causal_conv1d_get_states(dim, itype): - # set seed - paddle.seed(0) - seqlens = paddle.randint(1, 32, (100,)) - total_seqlen = seqlens.sum().item() - x = paddle.randn([total_seqlen, dim], dtype=itype) - cu_seqlens = F.pad(seqlens.cumsum(0).unsqueeze([0, 1]), (1, 0), data_format="NCL").squeeze([0, 1]) - state_len = 20 - out = causal_conv1d_varlen_states(x, cu_seqlens, state_len) - out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len) - assert paddle.equal_all(out, out_ref) - - -# @pytest.mark.parametrize("channel_last", [False, True]) -@pytest.mark.parametrize("channel_last", [True]) -# @pytest.mark.parametrize("itype", [paddle.float32, paddle.float16, paddle.bfloat16]) -@pytest.mark.parametrize("itype", [paddle.bfloat16]) -# @pytest.mark.parametrize("silu_activation", [False, True]) -@pytest.mark.parametrize("silu_activation", [True]) -# @pytest.mark.parametrize("has_bias", [False, True]) -@pytest.mark.parametrize("has_bias", [True]) -# @pytest.mark.parametrize("width", [2, 3, 4]) -@pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize( - # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] - "seqlen", - [2048], -) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [128]) -def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): - # set seed - paddle.seed(0) - batch = 2 - # batch = 1 - dim = 4096 + 32 # Try dim not divisible by 64 - # dim = 64 - if not channel_last: - x = paddle.randn([batch, 4096 + dim + 64, seqlen], dtype=itype)[:, 4096 : 4096 + dim, :].requires_grad_() - else: - x = rearrange( - paddle.randn([batch, seqlen, 4096 + dim + 64], dtype=itype)[:, :, 4096 : 4096 + dim], "b s d -> b d s" - ).requires_grad_() - weight = paddle.randn([dim, width], dtype=paddle.float32).requires_grad_() - if has_bias: - bias = paddle.randn( - [ - dim, - ], - dtype=paddle.float32, - ).requires_grad_() - else: - bias = None - activation = None if not silu_activation else "silu" - out0 = causal_conv1d_fn(x, weight, bias, activation=activation) - g = paddle.randn(out0.shape, dtype=out0.dtype) - dx0, dw0, db0 = paddle.autograd.grad(out0, (x, weight, bias), g) - dw_atol = 1e-4 - # db_atol = 1e-4 - - for i in range(10000): - out = causal_conv1d_fn(x, weight, bias, activation=activation) - dx, dw, db = paddle.autograd.grad(out, (x, weight, bias), g) - dw_equal = paddle.allclose(dw, dw0, atol=dw_atol) - # if not dw_equal: - # breakpoint() - if has_bias: - pass - # db_equal = paddle.allclose(db, db0, atol=db_atol) - # if not db_equal: - # breakpoint() - assert paddle.equal_all(out, out0) - assert paddle.equal_all(dx, dx0) - assert dw_equal - if has_bias: - assert dw_equal - - -@pytest.mark.parametrize("itype", [paddle.float32, paddle.float16, paddle.bfloat16]) -# @pytest.mark.parametrize('itype', [paddle.float16]) -@pytest.mark.parametrize("silu_activation", [False, True]) -# @pytest.mark.parametrize('silu_activation', [False]) -@pytest.mark.parametrize("has_bias", [False, True]) -# @pytest.mark.parametrize('has_bias', [False]) -@pytest.mark.parametrize("width", [2, 3, 4]) -# @pytest.mark.parametrize('width', [2]) -@pytest.mark.parametrize("seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [2048]) -@pytest.mark.parametrize("dim", [64, 4096 + 32]) -# @pytest.mark.parametrize('dim', [64]) -def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype): - rtol, atol = (3e-4, 1e-3) if itype == paddle.float32 else (3e-3, 5e-3) - if itype == paddle.bfloat16: - rtol, atol = 1e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # set seed - paddle.seed(seqlen + dim + width) - batch = 3 - seqlens = [] - for b in range(batch): - nsplits = paddle.randint(1, 5, (1,)).item() - eos_pos = paddle.randperm(seqlen - 1)[:nsplits].sort() - seqlens.append( - paddle.diff(paddle.concat([paddle.to_tensor([-1]), eos_pos, paddle.to_tensor([seqlen - 1])])).tolist() - ) - assert sum(seqlens[-1]) == seqlen - assert all(s > 0 for s in seqlens[-1]) - # Only support channel_last - x = rearrange( - paddle.randn([batch, seqlen, 4096 + dim + 64], dtype=itype)[:, :, 4096 : 4096 + dim], "b s d -> b d s" - ).requires_grad_() - weight = paddle.randn([dim, width], dtype=paddle.float32).requires_grad_() - if has_bias: - bias = paddle.randn( - [ - dim, - ], - dtype=paddle.float32, - ).requires_grad_() - else: - bias = None - seq_idx = paddle.stack( - [ - paddle.concat([paddle.full((s,), i, dtype=paddle.int32) for i, s in enumerate(sl)], axis=0) - for sl in seqlens - ], - axis=0, - ) - x_ref = x.detach().clone().requires_grad_() - weight_ref = weight.detach().clone().requires_grad_() - bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None - activation = None if not silu_activation else "silu" - out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation) - out_ref = [] - for b in range(batch): - out_ref_b = [] - for x_s in paddle.split(x_ref[[b]], seqlens[b], axis=2): - out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation)) - out_ref.append(paddle.concat(out_ref_b, axis=2)) - out_ref = paddle.concat(out_ref, axis=0) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert paddle.allclose(out, out_ref, rtol=rtol, atol=atol) - - g = paddle.randn(out.shape, dtype=out.dtype) - out_ref.backward(g) - out.backward(g) - - print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") - print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") - if has_bias: - print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") - - assert paddle.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) - assert paddle.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) - if has_bias: - assert paddle.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) diff --git a/ops/tests/cuda/test_selective_scan.py b/ops/tests/cuda/test_selective_scan.py deleted file mode 100644 index 6aa2cb9f823f..000000000000 --- a/ops/tests/cuda/test_selective_scan.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright (C) 2023, Tri Dao. - - -import paddle -import pytest -from paddlenlp_kernel.cuda.selective_scan import ( - mamba_inner_fn, - mamba_inner_ref, - selective_scan_fn, - selective_scan_ref, -) - -####################################################################################################################################### -# patch paddle.allclose -old_allclose = paddle.allclose - - -def allclose(a, b, **kwargs): - return old_allclose(a.cast("float32"), b.cast("float32"), **kwargs) - - -paddle.allclose = allclose - -old_equal_all = paddle.equal_all - - -def equal_all(a, b): - return old_equal_all(a.cast("float32"), b.cast("float32")) - - -paddle.equal_all = equal_all - - -def requires_grad_(self, value=True): - self.stop_gradient = not value - return self - - -paddle.Tensor.requires_grad_ = requires_grad_ -####################################################################################################################################### - - -@pytest.mark.parametrize("wtype", [paddle.float32]) -# @pytest.mark.parametrize('wtype', [paddle.float32]) -@pytest.mark.parametrize("itype", [paddle.float32]) -# @pytest.mark.parametrize('itype', [paddle.float32]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [128]) -# @pytest.mark.parametrize("return_last_state", [False, True]) -@pytest.mark.parametrize("return_last_state", [True]) -# @pytest.mark.parametrize('has_delta_bias', [False, True]) -@pytest.mark.parametrize("has_delta_bias", [True]) -# @pytest.mark.parametrize('delta_softplus', [False, True]) -@pytest.mark.parametrize("delta_softplus", [True]) -# @pytest.mark.parametrize('has_z', [False, True]) -@pytest.mark.parametrize("has_z", [True]) -# @pytest.mark.parametrize('has_D', [False, True]) -@pytest.mark.parametrize("has_D", [True]) -@pytest.mark.parametrize("varBC_groups", [1, 2]) -# @pytest.mark.parametrize("varBC_groups", [1]) -# @pytest.mark.parametrize("is_variable_C", [False, True]) -@pytest.mark.parametrize("is_variable_C", [True]) -# @pytest.mark.parametrize("is_variable_B", [False, True]) -@pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan( - is_variable_B, - is_variable_C, - varBC_groups, - has_D, - has_z, - has_delta_bias, - delta_softplus, - return_last_state, - seqlen, - itype, - wtype, -): - if varBC_groups > 1 and (not is_variable_B or not is_variable_C): - pytest.skip() # This config is not applicable - - rtol, atol = (6e-4, 2e-3) if itype == paddle.float32 else (3e-3, 5e-3) - if itype == paddle.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - if has_z: # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) - # set seed - paddle.seed(0) - batch_size = 2 - dim = 4 - dstate = 8 - is_complex = wtype == paddle.complex64 - - if is_complex: - A = (-0.5 * paddle.rand([dim, dstate], dtype="float32")).cast(wtype).requires_grad_() - else: - A = (-0.5 * paddle.rand([dim, dstate], dtype=wtype)).requires_grad_() - if not is_variable_B: - B_shape = (dim, dstate) - elif varBC_groups == 1: - B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) - else: - B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - B = paddle.randn(B_shape, dtype=wtype if not is_variable_B else itype).requires_grad_() - if not is_variable_C: - C_shape = (dim, dstate) - elif varBC_groups == 1: - C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) - else: - C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - C = paddle.randn(C_shape, dtype=wtype if not is_variable_C else itype).requires_grad_() - if has_D: - D = paddle.randn( - [ - dim, - ], - dtype=paddle.float32, - ).requires_grad_() - else: - D = None - if has_z: - z = paddle.randn([batch_size, dim, seqlen], dtype=itype).requires_grad_() - else: - z = None - if has_delta_bias: - delta_bias = ( - 0.5 - * paddle.rand( - [ - dim, - ], - dtype=paddle.float32, - ) - ).requires_grad_() - else: - delta_bias = None - u = paddle.randn([batch_size, dim, seqlen], dtype=itype).requires_grad_() - delta = (0.5 * paddle.rand([batch_size, dim, seqlen], dtype=itype)).requires_grad_() - A_ref = A.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() - C_ref = C.detach().clone().requires_grad_() - D_ref = D.detach().clone().requires_grad_() if D is not None else None - z_ref = z.detach().clone().requires_grad_() if z is not None else None - u_ref = u.detach().clone().requires_grad_() - delta_ref = delta.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out, *rest = selective_scan_fn( - u, - delta, - A, - B, - C, - D, - z=z, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - return_last_state=return_last_state, - ) - if return_last_state: - state = rest[0] - out_ref, *rest = selective_scan_ref( - u_ref, - delta_ref, - A_ref, - B_ref, - C_ref, - D_ref, - z=z_ref, - delta_bias=delta_bias_ref, - delta_softplus=delta_softplus, - return_last_state=return_last_state, - ) - if return_last_state: - state_ref = rest[0] - # dA = paddle.exp(paddle.einsum('bdl,dn->bdln', delta, A)) - # dt_u = delta * u - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert paddle.allclose(out, out_ref, rtol=rtol, atol=atol) - if return_last_state: - print(f"State max diff: {(state - state_ref).abs().max().item()}") - assert paddle.allclose(state, state_ref, rtol=rtol, atol=atol) - - g = paddle.randn(out.shape, dtype=out.dtype) - out_ref.backward(g.cast(out_ref.dtype)) - out.backward(g.cast(out.dtype)) - - print(f"du max diff: {(u.grad - u_ref.grad).abs().max().item()}") - print(f"ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}") - print(f"dA max diff: {(A.grad - A_ref.grad).abs().max().item()}") - print(f"dB max diff: {(B.grad - B_ref.grad).abs().max().item()}") - print(f"dC max diff: {(C.grad - C_ref.grad).abs().max().item()}") - if has_D: - print(f"dD max diff: {(D.grad - D_ref.grad).abs().max().item()}") - if has_z: - print(f"dz max diff: {(z.grad - z_ref.grad).abs().max().item()}") - if has_delta_bias: - print(f"ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}") - - assert paddle.allclose(u.grad, u_ref.grad.cast(dtype=itype), rtol=rtol * 2, atol=atol * 2) - assert paddle.allclose(delta.grad, delta_ref.grad.cast(dtype=itype), rtol=rtol * 5, atol=atol * 10) - assert paddle.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) - assert paddle.allclose( - B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, atol=atolw if not is_variable_B else atol - ) - assert paddle.allclose( - C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, atol=atolw if not is_variable_C else atol - ) - if has_D: - assert paddle.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) - if has_z: - assert paddle.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) - if has_delta_bias: - assert paddle.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) - - -# @pytest.mark.parametrize('wtype', [paddle.float32]) -@pytest.mark.parametrize("wtype", [paddle.float32]) -@pytest.mark.parametrize("itype", [paddle.float32]) -# @pytest.mark.parametrize('itype', [paddle.float32]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -@pytest.mark.parametrize("seqlen", [128]) -@pytest.mark.parametrize("is_variable_C", [False, True]) -# @pytest.mark.parametrize("is_variable_C", [False]) -@pytest.mark.parametrize("is_variable_B", [False, True]) -# @pytest.mark.parametrize("is_variable_B", [True]) -def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): - rtol, atol = (6e-4, 2e-3) if itype == paddle.float32 else (3e-3, 5e-3) - if itype == paddle.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) - # set seed - paddle.seed(0) - batch_size = 2 - dim = 768 - dstate = 8 - dt_rank = 48 - is_complex = wtype == paddle.complex64 - - xz = paddle.randn([batch_size, 2 * dim, seqlen], dtype=itype).requires_grad_() - conv1d_weight = paddle.randn([dim, 1, 3], dtype=paddle.float32).requires_grad_() - conv1d_bias = paddle.randn( - [ - dim, - ], - dtype=paddle.float32, - ).requires_grad_() - x_proj_weight = paddle.randn( - [dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate * (1 if not is_complex else 2), dim], - dtype=itype, - ).requires_grad_() - delta_proj_weight = paddle.randn([dim, dt_rank], dtype=itype).requires_grad_() - out_proj_weight = paddle.randn([dim // 2, dim], dtype=itype).requires_grad_() - out_proj_bias = None - - if is_complex: - A = (-0.5 * paddle.rand([dim, dstate], dtype="float32")).cast(wtype).requires_grad_() - else: - A = (-0.5 * paddle.rand([dim, dstate], dtype=wtype)).requires_grad_() - B = paddle.randn([dim, dstate], dtype=wtype).requires_grad_() if not is_variable_B else None - C = paddle.randn([dim, dstate], dtype=wtype).requires_grad_() if not is_variable_C else None - D = paddle.randn( - [ - dim, - ], - dtype=paddle.float32, - ).requires_grad_() - delta_bias = ( - 0.5 - * paddle.rand( - [ - dim, - ], - dtype=paddle.float32, - ) - ).requires_grad_() - # B_proj_bias = None - # C_proj_bias = None - xz_ref = xz.detach().clone().requires_grad_() - conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() - conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() - x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() - delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() - out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() - out_proj_bias_ref = out_proj_bias.detach().clone().requires_grad_() if out_proj_bias is not None else None - A_ref = A.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() if B is not None else None - C_ref = C.detach().clone().requires_grad_() if C is not None else None - D_ref = D.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out = mamba_inner_fn( - xz, - conv1d_weight, - conv1d_bias, - x_proj_weight, - delta_proj_weight, - out_proj_weight, - out_proj_bias, - A, - B, - C, - D, - delta_bias=delta_bias, - delta_softplus=True, - ) - out_ref = mamba_inner_ref( - xz_ref, - conv1d_weight_ref, - conv1d_bias_ref, - x_proj_weight_ref, - delta_proj_weight_ref, - out_proj_weight_ref, - out_proj_bias_ref, - A_ref, - B_ref, - C_ref, - D_ref, - delta_bias=delta_bias_ref, - delta_softplus=True, - ) - # dA = paddle.exp(paddle.einsum('bdl,dn->bdln', delta, A)) - # dt_u = delta * u - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert paddle.allclose(out, out_ref, rtol=rtol, atol=atol) - - g = paddle.randn(out.shape, dtype=out.dtype) - out_ref.backward(g.cast(out_ref.dtype)) - out.backward(g.cast(out.dtype)) - - print(f"dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}") - print(f"dA max diff: {(A.grad - A_ref.grad).abs().max().item()}") - if not is_variable_B: - print(f"dB max diff: {(B.grad - B_ref.grad).abs().max().item()}") - if not is_variable_C: - print(f"dC max diff: {(C.grad - C_ref.grad).abs().max().item()}") - print(f"dD max diff: {(D.grad - D_ref.grad).abs().max().item()}") - print(f"ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}") - print(f"dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}") - print(f"ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}") - print(f"dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}") - print(f"dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}") - print(f"dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}") - - # assert paddle.allclose(xz.grad, xz_ref.grad.cast(dtype=itype), rtol=rtol * 2, atol=atol * 2) - # assert paddle.allclose(delta.grad, delta_ref.grad.cast(dtype=itype), rtol=rtol * 5, atol=atol * 10) - # assert paddle.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) - # assert paddle.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, - # atol=atolw if not is_variable_B else atol) - # assert paddle.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, - # atol=atolw if not is_variable_C else atol) - # assert paddle.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) - # assert paddle.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) diff --git a/ops/tests/triton/cut_cross_entropy/test_cce_indexed_dot.py b/ops/tests/triton/cut_cross_entropy/test_cce_indexed_dot.py deleted file mode 100644 index b575e5677c47..000000000000 --- a/ops/tests/triton/cut_cross_entropy/test_cce_indexed_dot.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -import paddle -import pytest -from paddlenlp_kernel.triton.cut_cross_entropy.indexed_dot import ( - indexed_neg_dot_forward_kernel, -) -from paddlenlp_kernel.triton.cut_cross_entropy.utils import softcapping - -skip_no_cuda = pytest.mark.skipif(not paddle.device.is_compiled_with_cuda(), reason="Test requires CUDA") - - -@skip_no_cuda -@pytest.mark.parametrize("dtype,error_tol", [(paddle.float32, 5e-7), (paddle.float16, 1e-3), (paddle.bfloat16, 1e-2)]) -@pytest.mark.parametrize("softcap", [None, 20.0]) -@pytest.mark.parametrize("shape", [(256, 512, 128), (255, 507, 128), (255, 507, 123)]) -def test_indexed_dot(dtype: paddle.dtype, error_tol: float, softcap: float, shape: tuple[int, int, int]): - paddle.seed(0) - - if dtype == paddle.bfloat16 and not paddle.device.is_compiled_with_cuda(): - pytest.skip(reason="BF16 not available") - - N, V, D = shape - e = paddle.randn((N, D), dtype=dtype) / (D**0.5) - c = paddle.randn((V, D), dtype=dtype) - - c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] - - inds = paddle.randint(0, V, shape=(N,)) - - gt = -(e.cast("float32") * c[inds].cast("float32")).sum(-1) - if softcap is not None: - gt = softcapping(gt, softcap) - - ref = -(e * c[inds]).sum(-1).cast("float32") - if softcap is not None: - ref = softcapping(ref, softcap) - - cce_neg_dot = indexed_neg_dot_forward_kernel(e, c, inds, softcap=softcap) - - expected_error = (gt - ref).abs() - cce_error = (gt - cce_neg_dot).abs() - - assert ( - cce_error <= (expected_error + error_tol) - ).all(), f"{paddle.nn.functional.relu(cce_error - expected_error).max()=}" diff --git a/ops/tests/triton/cut_cross_entropy/test_cce_loss_backward.py b/ops/tests/triton/cut_cross_entropy/test_cce_loss_backward.py deleted file mode 100644 index d15176e82024..000000000000 --- a/ops/tests/triton/cut_cross_entropy/test_cce_loss_backward.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -from typing import Union - -import paddle -import paddle.nn.functional as F -import pytest -from paddlenlp_kernel.triton.cut_cross_entropy import linear_cross_entropy -from paddlenlp_kernel.triton.cut_cross_entropy.constants import IGNORE_INDEX -from paddlenlp_kernel.triton.cut_cross_entropy.utils import softcapping - -skip_no_cuda = pytest.mark.skipif(not paddle.device.is_compiled_with_cuda(), reason="Test requires CUDA") - - -def cross_entropy( - input, - label, - weight=None, - ignore_index=-100, - reduction="mean", - soft_label=False, - axis=-1, - use_softmax=True, - label_smoothing=0.0, - name=None, -): - """ - NOTE: torch cross_entropy is not the same as paddle cross_entropy. - """ - if ignore_index < 0 and reduction == "mean": - loss = F.cross_entropy(input, label, reduction="none") - binary_sequence = paddle.where(loss > 0, paddle.ones_like(loss), paddle.zeros_like(loss)) - count = paddle.sum(binary_sequence) - if count == 0: - loss = paddle.sum(loss * binary_sequence) - else: - loss = paddle.sum(loss * binary_sequence) / count - return loss - return F.cross_entropy( - input, label, weight, ignore_index, reduction, soft_label, axis, use_softmax, label_smoothing, name - ) - - -def _grads( - e: paddle.Tensor, - c: paddle.Tensor, - targets: paddle.Tensor, - softcap: Union[float, None], - shift: bool, - reduction: str, - fp32: bool = False, -) -> tuple[paddle.Tensor, paddle.Tensor]: - orig_e, orig_c = e, c - set_to_zero = False - e.clear_gradient(set_to_zero) - c.clear_gradient(set_to_zero) - - N, T = targets.shape - if shift: - e = e[:, :-1] - targets = targets[:, 1:] - T = T - 1 - - e = e.flatten(0, -2) - targets = targets.flatten() - - if fp32: - e = e.cast("float32") - c = c.cast("float32") - - logits = e @ c.T - if softcap is not None: - logits = softcapping(logits, softcap) - - loss = cross_entropy(logits.cast("float32"), targets, ignore_index=IGNORE_INDEX, reduction=reduction) - - if reduction == "sum": - loss = loss / (targets != IGNORE_INDEX).count_nonzero() - - loss.mean().backward() - - assert orig_e.grad is not None - assert orig_c.grad is not None - - return orig_e.grad.detach().clone(), orig_c.grad.detach().clone() - - -@skip_no_cuda -@pytest.mark.parametrize("impl", ["cce"]) -@pytest.mark.parametrize("dtype,error_tol", [(paddle.float16, 1e-3), (paddle.bfloat16, 1e-2)]) -@pytest.mark.parametrize("softcap", [None, 20.0]) -@pytest.mark.parametrize("shift", [False, True]) -@pytest.mark.parametrize("invalids", [False, True]) -@pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) -@pytest.mark.parametrize("shape", [(256, 512, 128), (252, 507, 128), (252, 507, 123)]) -def test_loss_backward( - impl: str, - dtype: paddle.dtype, - error_tol: float, - softcap: Union[float, None], - shift: bool, - invalids: bool, - reduction: str, - shape: tuple[int, int, int], -): - # paddle.set_float32_matmul_precision("highest") - # paddle._dynamo.config.cache_size_limit = 256 - paddle.seed(0) - - if dtype == paddle.bfloat16 and not paddle.device.is_compiled_with_cuda(): - pytest.skip(reason="BF16 not available") - - N, V, D = shape - e = paddle.randn((N, D), dtype=dtype) / (D**0.5) - c = paddle.randn((V, D), dtype=dtype) - - c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] - - targets = paddle.randint(0, V, shape=(N,)) - - if invalids: - inds = paddle.randperm(len(targets))[0 : int(0.2 * len(targets))] - targets[inds] = IGNORE_INDEX - - e = e.reshape([4, -1, D]) - targets = targets.reshape(e.shape[0:-1]) - - e.stop_gradient = False - c.stop_gradient = False - - gt = _grads(e, c, targets, softcap, shift, reduction, fp32=True) - - ref = _grads(e, c, targets, softcap, shift, reduction) - - set_to_zero = False - e.clear_gradient(set_to_zero) - c.clear_gradient(set_to_zero) - loss = linear_cross_entropy(e, c, targets, softcap=softcap, shift=shift, reduction=reduction, impl=impl) - if reduction == "sum": - loss = loss / (targets != IGNORE_INDEX).count_nonzero() - loss.mean().backward() - assert e.grad is not None - assert c.grad is not None - - expected_error = tuple((vgt - vref).abs() for vgt, vref in zip(gt, ref)) - cce_error = tuple((vgt - vcce).abs() for vgt, vcce in zip(gt, (e.grad, c.grad))) - - for i in range(len(expected_error)): - assert ( - cce_error[i] <= (expected_error[i] + error_tol) - ).all(), f"{(paddle.nn.functional.relu(cce_error[i] - expected_error[i])).max()=}" diff --git a/ops/tests/triton/cut_cross_entropy/test_cce_loss_forward.py b/ops/tests/triton/cut_cross_entropy/test_cce_loss_forward.py deleted file mode 100644 index 94b29c8a8bfb..000000000000 --- a/ops/tests/triton/cut_cross_entropy/test_cce_loss_forward.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -from typing import Union - -import paddle -import pytest -from paddlenlp_kernel.triton.cut_cross_entropy import linear_cross_entropy -from paddlenlp_kernel.triton.cut_cross_entropy.constants import IGNORE_INDEX -from paddlenlp_kernel.triton.cut_cross_entropy.utils import softcapping - -skip_no_cuda = pytest.mark.skipif(not paddle.device.is_compiled_with_cuda(), reason="Test requires CUDA") - - -def _loss( - e: paddle.Tensor, - c: paddle.Tensor, - targets: paddle.Tensor, - softcap: Union[float, None], - shift: bool, -) -> paddle.Tensor: - N, T = targets.shape - if shift: - e = e[:, :-1] - targets = targets[:, 1:] - T = T - 1 - - e = e.flatten(0, -2) - targets = targets.flatten() - - logits = e @ c.T - if softcap is not None: - logits = softcapping(logits, softcap) - - loss = paddle.nn.functional.cross_entropy( - logits.cast("float32"), targets, ignore_index=IGNORE_INDEX, reduction="none" - ) - - return loss.reshape([N, T]) - - -@skip_no_cuda -@pytest.mark.parametrize("impl", ["cce"]) -@pytest.mark.parametrize("dtype,error_tol", [(paddle.float32, 1e-5), (paddle.float16, 1e-3), (paddle.bfloat16, 1e-2)]) -@pytest.mark.parametrize("softcap", [None, 20.0]) -@pytest.mark.parametrize("shift", [False, True]) -@pytest.mark.parametrize("invalids", [False, True]) -@pytest.mark.parametrize("shape", [(256, 512, 128), (252, 507, 128), (252, 507, 123)]) -def test_loss_forward( - impl: str, - dtype: paddle.dtype, - error_tol: float, - softcap: Union[float, None], - shift: bool, - invalids: bool, - shape: tuple[int, int, int], -): - # paddle.set_float32_matmul_precision("highest") - # paddle._dynamo.config.cache_size_limit = 256 - paddle.seed(0) - - if dtype == paddle.bfloat16 and not paddle.device.is_compiled_with_cuda(): - pytest.skip(reason="BF16 not available") - - N, V, D = shape - e = paddle.randn((N, D), dtype=dtype) / (D**0.5) - c = paddle.randn((V, D), dtype=dtype) - - c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] - - e = e.reshape([4, -1, D]) - - targets = paddle.randint(0, V, shape=(N,)) - - if invalids: - inds = paddle.randperm(len(targets))[0 : int(0.2 * len(targets))] - targets[inds] = IGNORE_INDEX - - targets = targets.reshape(e.shape[0:-1]) - - gt = _loss(e.cast("float32"), c.cast("float32"), targets, softcap, shift) - - # paddle.set_float32_matmul_precision("highest" if dtype == paddle.float32 else "high") - ref = _loss(e, c, targets, softcap, shift) - - cce_loss = linear_cross_entropy(e, c, targets, softcap=softcap, shift=shift, reduction="none", impl=impl) - - expected_error = (gt - ref).abs() - cce_error = (gt - cce_loss).abs() - - assert ( - cce_error <= (expected_error + error_tol) - ).all(), f"{paddle.nn.functional.relu((cce_error - expected_error)).max()=}" diff --git a/ops/tests/triton/cut_cross_entropy/test_cce_lse.py b/ops/tests/triton/cut_cross_entropy/test_cce_lse.py deleted file mode 100644 index 517d25d21fb6..000000000000 --- a/ops/tests/triton/cut_cross_entropy/test_cce_lse.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. -from typing import Union - -import paddle -import paddle.nn.functional as F -import pytest -from paddlenlp_kernel.triton.cut_cross_entropy.cce_lse_forward import ( - cce_lse_forward_kernel, -) -from paddlenlp_kernel.triton.cut_cross_entropy.utils import softcapping - -skip_no_cuda = pytest.mark.skipif(not paddle.device.is_compiled_with_cuda(), reason="Test requires CUDA") - - -def _lse(e: paddle.Tensor, c: paddle.Tensor, softcap: Union[float, None]) -> paddle.Tensor: - logits = e @ c.T - if softcap is not None: - logits = softcapping(logits, softcap) - return paddle.logsumexp(logits.cast("float32"), axis=-1) - - -@skip_no_cuda -@pytest.mark.parametrize("dtype", [paddle.float32, paddle.float16, paddle.bfloat16]) -@pytest.mark.parametrize("softcap", [None, 20.0]) -@pytest.mark.parametrize("shape", [(256, 512, 128), (255, 507, 128), (255, 507, 123)]) -def test_lse(dtype: paddle.dtype, softcap: Union[float, None], shape: tuple[int, int, int]): - # paddle.set_float32_matmul_precision("highest") - paddle.seed(0) - - if dtype == paddle.bfloat16 and not paddle.device.is_compiled_with_cuda(): - pytest.skip(reason="BF16 not available") - - N, V, D = shape - e = paddle.randn((N, D), dtype=dtype) / (D**0.5) - c = paddle.randn((V, D), dtype=dtype) - - c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] - - gt = _lse(e.cast("float32"), c.cast("float32"), softcap) - - # paddle.set_float32_matmul_precision("highest" if dtype == paddle.float32 else "high") - ref = _lse(e, c, softcap) - - cce_lse = cce_lse_forward_kernel(e, c, softcap=softcap) - - expected_error = (gt - ref).abs() - cce_error = (gt - cce_lse).abs() - - assert (cce_error <= (expected_error + 1e-5)).all(), f"{F.relu(cce_error - expected_error).max()=}" diff --git a/ops/tests/triton/mamba/test_layernorm_gated.py b/ops/tests/triton/mamba/test_layernorm_gated.py deleted file mode 100644 index 5395d2de581c..000000000000 --- a/ops/tests/triton/mamba/test_layernorm_gated.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -import paddle.nn.functional as F -import pytest -from einops import rearrange -from paddlenlp_kernel.triton.mamba.layernorm_gated import layernorm_fn, rms_norm_ref - -####################################################################################################################################### -# patch paddle.allclose -old_allclose = paddle.allclose - - -def allclose(a, b, **kwargs): - return old_allclose(a.cast("float32"), b.cast("float32"), **kwargs) - - -paddle.allclose = allclose - -old_equal_all = paddle.equal_all - - -def equal_all(a, b): - return old_equal_all(a.cast("float32"), b.cast("float32")) - - -paddle.equal_all = equal_all - - -def requires_grad_(self, value=True): - self.stop_gradient = not value - return self - - -paddle.Tensor.requires_grad_ = requires_grad_ -####################################################################################################################################### - - -@pytest.mark.parametrize("norm_before_gate", [True, False]) -# @pytest.mark.parametrize("norm_before_gate", [False]) -@pytest.mark.parametrize("has_group", [False, True]) -# @pytest.mark.parametrize("has_group", [False]) -@pytest.mark.parametrize("is_rms_norm", [False, True]) -# @pytest.mark.parametrize("is_rms_norm", [True]) -@pytest.mark.parametrize("has_z", [False, True]) -# @pytest.mark.parametrize("has_z", [True]) -@pytest.mark.parametrize("has_bias", [False, True]) -# @pytest.mark.parametrize("has_bias", [False]) -# @pytest.mark.parametrize('dtype', [paddle.float32, paddle.float16, paddle.bfloat16]) -@pytest.mark.parametrize("dtype", [paddle.float16]) -# @pytest.mark.parametrize("wtype", [paddle.float32, paddle.float16, paddle.bfloat16]) -@pytest.mark.parametrize("wtype", [paddle.float32]) -@pytest.mark.parametrize("d", [2048, 4096]) -# @pytest.mark.parametrize('d', [4096]) -def test_layer_norm_gated(d, dtype, wtype, has_bias, has_z, is_rms_norm, has_group, norm_before_gate): - if not has_z and not norm_before_gate: - pytest.skip() - if not norm_before_gate and not is_rms_norm: # Reference LN isn't implemented for this case yet - pytest.skip() - - rtol, atol = (1e-5, 1e-5) if dtype == paddle.float32 else (1e-2, 8e-3) - group_size = None if not has_group else 64 - # set seed - paddle.seed(0) - batch = 16 - seqlen = 1024 - x = paddle.randn([batch, seqlen, d], dtype=dtype).requires_grad_() - if has_z: - z = paddle.randn([batch, seqlen, d], dtype=dtype).requires_grad_() - else: - z = None - weight = paddle.randn( - [ - d, - ], - dtype=wtype, - ).requires_grad_() - if has_bias: - bias = paddle.randn( - [ - d, - ], - dtype=wtype, - ).requires_grad_() - else: - bias = None - x_ref = x.detach().clone().requires_grad_() - x_pt = x.detach().clone().requires_grad_() - z_ref = z.detach().clone().requires_grad_() if z is not None else None - z_pt = z.detach().clone().requires_grad_() if z is not None else None - weight_ref = weight.detach().clone().requires_grad_() - weight_pt = weight.detach().clone().requires_grad_() - bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None - bias_pd = bias.detach().clone().requires_grad_() if bias is not None else None - out = layernorm_fn( - x, - weight, - bias, - z=z, - eps=1e-5, - group_size=group_size, - norm_before_gate=norm_before_gate, - is_rms_norm=is_rms_norm, - ) - if not is_rms_norm: - if not has_group: - out_ref = F.layer_norm( - x_ref.cast("float32"), - (d,), - weight=weight_ref.cast("float32"), - bias=bias_ref.cast("float32") if bias_ref is not None else None, - epsilon=1e-5, - ) - out_pd = F.layer_norm(x_pt.cast(wtype), (d,), weight=weight_pt, bias=bias_pd, epsilon=1e-5) - else: - out_ref = rearrange( - F.layer_norm( - rearrange(x_ref, "... (g d) -> ... g d", d=group_size).cast("float32"), (group_size,), epsilon=1e-5 - ), - "... g d -> ... (g d)", - ) * weight_ref.cast("float32") - if has_bias: - out_ref = out_ref + bias_ref.cast("float32") - out_pd = ( - rearrange( - F.layer_norm(rearrange(x_pt, "... (g d) -> ... g d", d=group_size), (group_size,), epsilon=1e-5), - "... g d -> ... (g d)", - ) - * weight_pt - ) - if has_bias: - out_pd = out_pd + bias_pd - if has_z and norm_before_gate: - out_ref = out_ref * F.silu(z_ref.cast("float32")) - out_pd = out_pd * F.silu(z_pt) - else: - out_ref = rms_norm_ref( - x_ref, weight_ref, bias_ref, z=z_ref, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate - ) - out_pd = rms_norm_ref( - x_pt, - weight_pt, - bias_pd, - z=z_pt, - eps=1e-5, - group_size=group_size, - norm_before_gate=norm_before_gate, - upcast=False, - ) - print(f"Max diff = {(out - out_ref).abs().max().item()}") - print(f"Max diff Paddle = {(out_pd - out_ref).abs().max().item()}") - assert (out - out_ref).abs().max().item() <= 2 * (out_pd - out_ref).abs().max().item() + atol - - g = paddle.randn(out.shape, dtype=out.dtype) - out.backward(g) - out_ref.backward(g.cast(out_ref.dtype)) - out_pd.backward(g.cast(out_pd.dtype)) - print(f"Max dx diff = {(x.grad - x_ref.grad).abs().max().item()}") - print(f"Max dx diff Paddle = {(x_pt.grad - x_ref.grad).abs().max().item()}") - if has_z: - print(f"Max dz diff = {(z.grad - z_ref.grad).abs().max().item()}") - print(f"Max dz diff Paddle = {(z_pt.grad - z_ref.grad).abs().max().item()}") - print(f"Max dw diff = {(weight.grad - weight_ref.grad).abs().max().item()}") - print(f"Max dw diff Paddle = {(weight_pt.grad - weight_ref.grad).abs().max().item()}") - if has_bias: - print(f"Max db diff = {(bias.grad - bias_ref.grad).abs().max().item()}") - print(f"Max db diff Paddle = {(bias_pd.grad - bias_ref.grad).abs().max().item()}") - assert (x.grad - x_ref.grad).abs().max().item() <= 2 * (x_pt.grad - x_ref.grad).abs().max().item() + atol - if has_z: - assert (z.grad - z_ref.grad).abs().max().item() <= 2 * (z_pt.grad - z_ref.grad).abs().max().item() + atol - assert (weight.grad - weight_ref.grad).abs().max().item() <= 2 * ( - weight_pt.grad - weight_ref.grad - ).abs().max().item() + atol - if has_bias: - assert (bias.grad - bias_ref.grad).abs().max().item() <= 2 * ( - bias_pd.grad - bias_ref.grad - ).abs().max().item() + atol diff --git a/ops/tests/triton/mamba/test_selective_state_update.py b/ops/tests/triton/mamba/test_selective_state_update.py deleted file mode 100644 index 6719aecc3e1d..000000000000 --- a/ops/tests/triton/mamba/test_selective_state_update.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright (C) 2023, Tri Dao. - - -import paddle -import pytest -from einops import repeat -from paddlenlp_kernel.triton.mamba.selective_state_update import ( - selective_state_update, - selective_state_update_ref, -) - -####################################################################################################################################### -# patch paddle.allclose -old_allclose = paddle.allclose - - -def allclose(a, b, **kwargs): - return old_allclose(a.cast("float32"), b.cast("float32"), **kwargs) - - -paddle.allclose = allclose - -old_equal_all = paddle.equal_all - - -def equal_all(a, b): - return old_equal_all(a.cast("float32"), b.cast("float32")) - - -paddle.equal_all = equal_all - - -def requires_grad_(self, value=True): - self.stop_gradient = not value - return self - - -paddle.Tensor.requires_grad_ = requires_grad_ -####################################################################################################################################### - - -@pytest.mark.parametrize("itype", [paddle.float32, paddle.float16, paddle.bfloat16]) -# @pytest.mark.parametrize('itype', [paddle.float16]) -@pytest.mark.parametrize("has_z", [False, True]) -# @pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) -# @pytest.mark.parametrize("dstate", [16]) -@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -# @pytest.mark.parametrize("dim", [2048]) -def test_selective_state_update(dim, dstate, has_z, itype): - rtol, atol = (3e-4, 1e-3) if itype == paddle.float32 else (5e-3, 1e-2) - if itype == paddle.bfloat16: - rtol, atol = 1e-2, 5e-2 - # if torch.version.hip: - # atol *= 2 - # set seed - paddle.seed(0) - batch_size = 2 - state = paddle.randn([batch_size, dim, dstate], dtype=itype) - x = paddle.randn([batch_size, dim], dtype=itype) - dt = paddle.randn([batch_size, dim], dtype=itype) - dt_bias = ( - paddle.rand( - [ - dim, - ] - ) - - 4.0 - ) - A = -paddle.rand([dim, dstate]) - 1.0 - B = paddle.randn([batch_size, dstate]) - C = paddle.randn([batch_size, dstate]) - D = paddle.randn( - [ - dim, - ] - ) - if has_z: - z = paddle.randn(x.shape, dtype=x.dtype) - else: - z = None - state_ref = state.detach().clone() - out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) - out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert paddle.allclose(state, state_ref, rtol=rtol, atol=atol) - assert paddle.allclose(out, out_ref, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize("itype", [paddle.float32, paddle.float16, paddle.bfloat16]) -# @pytest.mark.parametrize('itype', [paddle.float16]) -@pytest.mark.parametrize("has_z", [False, True]) -# @pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize("tie_hdim", [False, True]) -# @pytest.mark.parametrize('tie_hdim', [True]) -@pytest.mark.parametrize("ngroups", [1, 2, 4]) -# @pytest.mark.parametrize("ngroups", [2]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) -# @pytest.mark.parametrize("dstate", [16]) -@pytest.mark.parametrize("dim", [2048, 4096]) -# @pytest.mark.parametrize("dim", [2048]) -def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype): - rtol, atol = (3e-4, 1e-3) if itype == paddle.float32 else (5e-3, 3e-2) - if itype == paddle.bfloat16: - rtol, atol = 1e-2, 1e-1 - # set seed - paddle.seed(0) - batch_size = 2 - headdim = 64 - nheads = dim // headdim - state = paddle.randn([batch_size, nheads, headdim, dstate], dtype=itype) - x = paddle.randn([batch_size, nheads, headdim], dtype=itype) - if not tie_hdim: - dt = paddle.randn([batch_size, nheads, headdim], dtype=itype) - dt_bias = paddle.rand([nheads, headdim]) - 4.0 - A = -paddle.rand([nheads, headdim, dstate]) - 1.0 - D = paddle.randn([nheads, headdim]) - else: - dt = repeat(paddle.randn([batch_size, nheads], dtype=itype), "b h -> b h p", p=headdim) - dt_bias = repeat( - paddle.rand( - [ - nheads, - ] - ) - - 4.0, - "h -> h p", - p=headdim, - ) - A = repeat( - -paddle.rand( - [ - nheads, - ] - ) - - 1.0, - "h -> h p n", - p=headdim, - n=dstate, - ) - D = repeat( - paddle.randn( - [ - nheads, - ] - ), - "h -> h p", - p=headdim, - ) - B = paddle.randn([batch_size, ngroups, dstate]) - C = paddle.randn([batch_size, ngroups, dstate]) - if has_z: - z = paddle.randn(x.shape, dtype=x.dtype) - else: - z = None - state_ref = state.detach().clone() - # state_og = state.detach().clone() - out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) - out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert paddle.allclose(state, state_ref, rtol=rtol, atol=atol) - assert paddle.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/ops/tests/triton/mamba/test_ssd.py b/ops/tests/triton/mamba/test_ssd.py deleted file mode 100644 index 1e88db794ae4..000000000000 --- a/ops/tests/triton/mamba/test_ssd.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -import paddle.nn.functional as F -import pytest -from einops import rearrange -from paddlenlp_kernel.triton.mamba.ssd_chunk_state import ( - _chunk_cumsum_fwd, - _chunk_state_fwd, - chunk_state, - chunk_state_varlen, -) -from paddlenlp_kernel.triton.mamba.ssd_state_passing import _state_passing_fwd - -####################################################################################################################################### -# patch paddle.allclose -old_allclose = paddle.allclose - - -def allclose(a, b, **kwargs): - return old_allclose(a.cast("float32"), b.cast("float32"), **kwargs) - - -paddle.allclose = allclose - -old_equal_all = paddle.equal_all - - -def equal_all(a, b): - return old_equal_all(a.cast("float32"), b.cast("float32")) - - -paddle.equal_all = equal_all - - -def requires_grad_(self, value=True): - self.stop_gradient = not value - return self - - -paddle.Tensor.requires_grad_ = requires_grad_ -####################################################################################################################################### - - -def detach_clone(*args): - return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args]) - - -@pytest.mark.parametrize("dtype", [paddle.float32, paddle.float16, paddle.bfloat16]) -# @pytest.mark.parametrize('dtype', [paddle.bfloat16]) -@pytest.mark.parametrize("ngroups", [1, 2, 8, "max"]) -# @pytest.mark.parametrize('ngroups', [1]) -@pytest.mark.parametrize("chunk_size", [64, 128]) -# @pytest.mark.parametrize('chunk_size', [128]) -def test_chunk_state_varlen(chunk_size, ngroups, dtype): - rtol, atol = (1e-2, 3e-3) - # set seed - paddle.seed(chunk_size + (ngroups if ngroups != "max" else 64)) - batch = 300 - seqlens = paddle.randint(1, 200, (batch,)) - # batch = 3 - # seqlens = paddle.tensor([201, 56, 5]) - cu_seqlens = F.pad(seqlens.cumsum(0).unsqueeze([0, 1]), (1, 0), data_format="NCL").squeeze([0, 1]) - total_seqlen = seqlens.sum().item() - seq_idx = paddle.concat( - [paddle.full((s,), i, dtype=paddle.int32) for i, s in enumerate(seqlens)], axis=0 - ).unsqueeze(0) - dim = 4096 - # dim = 64 - headdim = 64 - # dim = 32 - dstate = 32 - assert dim % headdim == 0 - nheads = dim // headdim - if ngroups == "max": - ngroups = nheads - assert nheads % ngroups == 0 - B = paddle.randn([total_seqlen, ngroups, dstate], dtype=dtype) / 5 - x = paddle.randn([total_seqlen, nheads, headdim], dtype=dtype) - A = -0.1 * ( - paddle.rand( - [ - nheads, - ] - ) - ) - dt = F.softplus(paddle.randn([total_seqlen, nheads], dtype=paddle.float32) - 4) - dA_cumsum, dt_rounded = _chunk_cumsum_fwd(dt.unsqueeze(0), A, chunk_size) - chunk_states = _chunk_state_fwd(B.unsqueeze(0), x.unsqueeze(0), dt_rounded, dA_cumsum, seq_idx=seq_idx) - chunk_states, _ = _state_passing_fwd( - rearrange(chunk_states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], seq_idx=seq_idx, chunk_size=chunk_size - ) - chunk_states = rearrange(chunk_states, "... (p n) -> ... p n", n=dstate) - chunk_states = chunk_states.squeeze(0) - dA_cumsum = dA_cumsum.squeeze(0) - dt_rounded = dt_rounded.squeeze(0) - out = chunk_state_varlen(B, x, dt_rounded, dA_cumsum, cu_seqlens, chunk_states) - out_ref = [] - for b in range(batch): - x_s = x[cu_seqlens[b] : cu_seqlens[b + 1]].unsqueeze(0) - B_s = B[cu_seqlens[b] : cu_seqlens[b + 1]].unsqueeze(0) - dt_s = dt[cu_seqlens[b] : cu_seqlens[b + 1]].unsqueeze(0) - dA_cumsum_s, dt_rounded_s = _chunk_cumsum_fwd(dt_s, A, chunk_size) - states = chunk_state(B_s, x_s, dt_rounded_s, dA_cumsum_s) - _, final_states = _state_passing_fwd( - rearrange(states, "... p n -> ... (p n)"), dA_cumsum_s[:, :, :, -1], chunk_size=chunk_size - ) - final_states = rearrange(final_states, "... (p n) -> ... p n", n=dstate) - out_ref.append(final_states) - out_ref = paddle.concat(out_ref, axis=0) - print(f"Max diff = {(out - out_ref).abs().max().item()}") - assert paddle.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/ops/utils/release.py b/ops/utils/release.py deleted file mode 100644 index c053b688dac3..000000000000 --- a/ops/utils/release.py +++ /dev/null @@ -1,122 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import re - -import packaging.version - -PATH_TO_EXAMPLES = "examples/" -REPLACE_PATTERNS = { - "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'), - "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'), -} -REPLACE_FILES = { - "init": "src/paddlenlp_kernel/__init__.py", - "setup": "setup.py", -} -README_FILE = "README.md" - - -def update_version_in_file(fname, version, pattern): - """Update the version in one file using a specific pattern.""" - with open(fname, "r", encoding="utf-8", newline="\n") as f: - code = f.read() - re_pattern, replace = REPLACE_PATTERNS[pattern] - replace = replace.replace("VERSION", version) - code = re_pattern.sub(replace, code) - with open(fname, "w", encoding="utf-8", newline="\n") as f: - f.write(code) - - -def update_version_in_examples(version): - """Update the version in all examples files.""" - for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES): - # Removing some of the folders with non-actively maintained examples from the walk - if "research_projects" in directories: - directories.remove("research_projects") - if "legacy" in directories: - directories.remove("legacy") - for fname in fnames: - if fname.endswith(".py"): - update_version_in_file(os.path.join(folder, fname), version, pattern="examples") - - -def global_version_update(version, patch=False): - """Update the version in all needed files.""" - for pattern, fname in REPLACE_FILES.items(): - update_version_in_file(fname, version, pattern) - if not patch: - update_version_in_examples(version) - - -def get_version(): - """Reads the current version in the __init__.""" - with open(REPLACE_FILES["init"], "r") as f: - code = f.read() - default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] - return packaging.version.parse(default_version) - - -def pre_release_work(patch=False): - """Do all the necessary pre-release steps.""" - # First let's get the default version: base version if we are in dev, bump minor otherwise. - default_version = get_version() - if patch and default_version.is_devrelease: - raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") - if default_version.is_devrelease: - default_version = default_version.base_version - elif patch: - default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" - else: - default_version = f"{default_version.major}.{default_version.minor + 1}.0" - - # Now let's ask nicely if that's the right one. - version = input(f"Which version are you releasing? [{default_version}]") - if len(version) == 0: - version = default_version - - print(f"Updating version to {version}.") - global_version_update(version, patch=patch) - - -def post_release_work(): - """Do all the necessary post-release steps.""" - # First let's get the current version - current_version = get_version() - dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" - current_version = current_version.base_version - - # Check with the user we got that right. - version = input(f"Which version are we developing now? [{dev_version}]") - if len(version) == 0: - version = dev_version - - print(f"Updating version to {version}.") - global_version_update(version) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.") - parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") - args = parser.parse_args() - if not args.post_release: - pre_release_work(patch=args.patch) - elif args.patch: - print("Nothing to do after a patch :-)") - else: - post_release_work() diff --git a/paddlenlp/transformers/gemma/modeling.py b/paddlenlp/transformers/gemma/modeling.py index e73ccc1c4c20..e7b07ea8fc65 100644 --- a/paddlenlp/transformers/gemma/modeling.py +++ b/paddlenlp/transformers/gemma/modeling.py @@ -26,7 +26,6 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.utils import recompute -from paddle.utils import try_import try: from paddle.incubate.nn.functional import fused_rotary_position_embedding @@ -84,11 +83,6 @@ def _get_interleave_power_of_2(n): ) -def rms_norm_fused(x_in, w, eps): - fused_ln = try_import("fused_ln") - return fused_ln.fused_rms_norm(x_in, w, eps)[0] - - def assign_kv_heads(num_kv_heads: int, num_gpus: int): # Initialize the assignment list """ @@ -369,7 +363,9 @@ def _norm(self, x): def forward(self, x): if self.config.use_fused_rms_norm: - return rms_norm_fused(x, self.weight + 1, self.variance_epsilon) + return paddle.incubate.nn.functional.fused_rms_norm_ext(x, self.weight + 1, self.variance_epsilon)[ + 0 + ].astype(self.weight.dtype) output = self._norm(x.astype(paddle.float32)).astype(x.dtype) return output * (self.weight + 1) diff --git a/paddlenlp/transformers/gpt/modeling.py b/paddlenlp/transformers/gpt/modeling.py index 68109bbcdb61..3156c7cc171f 100644 --- a/paddlenlp/transformers/gpt/modeling.py +++ b/paddlenlp/transformers/gpt/modeling.py @@ -40,7 +40,6 @@ except: pass from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from paddle.utils import try_import from ...utils.converter import StateDictNameMapping from ...utils.log import logger @@ -140,11 +139,6 @@ def seed_guard_context(name=None): return contextlib.nullcontext() -def fast_layer_norm(input, weight, bias, eps): - fast_ln_lib = try_import("fast_ln") - return fast_ln_lib.fast_ln(input, weight, bias, eps)[0] - - def _make_causal_mask(input_ids_shape, past_key_values_length): """ Make causal mask used for self-attention @@ -785,11 +779,6 @@ def __init__(self, config, normalized_shape, epsilon=1e-05, weight_attr=None, bi self.config = config _check_normalized_shape(self._normalized_shape) - def forward(self, input): - if self.config.use_fast_layer_norm: - return fast_layer_norm(input, self.weight, self.bias, self._epsilon) - return super().forward(input) - class GPTPretrainedModel(PretrainedModel): """ diff --git a/paddlenlp/transformers/gpt/modeling_auto.py b/paddlenlp/transformers/gpt/modeling_auto.py index e21067ba42c3..3194c1dab567 100644 --- a/paddlenlp/transformers/gpt/modeling_auto.py +++ b/paddlenlp/transformers/gpt/modeling_auto.py @@ -30,7 +30,6 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.utils import recompute -from paddle.utils import try_import try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( @@ -93,11 +92,6 @@ def seed_guard_context(name=None): return contextlib.nullcontext() -def fast_layer_norm(input, weight, bias, eps): - fast_ln_lib = try_import("fast_ln") - return fast_ln_lib.fast_ln(input, weight, bias, eps)[0] - - class GPTLayerNorm(nn.LayerNorm): def __init__(self, config, normalized_shape, ipp=-1, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None): super().__init__( @@ -114,11 +108,6 @@ def _check_normalized_shape(self, normalized_shape): if isinstance(normalized_shape, (list, tuple)): assert len(normalized_shape) == 1 - def forward(self, input): - if self.config.use_fast_layer_norm: - return fast_layer_norm(input, self.weight, self.bias, self._epsilon) - return super().forward(input) - def _make_causal_mask(input_ids_shape, past_key_values_length): """ diff --git a/paddlenlp/transformers/gpt/modeling_network.py b/paddlenlp/transformers/gpt/modeling_network.py index c558004dab6a..6b723b7f0d3a 100644 --- a/paddlenlp/transformers/gpt/modeling_network.py +++ b/paddlenlp/transformers/gpt/modeling_network.py @@ -27,7 +27,6 @@ import paddle.tensor as tensor from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.utils import recompute -from paddle.utils import try_import try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( @@ -82,11 +81,6 @@ def seed_guard_context(name=None): return contextlib.nullcontext() -def fast_layer_norm(input, weight, bias, eps): - fast_ln_lib = try_import("fast_ln") - return fast_ln_lib.fast_ln(input, weight, bias, eps)[0] - - class GPTLayerNorm(nn.LayerNorm): def __init__(self, config, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None): super().__init__( @@ -99,11 +93,6 @@ def _check_normalized_shape(self, normalized_shape): if isinstance(normalized_shape, (list, tuple)): assert len(normalized_shape) == 1 - def forward(self, input): - if self.config.use_fast_layer_norm: - return fast_layer_norm(input, self.weight, self.bias, self._epsilon) - return super().forward(input) - def _make_causal_mask(input_ids_shape, past_key_values_length): """ diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 62f3660a5bfe..cc6f7241f0c8 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -32,8 +32,6 @@ def swiglu(x, y=None): return F.silu(x) * y -from paddle.utils import try_import - from paddlenlp.utils.tools import get_env_device try: @@ -131,16 +129,7 @@ def fusion_rope( return query_states, key_states -def rms_norm_fused(x_in, w, eps, use_fast_ln=False): - if use_fast_ln: - fast_ln = try_import("fast_ln") - return fast_ln.fast_rms_norm(x_in, w, eps)[0] - else: - fused_ln = try_import("fused_ln") - return fused_ln.fused_rms_norm(x_in, w, eps)[0] - - -def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False): +def fusion_rms_norm(hidden_states, weight, variance_epsilon): if get_env_device() == "npu": return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0] if get_env_device() == "mlu": @@ -160,7 +149,9 @@ def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False): raise NotImplementedError( f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" ) - return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln) + return paddle.incubate.nn.functional.fused_rms_norm_ext(hidden_states, weight, variance_epsilon)[0].astype( + weight.dtype + ) def fusion_flash_attention( diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 84c9bcc7ff4e..6d79bcdb686e 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -98,8 +98,6 @@ def swiglu(x, y=None): flash_attention = None from . import fusion_ops -rms_norm_fused = fusion_ops.rms_norm_fused - __all__ = [ "LlamaModel", "LlamaLMHead", @@ -389,9 +387,7 @@ def __init__(self, config): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - return fusion_ops.fusion_rms_norm( - hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm - ) + return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index b5d1a71e3146..298d12dba139 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -237,9 +237,7 @@ def __init__(self, config, ipp): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - return fusion_ops.fusion_rms_norm( - hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm - ) + return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon) with paddle.amp.auto_cast(False): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) diff --git a/paddlenlp/transformers/llama/modeling_auto_pp.py b/paddlenlp/transformers/llama/modeling_auto_pp.py index bc4fed315131..13c778163748 100644 --- a/paddlenlp/transformers/llama/modeling_auto_pp.py +++ b/paddlenlp/transformers/llama/modeling_auto_pp.py @@ -247,9 +247,7 @@ def __init__(self, config, ipp): def forward(self, args): hidden_states, attention_mask, position_ids, alibi = parse_args(args) if self.config.use_fused_rms_norm: - hidden_states = fusion_ops.fusion_rms_norm( - hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm - ) + hidden_states = fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon) return return_args(hidden_states, attention_mask, position_ids, alibi) with paddle.amp.auto_cast(False): diff --git a/paddlenlp/transformers/llama/modeling_network.py b/paddlenlp/transformers/llama/modeling_network.py index 99aa98f63137..3481efcfc5fe 100644 --- a/paddlenlp/transformers/llama/modeling_network.py +++ b/paddlenlp/transformers/llama/modeling_network.py @@ -69,7 +69,6 @@ def swiglu(x, y=None): build_alibi_tensor, get_triangle_upper_mask, repeat_kv, - rms_norm_fused, ) try: @@ -358,7 +357,9 @@ def __init__(self, config): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) + return paddle.incubate.nn.functional.fused_rms_norm_ext(hidden_states, self.weight, self.variance_epsilon)[ + 0 + ].astype(self.weight.dtype) with paddle.amp.auto_cast(False): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 6f44737bc45a..756895b1b22d 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -25,7 +25,6 @@ from paddle.distributed import fleet from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute -from paddle.utils import try_import from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, @@ -1266,11 +1265,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): return q_embed, k_embed -def rms_norm_fused(x_in, w, eps): - fused_ln = try_import("fused_ln") - return fused_ln.fused_rms_norm(x_in, w, eps)[0] - - class QWenRMSNorm(nn.Layer): def __init__(self, config): super().__init__() @@ -1289,7 +1283,9 @@ def _norm(self, x): def forward(self, x): if self.config.use_fused_rms_norm: - return rms_norm_fused(x, self.weight, self.eps) + return paddle.incubate.nn.functional.fused_rms_norm_ext(x, self.weight, self.eps)[0].astype( + self.weight.dtype + ) output = self._norm(x.astype(paddle.float32)).astype(x.dtype) return output * self.weight diff --git a/paddlenlp/transformers/qwen/modeling_auto.py b/paddlenlp/transformers/qwen/modeling_auto.py index 79534161cf89..1f2ccb32740c 100644 --- a/paddlenlp/transformers/qwen/modeling_auto.py +++ b/paddlenlp/transformers/qwen/modeling_auto.py @@ -23,7 +23,6 @@ from paddle import nn from paddle.distributed import fleet from paddle.distributed.fleet.utils import recompute -from paddle.utils import try_import from paddlenlp.transformers.model_outputs import BaseModelOutputWithPast from paddlenlp.transformers.model_utils import PretrainedModel @@ -943,11 +942,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): return q_embed, k_embed -def rms_norm_fused(x_in, w, eps): - fused_ln = try_import("fused_ln") - return fused_ln.fused_rms_norm(x_in, w, eps)[0] - - class QWenRMSNormAuto(nn.Layer): def __init__(self, config, ipp): super().__init__() @@ -965,7 +959,9 @@ def _norm(self, x): def forward(self, x): if self.config.use_fused_rms_norm: - return rms_norm_fused(x, self.weight, self.eps) + return paddle.incubate.nn.functional.fused_rms_norm_ext(x, self.weight, self.eps)[0].astype( + self.weight.dtype + ) output = self._norm(x.astype(paddle.float32)).astype(x.dtype) return output * self.weight diff --git a/paddlenlp/transformers/qwen/modeling_network.py b/paddlenlp/transformers/qwen/modeling_network.py index a8eddb3437ea..9855ab7faa86 100644 --- a/paddlenlp/transformers/qwen/modeling_network.py +++ b/paddlenlp/transformers/qwen/modeling_network.py @@ -21,7 +21,6 @@ import paddle.nn.functional as F from paddle import nn from paddle.distributed.fleet.utils import recompute -from paddle.utils import try_import from paddlenlp.transformers.model_outputs import BaseModelOutputWithPast from paddlenlp.transformers.model_utils import PretrainedModel @@ -804,11 +803,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): return q_embed, k_embed -def rms_norm_fused(x_in, w, eps): - fused_ln = try_import("fused_ln") - return fused_ln.fused_rms_norm(x_in, w, eps)[0] - - class QWenRMSNormNet(nn.Layer): def __init__(self, config): super().__init__() @@ -825,7 +819,9 @@ def _norm(self, x): def forward(self, x): if self.config.use_fused_rms_norm: - return rms_norm_fused(x, self.weight, self.eps) + return paddle.incubate.nn.functional.fused_rms_norm_ext(x, self.weight, self.eps)[0].astype( + self.weight.dtype + ) output = self._norm(x.astype(paddle.float32)).astype(x.dtype) return output * self.weight diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index cb604d3a18b0..e2e544ecd366 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -313,7 +313,7 @@ def __init__(self, config: Qwen2Config): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon, False) + return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index e15cdf1373ab..7da412f06e6b 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -350,7 +350,7 @@ def __init__(self, config: Qwen2MoeConfig): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon, False) + return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): diff --git a/paddlenlp/transformers/qwen3/modeling.py b/paddlenlp/transformers/qwen3/modeling.py index 0406ab20d565..a7aabbe3a80b 100644 --- a/paddlenlp/transformers/qwen3/modeling.py +++ b/paddlenlp/transformers/qwen3/modeling.py @@ -109,7 +109,7 @@ def __init__(self, config: Qwen3Config, hidden_size=None, rms_norm_eps=None): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon, False) + return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): diff --git a/slm/model_zoo/gpt-3/ppfleetx/ops/fused_layers.py b/slm/model_zoo/gpt-3/ppfleetx/ops/fused_layers.py index 3e58e30e150d..d4e3493f71a9 100644 --- a/slm/model_zoo/gpt-3/ppfleetx/ops/fused_layers.py +++ b/slm/model_zoo/gpt-3/ppfleetx/ops/fused_layers.py @@ -13,7 +13,6 @@ # limitations under the License. import distutils.util -import importlib import os import paddle @@ -23,84 +22,18 @@ origin_linear = paddle.incubate.nn.functional.fused_linear -def try_import(module_name, func_name=None): - if func_name is None: - func_name = module_name - try: - m = importlib.import_module(module_name) - return m - # return getattr(m, func_name) - except ImportError: - return None - - -fast_ln_lib = try_import("fast_ln") -fused_ln_lib = try_import("fused_ln") - -if fast_ln_lib is not None: - fast_ln = fast_ln_lib.fast_ln - -if fused_ln_lib is not None: - fused_ln = fused_ln_lib.fused_ln - fused_rms_norm = fused_ln_lib.fused_rms_norm - - def check_normalized_shape(normalized_shape): if isinstance(normalized_shape, (list, tuple)): assert len(normalized_shape) == 1 -class FusedLayerNorm(OriginLayerNorm): - def __init__(self, - normalized_shape, - epsilon=1e-05, - weight_attr=None, - bias_attr=None, - name=None): - super().__init__( - normalized_shape=normalized_shape, - epsilon=epsilon, - weight_attr=weight_attr, - bias_attr=bias_attr) - check_normalized_shape(self._normalized_shape) - - def forward(self, input): - return fused_ln(input, self.weight, self.bias, self._epsilon)[0] - - class FusedRMSNorm(OriginLayerNorm): - def __init__(self, - normalized_shape, - epsilon=1e-05, - weight_attr=None, - name=None): - super().__init__( - normalized_shape=normalized_shape, - epsilon=epsilon, - weight_attr=weight_attr, - bias_attr=False) - check_normalized_shape(self._normalized_shape) - - def forward(self, input): - return fused_rms_norm(input, self.weight, self._epsilon)[0] - - -class FastLayerNorm(OriginLayerNorm): - def __init__(self, - normalized_shape, - epsilon=1e-05, - weight_attr=None, - bias_attr=None, - name=None): - super().__init__( - normalized_shape=normalized_shape, - epsilon=epsilon, - weight_attr=weight_attr, - bias_attr=bias_attr) + def __init__(self, normalized_shape, epsilon=1e-05, weight_attr=None, name=None): + super().__init__(normalized_shape=normalized_shape, epsilon=epsilon, weight_attr=weight_attr, bias_attr=False) check_normalized_shape(self._normalized_shape) def forward(self, input): - return fast_ln(input, self.weight, self.bias, self._epsilon)[0] + return paddle.incubate.nn.functional.fused_rms_norm(input, self.weight, self.bias, self._epsilon)[0] class FusedLinearWithGradAdd(paddle.autograd.PyLayer): @@ -117,21 +50,19 @@ def backward(ctx, y_grad): if bias is None: if hasattr(weight, "main_grad"): - weight.main_grad, _ = _C_ops.fused_linear_param_grad_add( - x, y_grad, weight.main_grad, None, True) + weight.main_grad, _ = _C_ops.fused_linear_param_grad_add(x, y_grad, weight.main_grad, None, True) return x_grad, None else: - weight_grad, _ = _C_ops.fused_linear_param_grad_add( - x, y_grad, None, None, False) + weight_grad, _ = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False) return x_grad, weight_grad if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"): weight.main_grad, bias.main_grad = _C_ops.fused_linear_param_grad_add( - x, y_grad, weight.main_grad, bias.main_grad, True) + x, y_grad, weight.main_grad, bias.main_grad, True + ) return x_grad, None, None else: - weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add( - x, y_grad, None, None, False) + weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False) return x_grad, weight_grad, bias_grad @@ -144,11 +75,7 @@ def get_env(env_name, default_value=False): def mock_layers(): - if get_env("USE_FAST_LN"): - paddle.nn.LayerNorm = FastLayerNorm - elif get_env("USE_FUSED_LN"): - paddle.nn.LayerNorm = FusedLayerNorm - elif get_env("USE_FUSED_RMS_NORM"): + if get_env("USE_FUSED_RMS_NORM"): paddle.nn.LayerNorm = FusedRMSNorm if get_env("USE_LINEAR_WITH_GRAD_ADD"):