Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6804a7d
add svdvals_kernel
aquagull Nov 17, 2024
6183311
fix bug
aquagull Nov 17, 2024
f6dce63
fix bug
aquagull Nov 17, 2024
f27b4c7
fix bug
aquagull Nov 17, 2024
46d6fff
fix bug
aquagull Nov 17, 2024
cf42122
fix some bug
aquagull Nov 18, 2024
cb428b2
fix bug
aquagull Nov 18, 2024
3a1110c
fix bug
aquagull Nov 18, 2024
a729ef6
fix bug
aquagull Nov 18, 2024
99dbd3b
add include
aquagull Nov 18, 2024
a008105
fix bug in svdvals_kernel
aquagull Nov 19, 2024
963e03a
fix bug
aquagull Nov 19, 2024
9824e90
fix bug
aquagull Nov 19, 2024
9f8a548
fix bug in func SvdvalsInferMeta
aquagull Nov 19, 2024
681cad9
add test
aquagull Nov 19, 2024
fd97a58
Merge branch 'PaddlePaddle:develop' into addSvdvals
aquagull Nov 20, 2024
4d9c7ce
fix codestyle
aquagull Nov 20, 2024
cc1a1ae
fix lwork and int
aquagull Nov 20, 2024
e26cb76
fix
aquagull Nov 20, 2024
7ee7eb9
use guard to control enable/disable
aquagull Nov 20, 2024
d3131bf
add test_check_grad
aquagull Nov 20, 2024
c1a1c48
fix test_svdvals_op
aquagull Nov 21, 2024
aa92a7a
fix bug
aquagull Nov 21, 2024
f246347
fix bug in svdvals_kernel
aquagull Nov 21, 2024
0267c37
fix bug
aquagull Nov 22, 2024
b5470b8
fix bug
aquagull Nov 22, 2024
f86541e
fix bug
aquagull Nov 22, 2024
73a53d6
fix bug in svdvals_grad_kernel
aquagull Nov 22, 2024
e439a2b
fix
aquagull Nov 22, 2024
4135e24
fix
aquagull Nov 22, 2024
69d05c0
add debug
aquagull Nov 22, 2024
190ffe4
dix
aquagull Nov 22, 2024
3284bd2
fix
aquagull Nov 22, 2024
a619b93
fix
aquagull Nov 22, 2024
8e9c202
fix
aquagull Nov 22, 2024
b37f96b
fix bug in svdvals_grad_kernel
aquagull Nov 25, 2024
7a09e29
fix
aquagull Nov 26, 2024
290b201
fix
aquagull Nov 26, 2024
2ffa2a0
fix
aquagull Nov 26, 2024
46e4219
fix
aquagull Nov 27, 2024
6099878
delete VLOG
aquagull Nov 28, 2024
e5b0039
delete head
aquagull Nov 28, 2024
a4f8235
fix
aquagull Nov 28, 2024
dd419d2
fix op_gen
aquagull Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ void OperatorDialect::initialize() {
>();
RegisterOps<
#define GET_OP_LIST2
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT
>();
RegisterOps<
#define GET_OP_LIST3
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT
>();
RegisterOps<
#define GET_OP_LIST4
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT
>();
#else
Expand Down
18 changes: 16 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
#elif defined(GET_OP_LIST2)
#undef GET_OP_LIST2
{op_declare_second_part}
#elif defined(GET_OP_LIST3)
#undef GET_OP_LIST3
{op_declare_third_part}
#elif defined(GET_OP_LIST4)
#undef GET_OP_LIST4
{op_declare_fourth_part}
"""

CC_OP_INFO_FILE_TEMPLATE_PART2 = """
Expand Down Expand Up @@ -2390,9 +2396,11 @@ def OpGenerator(

if op_info_file is not None:
if sys.platform == "win32":
n = len(op_list_strs) // 2
n = len(op_list_strs) // 4
first_part_op_info = op_list_strs[:n]
second_part_op_info = op_list_strs[n:]
second_part_op_info = op_list_strs[n : 2 * n]
third_part_op_info = op_list_strs[2 * n : 3 * n]
fourth_part_op_info = op_list_strs[3 * n :]
Comment on lines +2399 to +2403
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_parts = 4
part_size = math.ceil(len(op_list_strs) / n_part)
op_list_parts = [op_list_strs[part_offset:part_offset + part_size] for part_offset in range(0, len(op_list_strs), part_size)]

后续可以考虑以更易于维护的方式来扩展

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_parts = 4
part_size = math.ceil(len(op_list_strs) / n_part)
op_list_parts = [op_list_strs[part_offset:part_offset + part_size] for part_offset in range(0, len(op_list_strs), part_size)]

后续可以考虑以更易于维护的方式来扩展

好的谢谢佬

CC_OP_INFO_FILE_TEMPLATE = (
CC_OP_INFO_FILE_TEMPLATE_WIN_PART1
+ CC_OP_INFO_FILE_TEMPLATE_PART2
Expand All @@ -2404,6 +2412,12 @@ def OpGenerator(
op_declare_second_part=",".join(second_part_op_info).replace(
"\n", ""
),
op_declare_third_part=",".join(third_part_op_info).replace(
"\n", ""
),
op_declare_fourth_part=",".join(fourth_part_op_info).replace(
"\n", ""
),
other_info=other_info_str,
h_file=op_def_h_file[:-4],
)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,17 @@ void OperatorDialect::initialize() {
#define GET_OP_LIST2
#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT
>();

RegisterOps<
#define GET_OP_LIST3
#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT
>();

RegisterOps<
#define GET_OP_LIST4
#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT
>();

#else
RegisterOps<
#define GET_OP_LIST
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ void OneDNNOperatorDialect::initialize() {
#define GET_OP_LIST2
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op_info.cc" // NOLINT
>();
RegisterOps<
#define GET_OP_LIST3
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op_info.cc" // NOLINT
>();
RegisterOps<
#define GET_OP_LIST4
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op_info.cc" // NOLINT
>();

#else
RegisterOps<
#define GET_OP_LIST
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4902,6 +4902,31 @@ void PartialConcatInferMeta(const std::vector<const MetaTensor*>& xs,
out->set_dtype(xs[0]->dtype());
}

void SvdvalsInferMeta(const MetaTensor& x, MetaTensor* s) {
auto SDDim = [](const DDim& x_dim, int k) {
auto x_vec = common::vectorize(x_dim);
x_vec.erase(x_vec.end() - 2, x_vec.end());
x_vec.push_back(k);
return common::make_ddim(x_vec);
};

auto in_dims = x.dims();
int64_t x_rank = in_dims.size();

PADDLE_ENFORCE_GE(
x_rank,
2,
common::errors::InvalidArgument("The rank of input tensor must be >= 2"));

int64_t m = in_dims[x_rank - 2];
int64_t n = in_dims[x_rank - 1];

int64_t k = std::min(m, n);
s->set_dims(SDDim(in_dims, k));
s->share_lod(x);
s->set_dtype(x.dtype());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,8 @@ void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
MetaTensor* out,
MetaConfig config = MetaConfig());

void SvdvalsInferMeta(const MetaTensor& x, MetaTensor* s);

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/kernels/cpu/svdvals_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/svdvals_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/svdvals_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
svdvals_grad, CPU, ALL_LAYOUT, phi::SvdvalsGradKernel, float, double) {}
130 changes: 130 additions & 0 deletions paddle/phi/kernels/cpu/svdvals_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/svdvals_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {

template <typename T>
void LapackSvdvals(const T* X, T* S, int rows, int cols) {
// Using N to neglect computing U、VH
char jobz = 'N';
T* a = const_cast<T*>(X);
int lda = rows;
int lwork = -1;
std::vector<T> work(1);
int info = 0;
// Get the best lwork
phi::funcs::lapackSvd<T>(jobz,
rows,
cols,
a,
lda,
S,
nullptr, // U is not needed
1, // dummy dimension for U
nullptr, // VH is not needed
1, // dummy dimension for VH
work.data(),
lwork,
nullptr, // iwork is not needed
&info);
if (info != 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Error during LAPACK lwork query. Invalid matrix or arguments."));
}
lwork = static_cast<int>(work[0]);
work.resize(lwork);
phi::funcs::lapackSvd<T>(jobz,
rows,
cols,
a,
lda,
S,
nullptr, // U is not needed
1, // dummy dimension for U
nullptr, // VH is not needed
1, // dummy dimension for VH
work.data(),
lwork,
nullptr, // iwork is not needed
&info);
if (info < 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"This %s-th argument has an illegal value.", info));
}
if (info > 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"SVD computation did not converge. Input matrix may be invalid."));
}
}

template <typename T>
void BatchSvdvals(const T* X, T* S, int rows, int cols, int batches) {
int stride = rows * cols;
int stride_s = std::min(rows, cols);
for (int i = 0; i < batches; i++) {
LapackSvdvals<T>(X + i * stride, S + i * stride_s, rows, cols);
}
}

template <typename T, typename Context>
void SvdvalsKernel(const Context& dev_ctx,
const DenseTensor& X,
DenseTensor* S) {
auto x_dims = X.dims();
int rows = static_cast<int>(x_dims[x_dims.size() - 2]);
int cols = static_cast<int>(x_dims[x_dims.size() - 1]);
// Validate dimensions
PADDLE_ENFORCE_GT(
rows,
0,
phi::errors::InvalidArgument("The row of Input(X) must be > 0."));
PADDLE_ENFORCE_GT(
cols,
0,
phi::errors::InvalidArgument("The column of Input(X) must be > 0."));
int k = std::min(rows, cols);
int batches = static_cast<int>(X.numel() / (rows * cols));
PADDLE_ENFORCE_GT(
batches,
0,
phi::errors::InvalidArgument("The batch size of Input(X) must be > 0."));
DDim s_dims;
if (batches == 1) {
s_dims = {k};
} else {
s_dims = {batches, k};
}
S->Resize(s_dims);
// Allocate memory for output
auto* S_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(S);

// Transpose the last two dimensions for LAPACK compatibility
DenseTensor trans_x = ::phi::TransposeLast2Dim<T>(dev_ctx, X);
auto* x_data = trans_x.data<T>();
// Perform batch SVD computation for singular values
BatchSvdvals<T>(x_data, S_out, rows, cols, batches);
}

} // namespace phi

