Skip to content

Commit ad2e7e0

Browse files
committed
improve the code
1 parent 6f8c194 commit ad2e7e0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ inline bool is_scalar_on_cpu(const at::Tensor& t) {
254254
return t.defined() && t.is_cpu() && t.numel() == 1;
255255
}
256256

257-
inline bool check_scalar_on_cpu(const c10::optional<at::Tensor> t) {
258-
return t.has_value() && (*t).unsafeGetTensorImpl()->is_wrapped_number();
257+
inline bool check_scalar_on_cpu(const c10::optional<at::Tensor>& t) {
258+
return t.has_value() && ((*t).unsafeGetTensorImpl()->is_wrapped_number() || ((*t).is_cpu() && (*t).numel() == 1 ));
259259
}
260260

261261
} // namespace native

0 commit comments

Comments
 (0)