|
| 1 | +//////////////////////////////////////////////////////////////////////////////// |
| 2 | +// BSD 3-Clause License |
| 3 | +// |
| 4 | +// Copyright (c) 2021, NVIDIA Corporation |
| 5 | +// All rights reserved. |
| 6 | +// |
| 7 | +// Redistribution and use in source and binary forms, with or without |
| 8 | +// modification, are permitted provided that the following conditions are met: |
| 9 | +// |
| 10 | +// 1. Redistributions of source code must retain the above copyright notice, this |
| 11 | +// list of conditions and the following disclaimer. |
| 12 | +// |
| 13 | +// 2. Redistributions in binary form must reproduce the above copyright notice, |
| 14 | +// this list of conditions and the following disclaimer in the documentation |
| 15 | +// and/or other materials provided with the distribution. |
| 16 | +// |
| 17 | +// 3. Neither the name of the copyright holder nor the names of its |
| 18 | +// contributors may be used to endorse or promote products derived from |
| 19 | +// this software without specific prior written permission. |
| 20 | +// |
| 21 | +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 22 | +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 23 | +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 24 | +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| 25 | +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| 26 | +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| 27 | +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| 28 | +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| 29 | +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 30 | +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 31 | +///////////////////////////////////////////////////////////////////////////////// |
| 32 | + |
| 33 | +#pragma once |
| 34 | + |
| 35 | + |
| 36 | +#include "matx/core/type_utils.h" |
| 37 | +#include "matx/operators/scalar_ops.h" |
| 38 | +#include "matx/operators/base_operator.h" |
| 39 | + |
| 40 | +namespace matx |
| 41 | +{ |
| 42 | + |
| 43 | + namespace detail { |
| 44 | + template <typename Op1, typename Op2> |
| 45 | + class IsCloseOp : public BaseOp<IsCloseOp<Op1, Op2>> |
| 46 | + { |
| 47 | + public: |
| 48 | + using matxop = bool; |
| 49 | + using scalar_type = typename remove_cvref_t<Op2>::scalar_type; |
| 50 | + using inner_type = typename inner_op_type_t<scalar_type>::type; |
| 51 | + |
| 52 | + __MATX_INLINE__ std::string str() const { return "isclose()"; } |
| 53 | + |
| 54 | + __MATX_INLINE__ IsCloseOp(Op1 op1, Op2 op2, double rtol, double atol) : |
| 55 | + op1_(op1), op2_(op2), rtol_(static_cast<inner_type>(rtol)), atol_(static_cast<inner_type>(atol)) |
| 56 | + { |
| 57 | + static_assert(op1.Rank() == op2.Rank(), "Operator ranks must match in isclose()"); |
| 58 | + for (int32_t i = 0; i < op2.Rank(); i++) { |
| 59 | + MATX_ASSERT_STR(op1.Size(i) == op2.Size(i), matxInvalidDim, |
| 60 | + "Size of each dimension must match in isclose()"); |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + template <typename... Is> |
| 65 | + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ int operator()([[maybe_unused]] Is... indices) const |
| 66 | + { |
| 67 | + |
| 68 | + return static_cast<int>(detail::_internal_abs(op1_(indices...) - op2_(indices...)) <= |
| 69 | + static_cast<inner_type>(atol_) + static_cast<inner_type>(rtol_) * detail::_internal_abs(op2_(indices...))); |
| 70 | + } |
| 71 | + |
| 72 | + static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() |
| 73 | + { |
| 74 | + return remove_cvref_t<Op1>::Rank(); |
| 75 | + } |
| 76 | + |
| 77 | + constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const |
| 78 | + { |
| 79 | + return op1_.Size(dim); |
| 80 | + } |
| 81 | + |
| 82 | + private: |
| 83 | + Op1 op1_; |
| 84 | + Op2 op2_; |
| 85 | + inner_type rtol_; |
| 86 | + inner_type atol_; |
| 87 | + |
| 88 | + }; |
| 89 | + } |
| 90 | + |
| 91 | + /** |
| 92 | + * @brief Returns an integer tensor where an element is 1 if: |
| 93 | + * abs(op1 - op2) <= atol + rtol * abs(op2) |
| 94 | + * |
| 95 | + * or 0 otherwise |
| 96 | + * |
| 97 | + * @tparam Op1 First operator type |
| 98 | + * @tparam Op2 Second operator type |
| 99 | + * @param op1 First operator |
| 100 | + * @param op2 Second operator |
| 101 | + * @param rtol Relative tolerance |
| 102 | + * @param atol Absolute tolerance |
| 103 | + * @return IsClose operator |
| 104 | + */ |
| 105 | + template <typename Op1, typename Op2> |
| 106 | + __MATX_INLINE__ auto isclose(Op1 op1, Op2 op2, double rtol = 1e-5, double atol = 1e-8) { |
| 107 | + return detail::IsCloseOp<Op1, Op2>(op1, op2, rtol, atol); |
| 108 | + } |
| 109 | +} // end namespace matx |
0 commit comments