Skip to content

Commit 22cfc0b

Browse files
authored
Feature: add python wrapper for math sphbes (#3475)
* recommit for review * add python wrapper * remove timer since performace tests add
1 parent af8ccfb commit 22cfc0b

File tree

8 files changed

+105
-38
lines changed

8 files changed

+105
-38
lines changed

python/pyabacus/CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@ set(BASE_PATH "${PROJECT_SOURCE_DIR}/../../source/module_base")
1212
set(ABACUS_SOURCE_DIR "${PROJECT_SOURCE_DIR}/../../source")
1313
include_directories(${BASE_PATH} ${ABACUS_SOURCE_DIR})
1414
list(APPEND _sources
15-
${ABACUS_SOURCE_DIR}/module_basis/module_nao/numerical_radial.h
16-
${ABACUS_SOURCE_DIR}/module_basis/module_nao/numerical_radial.cpp
17-
${PROJECT_SOURCE_DIR}/src/py_numerical_radial.cpp)
15+
#${ABACUS_SOURCE_DIR}/module_basis/module_nao/numerical_radial.h
16+
#${ABACUS_SOURCE_DIR}/module_basis/module_nao/numerical_radial.cpp
17+
${ABACUS_SOURCE_DIR}/module_base/constants.h
18+
${ABACUS_SOURCE_DIR}/module_base/math_sphbes.h
19+
${ABACUS_SOURCE_DIR}/module_base/math_sphbes.cpp
20+
${PROJECT_SOURCE_DIR}/src/py_abacus.cpp
21+
#${PROJECT_SOURCE_DIR}/src/py_numerical_radial.cpp
22+
${PROJECT_SOURCE_DIR}/src/py_math_base.cpp)
1823
python_add_library(_core MODULE ${_sources} WITH_SOABI)
1924
target_link_libraries(_core PRIVATE pybind11::headers)
2025
target_compile_definitions(_core PRIVATE VERSION_INFO=${PROJECT_VERSION})

python/pyabacus/src/py_abacus.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include <pybind11/numpy.h>
2+
#include <pybind11/pybind11.h>
3+
4+
namespace py = pybind11;
5+
6+
void bind_numerical_radial(py::module& m);
7+
void bind_math_base(py::module& m);
8+
9+
PYBIND11_MODULE(_core, m)
10+
{
11+
// bind_numerical_radial(m);
12+
bind_math_base(m);
13+
}

python/pyabacus/src/py_math_base.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include <pybind11/numpy.h>
2+
#include <pybind11/pybind11.h>
3+
4+
#include "module_base/math_sphbes.h"
5+
6+
namespace py = pybind11;
7+
using namespace pybind11::literals;
8+
template <typename... Args>
9+
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
10+
11+
void bind_math_base(py::module& m)
12+
{
13+
py::module module_base = m.def_submodule("ModuleBase");
14+
15+
py::class_<ModuleBase::Sphbes>(module_base, "Sphbes")
16+
.def(py::init<>())
17+
.def_static("sphbesj", overload_cast_<const int, const double>()(&ModuleBase::Sphbes::sphbesj), "l"_a, "x"_a)
18+
.def_static("dsphbesj", overload_cast_<const int, const double>()(&ModuleBase::Sphbes::dsphbesj), "l"_a, "x"_a)
19+
.def_static("sphbesj",
20+
[](const int n, py::array_t<double> r, const double q, const int l, py::array_t<double> jl) {
21+
py::buffer_info r_info = r.request();
22+
if (r_info.ndim != 1)
23+
{
24+
throw std::runtime_error("r array must be 1-dimensional");
25+
}
26+
py::buffer_info jl_info = jl.request();
27+
if (jl_info.ndim != 1)
28+
{
29+
throw std::runtime_error("jl array must be 1-dimensional");
30+
}
31+
ModuleBase::Sphbes::sphbesj(n,
32+
static_cast<const double* const>(r_info.ptr),
33+
q,
34+
l,
35+
static_cast<double* const>(jl_info.ptr));
36+
})
37+
.def_static("dsphbesj",
38+
[](const int n, py::array_t<double> r, const double q, const int l, py::array_t<double> djl) {
39+
py::buffer_info r_info = r.request();
40+
if (r_info.ndim != 1)
41+
{
42+
throw std::runtime_error("r array must be 1-dimensional");
43+
}
44+
py::buffer_info djl_info = djl.request();
45+
if (djl_info.ndim != 1)
46+
{
47+
throw std::runtime_error("djl array must be 1-dimensional");
48+
}
49+
ModuleBase::Sphbes::dsphbesj(n,
50+
static_cast<const double* const>(r_info.ptr),
51+
q,
52+
l,
53+
static_cast<double* const>(djl_info.ptr));
54+
})
55+
.def_static("sphbes_zeros", [](const int l, const int n, py::array_t<double> zeros) {
56+
py::buffer_info zeros_info = zeros.request();
57+
if (zeros_info.ndim != 1)
58+
{
59+
throw std::runtime_error("zeros array must be 1-dimensional");
60+
}
61+
ModuleBase::Sphbes::sphbes_zeros(l, n, static_cast<double* const>(zeros_info.ptr));
62+
});
63+
}

