diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 10e3dedb87..1b8e1178b9 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -87,11 +87,15 @@ end @inline Base.ifelse(cond::Bool, a, b::CuTracedRNumber) = Base.ifelse(cond, a, b[]) @inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b) = Base.ifelse(cond, a[], b) -@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b::CuTracedRNumber) = Base.ifelse(cond, a[], b[]) +@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b::CuTracedRNumber) = + Base.ifelse(cond, a[], b[]) @inline Base.ifelse(cond::CuTracedRNumber, a, b) = Base.ifelse(cond[], a, b) -@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b) = Base.ifelse(cond[], a[], b) -@inline Base.ifelse(cond::CuTracedRNumber, a, b::CuTracedRNumber) = Base.ifelse(cond[], a, b[]) -@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b::CuTracedRNumber) = Base.ifelse(cond[], a[], b[]) +@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b) = + Base.ifelse(cond[], a[], b) +@inline Base.ifelse(cond::CuTracedRNumber, a, b::CuTracedRNumber) = + Base.ifelse(cond[], a, b[]) +@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b::CuTracedRNumber) = + Base.ifelse(cond[], a[], b[]) Base.@constprop :aggressive @inline Base.:^( a::CuTracedRNumber{T,A}, b::Integer