We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6f8c194 commit ad2e7e0Copy full SHA for ad2e7e0
dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp
@@ -254,8 +254,8 @@ inline bool is_scalar_on_cpu(const at::Tensor& t) {
254
return t.defined() && t.is_cpu() && t.numel() == 1;
255
}
256
257
-inline bool check_scalar_on_cpu(const c10::optional<at::Tensor> t) {
258
- return t.has_value() && (*t).unsafeGetTensorImpl()->is_wrapped_number();
+inline bool check_scalar_on_cpu(const c10::optional<at::Tensor>& t) {
+ return t.has_value() && ((*t).unsafeGetTensorImpl()->is_wrapped_number() || ((*t).is_cpu() && (*t).numel() == 1 ));
259
260
261
} // namespace native
0 commit comments