// Register the kernel for CPU
PD_REGISTER_KERNEL(
svdvals, CPU, ALL_LAYOUT, phi::SvdvalsKernel, float, double) {}
62 changes: 62 additions & 0 deletions paddle/phi/kernels/impl/svdvals_grad_kernel_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/diag_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/diag_embed_impl.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/svd_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {

template <typename T, typename Context>
void SvdvalsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& s_grad,
DenseTensor* x_grad) {
auto x_dims = x.dims();
int rows = static_cast<int>(x_dims[x_dims.size() - 2]);
int cols = static_cast<int>(x_dims[x_dims.size() - 1]);
int batches = static_cast<int>(x.numel() / (rows * cols));
DenseTensor dX_term;
if (batches == 1) {
dX_term = Diag<T, Context>(dev_ctx, s_grad, 0, 0);
} else {
MetaTensor meta_dX(&dX_term);
DiagEmbedInferMeta(s_grad, 0, -1, -2, &meta_dX);
phi::DiagEmbedKernel<T, Context>(dev_ctx, s_grad, 0, -1, -2, &dX_term);
}

DenseTensor U, VH, S_recomputed;
MetaTensor meta_u(&U), meta_s(&S_recomputed), meta_vh(&VH);
SvdInferMeta(x, false, &meta_u, &meta_s, &meta_vh);
phi::SvdKernel<T, Context>(dev_ctx,
x,
false,
&U,
&S_recomputed,
&VH); // Crucial: recomputing SVD
*x_grad =
Matmul<T, Context>(dev_ctx, Matmul<T, Context>(dev_ctx, U, dX_term), VH);
}
} // namespace phi
26 changes: 26 additions & 0 deletions paddle/phi/kernels/svdvals_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void SvdvalsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& s_grad,
DenseTensor* x_grad);
} // namespace phi
27 changes: 27 additions & 0 deletions paddle/phi/kernels/svdvals_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void SvdvalsKernel(const Context& dev_ctx,
const DenseTensor& X,
DenseTensor* S);

} // namespace phi
Loading