Skip to content

Commit 1d0c13f

Browse files
authored
Fix typedef and undef ctors for ConcreteRArray and ConcreteRNumber (#1571)
* Fix typedef of ConcreteRNumber * Fix typedef and ctors of ConcreteRArray and add tests * Add tests for ConcreteRNumber * Remove deprecated ConcreteRArray undef ctor with T as argument * Don't restrict ConcreteRArray undef ctor to eltype Number * Remove type parameters in ConcreteRArray and ConcreteRNumber typedefs
1 parent 56c272a commit 1d0c13f

File tree

4 files changed

+42
-13
lines changed

4 files changed

+42
-13
lines changed

ext/ReactantKernelAbstractionsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function Base.getproperty(x::ReactantBackend, sym::Symbol)
2626
end
2727

2828
function KA.allocate(::ReactantBackend, ::Type{T}, dims::Tuple) where {T}
29-
return ConcreteRArray(undef, T, dims)
29+
return ConcreteRArray{T}(undef, dims)
3030
end
3131

3232
function KA.zeros(b::ReactantBackend, ::Type{T}, dims::Tuple) where {T}

src/ConcreteRArray.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ end
378378
device::Union{Nothing,XLA.PJRT.Device}=nothing,
379379
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
380380
) where {S}
381-
return ConcretePJRTArray(
382-
undef, S, dims; client=client, idx=idx, device=device, sharding=sharding
381+
return ConcretePJRTArray{S}(
382+
undef, dims; client=client, idx=idx, device=device, sharding=sharding
383383
)
384384
end
385385

@@ -410,7 +410,7 @@ function Base.similar(a::ConcreteIFRTArray{T}, ::Type{S}=T, dims::Dims=size(a))
410410
end
411411
Base.similar(a::ConcreteIFRTArray, dims::Dims) = similar(a, eltype(a), dims)
412412
function Base.similar(::Type{ConcreteIFRTArray{T}}, dims) where {T}
413-
return ConcreteIFRTArray(undef, T, dims)
413+
return ConcreteIFRTArray{T}(undef, dims)
414414
end
415415

416416
# Broadcasting interface

src/Types.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -410,14 +410,16 @@ Base.@deprecate_binding ConcreteRNG ReactantRNG
410410
Base.@deprecate_binding TracedRNG ReactantRNG
411411

412412
"""
413-
ConcreteRArray(
414-
undef, ::Type{T}, shape::Dims;
413+
ConcreteRArray{T}(
414+
undef, shape::Dims;
415415
client::Union{Nothing,XLA.AbstractClient} = nothing,
416416
device::Union{Nothing,XLA.AbstractDevice} = nothing,
417417
sharding::Sharding.AbstractSharding = Sharding.NoSharding(),
418418
)
419419
420-
ConcretePJRTArray(data::Array, kwargs...)
420+
ConcretePJRTArray{T}(undef, shape::Integer...; kwargs...)
421+
422+
ConcretePJRTArray(data::Array; kwargs...)
421423
422424
Allocate an uninitialized `ConcreteRArray` of element type `T` and size
423425
`shape` or convert an `Array` to a `ConcreteRArray`.
@@ -434,6 +436,9 @@ elseif XLA.REACTANT_XLA_RUNTIME == "IFRT"
434436
ConcreteIFRTArray
435437
end
436438

439+
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} =
440+
ConcreteRArray{T}(undef, Dims(shape); kwargs...)
441+
437442
"""
438443
ConcreteRNumber(
439444
x::Number;
@@ -460,16 +465,13 @@ end
460465

461466
## Other Aliases based on the set preferences
462467
@static if XLA.REACTANT_XLA_RUNTIME == "PJRT"
463-
const ConcreteRNumber = ConcretePJRTNumber
464468
const AnyConcreteRArray = AnyConcretePJRTArray
465469
elseif XLA.REACTANT_XLA_RUNTIME == "IFRT"
466-
const ConcreteRNumber = ConcreteIFRTNumber
467470
const AnyConcreteRArray = AnyConcreteIFRTArray
468471
end
469472

470-
function ConcretePJRTArray(
473+
function ConcretePJRTArray{T}(
471474
::UndefInitializer,
472-
::Type{T},
473475
shape::Dims;
474476
client::Union{Nothing,XLA.AbstractClient}=nothing,
475477
idx::Union{Int,Nothing}=nothing,
@@ -483,9 +485,8 @@ function ConcretePJRTArray(
483485
return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
484486
end
485487

486-
function ConcreteIFRTArray(
488+
function ConcreteIFRTArray{T}(
487489
::UndefInitializer,
488-
::Type{T},
489490
shape::Dims;
490491
client::Union{Nothing,XLA.AbstractClient}=nothing,
491492
idx::Union{Int,Nothing}=nothing,

test/layout.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,34 @@ using Reactant
22
using Test
33

44
@testset "Layout" begin
5+
client = Reactant.XLA.default_backend()
6+
device = Reactant.XLA.default_device()
7+
sharding = Sharding.NoSharding()
8+
idx = 0
9+
10+
@test ConcreteRArray{Float32}(undef, (100, 10)) isa ConcreteRArray{Float32,2}
11+
12+
@test ConcreteRArray{Float32}(
13+
undef, (100, 10); client=client, idx=idx, device=device
14+
) isa ConcreteRArray{Float32,2}
15+
16+
@test ConcreteRArray{Float32}(
17+
undef, Int32(100), Int16(10); client=client, idx=idx, device=device
18+
) isa ConcreteRArray{Float32,2}
19+
20+
@test ConcreteRNumber(Float32(4.2)) isa ConcreteRNumber{Float32}
21+
22+
@test ConcreteRNumber(Float16(4.2); client=client, idx=idx, device=device) isa
23+
ConcreteRNumber{Float16}
24+
25+
@test ConcreteRNumber{Float32}(Float32(4.2); client=client, idx=idx, device=device) isa
26+
ConcreteRNumber{Float32}
27+
28+
@test ConcreteRNumber{Float16}(Float32(4.2)) isa ConcreteRNumber{Float16}
29+
30+
@test ConcreteRNumber{Float32}(Float16(4.2); client=client, idx=idx, device=device) isa
31+
ConcreteRNumber{Float32}
32+
533
x = reshape([1.0, 2.0, 3.0, 4.0], (2, 2))
634

735
y = Reactant.to_rarray(x)

0 commit comments

Comments
 (0)