-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Paddle Tensor No.26】Svdvals new branch #69796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
HydrogenSulfate
merged 44 commits into
PaddlePaddle:develop
from
aquagull:svdvals_new_branch
Nov 29, 2024
Merged
Changes from all commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
6804a7d
add svdvals_kernel
aquagull 6183311
fix bug
aquagull f6dce63
fix bug
aquagull f27b4c7
fix bug
aquagull 46d6fff
fix bug
aquagull cf42122
fix some bug
aquagull cb428b2
fix bug
aquagull 3a1110c
fix bug
aquagull a729ef6
fix bug
aquagull 99dbd3b
add include
aquagull a008105
fix bug in svdvals_kernel
aquagull 963e03a
fix bug
aquagull 9824e90
fix bug
aquagull 9f8a548
fix bug in func SvdvalsInferMeta
aquagull 681cad9
add test
aquagull fd97a58
Merge branch 'PaddlePaddle:develop' into addSvdvals
aquagull 4d9c7ce
fix codestyle
aquagull cc1a1ae
fix lwork and int
aquagull e26cb76
fix
aquagull 7ee7eb9
use guard to control enable/disable
aquagull d3131bf
add test_check_grad
aquagull c1a1c48
fix test_svdvals_op
aquagull aa92a7a
fix bug
aquagull f246347
fix bug in svdvals_kernel
aquagull 0267c37
fix bug
aquagull b5470b8
fix bug
aquagull f86541e
fix bug
aquagull 73a53d6
fix bug in svdvals_grad_kernel
aquagull e439a2b
fix
aquagull 4135e24
fix
aquagull 69d05c0
add debug
aquagull 190ffe4
dix
aquagull 3284bd2
fix
aquagull a619b93
fix
aquagull 8e9c202
fix
aquagull b37f96b
fix bug in svdvals_grad_kernel
aquagull 7a09e29
fix
aquagull 290b201
fix
aquagull 2ffa2a0
fix
aquagull 46e4219
fix
aquagull 6099878
delete VLOG
aquagull e5b0039
delete head
aquagull a4f8235
fix
aquagull dd419d2
fix op_gen
aquagull File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/svdvals_grad_kernel.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/svdvals_grad_kernel_impl.h" | ||
|
||
PD_REGISTER_KERNEL( | ||
svdvals_grad, CPU, ALL_LAYOUT, phi::SvdvalsGradKernel, float, double) {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/svdvals_kernel.h" | ||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/complex_functors.h" | ||
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" | ||
#include "paddle/phi/kernels/transpose_kernel.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T> | ||
void LapackSvdvals(const T* X, T* S, int rows, int cols) { | ||
// Using N to neglect computing U、VH | ||
char jobz = 'N'; | ||
T* a = const_cast<T*>(X); | ||
int lda = rows; | ||
int lwork = -1; | ||
std::vector<T> work(1); | ||
int info = 0; | ||
// Get the best lwork | ||
phi::funcs::lapackSvd<T>(jobz, | ||
rows, | ||
cols, | ||
a, | ||
lda, | ||
S, | ||
nullptr, // U is not needed | ||
1, // dummy dimension for U | ||
nullptr, // VH is not needed | ||
1, // dummy dimension for VH | ||
work.data(), | ||
lwork, | ||
nullptr, // iwork is not needed | ||
&info); | ||
if (info != 0) { | ||
PADDLE_THROW(phi::errors::InvalidArgument( | ||
"Error during LAPACK lwork query. Invalid matrix or arguments.")); | ||
} | ||
lwork = static_cast<int>(work[0]); | ||
work.resize(lwork); | ||
phi::funcs::lapackSvd<T>(jobz, | ||
rows, | ||
cols, | ||
a, | ||
lda, | ||
S, | ||
nullptr, // U is not needed | ||
1, // dummy dimension for U | ||
nullptr, // VH is not needed | ||
1, // dummy dimension for VH | ||
work.data(), | ||
lwork, | ||
nullptr, // iwork is not needed | ||
&info); | ||
if (info < 0) { | ||
PADDLE_THROW(phi::errors::InvalidArgument( | ||
"This %s-th argument has an illegal value.", info)); | ||
} | ||
if (info > 0) { | ||
PADDLE_THROW(phi::errors::InvalidArgument( | ||
"SVD computation did not converge. Input matrix may be invalid.")); | ||
} | ||
} | ||
|
||
template <typename T> | ||
void BatchSvdvals(const T* X, T* S, int rows, int cols, int batches) { | ||
int stride = rows * cols; | ||
int stride_s = std::min(rows, cols); | ||
for (int i = 0; i < batches; i++) { | ||
LapackSvdvals<T>(X + i * stride, S + i * stride_s, rows, cols); | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void SvdvalsKernel(const Context& dev_ctx, | ||
const DenseTensor& X, | ||
DenseTensor* S) { | ||
auto x_dims = X.dims(); | ||
int rows = static_cast<int>(x_dims[x_dims.size() - 2]); | ||
int cols = static_cast<int>(x_dims[x_dims.size() - 1]); | ||
// Validate dimensions | ||
PADDLE_ENFORCE_GT( | ||
rows, | ||
0, | ||
phi::errors::InvalidArgument("The row of Input(X) must be > 0.")); | ||
PADDLE_ENFORCE_GT( | ||
cols, | ||
0, | ||
phi::errors::InvalidArgument("The column of Input(X) must be > 0.")); | ||
int k = std::min(rows, cols); | ||
int batches = static_cast<int>(X.numel() / (rows * cols)); | ||
PADDLE_ENFORCE_GT( | ||
batches, | ||
0, | ||
phi::errors::InvalidArgument("The batch size of Input(X) must be > 0.")); | ||
DDim s_dims; | ||
if (batches == 1) { | ||
s_dims = {k}; | ||
} else { | ||
s_dims = {batches, k}; | ||
} | ||
S->Resize(s_dims); | ||
// Allocate memory for output | ||
auto* S_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(S); | ||
|
||
// Transpose the last two dimensions for LAPACK compatibility | ||
DenseTensor trans_x = ::phi::TransposeLast2Dim<T>(dev_ctx, X); | ||
auto* x_data = trans_x.data<T>(); | ||
// Perform batch SVD computation for singular values | ||
BatchSvdvals<T>(x_data, S_out, rows, cols, batches); | ||
} | ||
|
||
} // namespace phi | ||
|
||
// Register the kernel for CPU | ||
PD_REGISTER_KERNEL( | ||
svdvals, CPU, ALL_LAYOUT, phi::SvdvalsKernel, float, double) {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/infermeta/unary.h" | ||
#include "paddle/phi/kernels/activation_kernel.h" | ||
#include "paddle/phi/kernels/diag_kernel.h" | ||
#include "paddle/phi/kernels/elementwise_multiply_kernel.h" | ||
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
#include "paddle/phi/kernels/impl/diag_embed_impl.h" | ||
#include "paddle/phi/kernels/matmul_kernel.h" | ||
#include "paddle/phi/kernels/slice_kernel.h" | ||
#include "paddle/phi/kernels/svd_kernel.h" | ||
#include "paddle/phi/kernels/transpose_kernel.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void SvdvalsGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& s_grad, | ||
DenseTensor* x_grad) { | ||
auto x_dims = x.dims(); | ||
int rows = static_cast<int>(x_dims[x_dims.size() - 2]); | ||
int cols = static_cast<int>(x_dims[x_dims.size() - 1]); | ||
int batches = static_cast<int>(x.numel() / (rows * cols)); | ||
DenseTensor dX_term; | ||
if (batches == 1) { | ||
dX_term = Diag<T, Context>(dev_ctx, s_grad, 0, 0); | ||
} else { | ||
MetaTensor meta_dX(&dX_term); | ||
DiagEmbedInferMeta(s_grad, 0, -1, -2, &meta_dX); | ||
phi::DiagEmbedKernel<T, Context>(dev_ctx, s_grad, 0, -1, -2, &dX_term); | ||
} | ||
|
||
DenseTensor U, VH, S_recomputed; | ||
MetaTensor meta_u(&U), meta_s(&S_recomputed), meta_vh(&VH); | ||
SvdInferMeta(x, false, &meta_u, &meta_s, &meta_vh); | ||
phi::SvdKernel<T, Context>(dev_ctx, | ||
x, | ||
false, | ||
&U, | ||
&S_recomputed, | ||
&VH); // Crucial: recomputing SVD | ||
*x_grad = | ||
Matmul<T, Context>(dev_ctx, Matmul<T, Context>(dev_ctx, U, dX_term), VH); | ||
} | ||
} // namespace phi |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void SvdvalsGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& s_grad, | ||
DenseTensor* x_grad); | ||
} // namespace phi |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void SvdvalsKernel(const Context& dev_ctx, | ||
const DenseTensor& X, | ||
DenseTensor* S); | ||
|
||
} // namespace phi |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续可以考虑以更易于维护的方式来扩展
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的谢谢佬