@@ -49,8 +49,9 @@ namespace matx
4949 float alpha_;
5050 float beta_;
5151 PermDims perm_;
52- std::array<index_t , OpA::Rank()> out_dims_;
53- mutable matx::tensor_t <typename OpA::scalar_type, OpA::Rank()> tmp_out_;
52+ static constexpr int out_rank = std::max(OpA::Rank(), OpB::Rank());
53+ std::array<index_t , out_rank> out_dims_;
54+ mutable matx::tensor_t <typename OpA::scalar_type, out_rank> tmp_out_;
5455
5556 public:
5657 using matxop = bool ;
@@ -59,7 +60,7 @@ namespace matx
5960 using matmul_xform_op = bool ;
6061
6162 __MATX_INLINE__ std::string str () const {
62- return " matmul(" + get_type_str (a_) + " )" ;
63+ return " matmul(" + get_type_str (a_) + " , " + get_type_str (b_) + " )" ;
6364 }
6465
6566 __MATX_INLINE__ MatMulOp (OpA a, OpB b, float alpha, float beta, PermDims perm) :
@@ -73,17 +74,17 @@ namespace matx
7374 out_dims_[r] = b_.Size (perm_[r]);
7475 }
7576 else {
76- out_dims_[r] = a_ .Size (r);
77+ out_dims_[r] = OpA::Rank () > OpB::Rank () ? a_. Size (r) : b_ .Size (r);
7778 }
7879 }
7980 }
8081 else {
8182 for (int r = 0 ; r < Rank () - 2 ; r++) {
82- out_dims_[r] = a_ .Size (r);
83+ out_dims_[r] = OpA::Rank () > OpB::Rank () ? a_. Size (r) : b_ .Size (r);
8384 }
8485
85- out_dims_[OpA:: Rank () - 2 ] = a_.Size (OpA::Rank () - 2 );
86- out_dims_[OpB:: Rank () - 1 ] = b_.Size (OpB::Rank () - 1 );
86+ out_dims_[Rank () - 2 ] = a_.Size (OpA::Rank () - 2 );
87+ out_dims_[Rank () - 1 ] = b_.Size (OpB::Rank () - 1 );
8788 }
8889 }
8990
@@ -96,7 +97,7 @@ namespace matx
9697
9798 static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank ()
9899 {
99- return OpA::Rank () ;
100+ return out_rank ;
100101 }
101102 constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size (int dim) const
102103 {
@@ -123,7 +124,7 @@ namespace matx
123124
124125 if constexpr (is_matx_op<OpB>()) {
125126 b_.PreRun (std::forward<ShapeType>(shape), std::forward<Executor>(ex));
126- }
127+ }
127128
128129 if constexpr (is_device_executor_v<Executor>) {
129130 make_tensor (tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream ());
0 commit comments