python/pyabacus/src/py_numerical_radial.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using namespace pybind11::literals;
88
template <typename... Args>
99
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
1010

11-
PYBIND11_MODULE(_core, m)
11+
void bind_numerical_radial(py::module& m)
1212
{
1313
// Create the submodule for NumericalRadial
1414
py::module m_numerical_radial = m.def_submodule("NumericalRadial");
@@ -165,4 +165,4 @@ PYBIND11_MODULE(_core, m)
165165
.def_property_readonly("kgrid", overload_cast_<int>()(&NumericalRadial::kgrid, py::const_))
166166
.def_property_readonly("rvalue", overload_cast_<int>()(&NumericalRadial::rvalue, py::const_))
167167
.def_property_readonly("kvalue", overload_cast_<int>()(&NumericalRadial::kvalue, py::const_));
168-
}
168+
}
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from __future__ import annotations
2-
from ._core import __doc__, __version__, NumericalRadial
3-
__all__ = ["__doc__", "__version__", "NumericalRadial"]
2+
# from ._core import __doc__, __version__, NumericalRadial, ModuleBase
3+
from ._core import ModuleBase
4+
__all__ = ["ModuleBase"]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from __future__ import annotations
2+
3+
import pyabacus as m
4+
import numpy as np
5+
6+
7+
def test_version():
8+
assert m.__version__ == "0.0.1"
9+
10+
def test_sphbes():
11+
s = m.ModuleBase.Sphbes()
12+
# test for sphbesj
13+
assert s.sphbesj(1, 0.0) == 0.0
14+
assert s.sphbesj(0, 0.0) == 1.0
15+

python/pyabacus/tests/test_nr.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

source/module_base/math_sphbes.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "math_sphbes.h"
2-
#include "timer.h"
32
#include "constants.h"
43
#include <algorithm>
4+
#include <iostream>
55

66
#include <cassert>
77

@@ -425,7 +425,6 @@ void Sphbes::Spherical_Bessel
425425
double *jl // jl(1:msh) = j_l(q*r(i)),spherical bessel function
426426
)
427427
{
428-
ModuleBase::timer::tick("Sphbes","Spherical_Bessel");
429428
double x1=0.0;
430429

431430
int i=0;
@@ -598,7 +597,6 @@ void Sphbes::Spherical_Bessel
598597
}
599598
}
600599

601-
ModuleBase::timer::tick("Sphbes","Spherical_Bessel");
602600
return;
603601
}
604602

@@ -613,7 +611,6 @@ void Sphbes::Spherical_Bessel
613611
double *sjp
614612
)
615613
{
616-
ModuleBase::timer::tick("Sphbes","Spherical_Bessel");
617614

618615
//calculate jlx first
619616
Spherical_Bessel (msh, r, q, l, sj);
@@ -634,7 +631,6 @@ void Sphbes::dSpherical_Bessel_dx
634631
double *djl // jl(1:msh) = j_l(q*r(i)),spherical bessel function
635632
)
636633
{
637-
ModuleBase::timer::tick("Sphbes","dSpherical_Bessel_dq");
638634
if (l < 0 )
639635
{
640636
std::cout << "We temporarily only calculate derivative of l >= 0." << std::endl;
@@ -682,7 +678,6 @@ void Sphbes::dSpherical_Bessel_dx
682678
}
683679
delete[] jl;
684680
}
685-
ModuleBase::timer::tick("Sphbes","dSpherical_Bessel_dq");
686681
return;
687682
}
688683

0 commit comments

Comments
 (0)