Skip to content

Commit fa51cbc

Browse files
committed
Fixed rank of output in matmul operator when A/B had 0 stride
1 parent 1b27fe6 commit fa51cbc

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

include/matx/operators/matmul.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ namespace matx
4949
float alpha_;
5050
float beta_;
5151
PermDims perm_;
52-
std::array<index_t, OpA::Rank()> out_dims_;
53-
mutable matx::tensor_t<typename OpA::scalar_type, OpA::Rank()> tmp_out_;
52+
static constexpr int out_rank = std::max(OpA::Rank(), OpB::Rank());
53+
std::array<index_t, out_rank> out_dims_;
54+
mutable matx::tensor_t<typename OpA::scalar_type, out_rank> tmp_out_;
5455

5556
public:
5657
using matxop = bool;
@@ -59,7 +60,7 @@ namespace matx
5960
using matmul_xform_op = bool;
6061

6162
__MATX_INLINE__ std::string str() const {
62-
return "matmul(" + get_type_str(a_) + ")";
63+
return "matmul(" + get_type_str(a_) + "," + get_type_str(b_) + ")";
6364
}
6465

6566
__MATX_INLINE__ MatMulOp(OpA a, OpB b, float alpha, float beta, PermDims perm) :
@@ -73,17 +74,17 @@ namespace matx
7374
out_dims_[r] = b_.Size(perm_[r]);
7475
}
7576
else {
76-
out_dims_[r] = a_.Size(r);
77+
out_dims_[r] = OpA::Rank() > OpB::Rank() ? a_.Size(r) : b_.Size(r);
7778
}
7879
}
7980
}
8081
else {
8182
for (int r = 0; r < Rank() - 2; r++) {
82-
out_dims_[r] = a_.Size(r);
83+
out_dims_[r] = OpA::Rank() > OpB::Rank() ? a_.Size(r) : b_.Size(r);
8384
}
8485

85-
out_dims_[OpA::Rank() - 2] = a_.Size(OpA::Rank() - 2);
86-
out_dims_[OpB::Rank() - 1] = b_.Size(OpB::Rank() - 1);
86+
out_dims_[Rank() - 2] = a_.Size(OpA::Rank() - 2);
87+
out_dims_[Rank() - 1] = b_.Size(OpB::Rank() - 1);
8788
}
8889
}
8990

@@ -96,7 +97,7 @@ namespace matx
9697

9798
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
9899
{
99-
return OpA::Rank();
100+
return out_rank;
100101
}
101102
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
102103
{
@@ -123,7 +124,7 @@ namespace matx
123124

124125
if constexpr (is_matx_op<OpB>()) {
125126
b_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
126-
}
127+
}
127128

128129
if constexpr (is_device_executor_v<Executor>) {
129130
make_tensor(tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream());

0 commit comments

Comments
 (0)