Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ jobs:
timeout-minutes: 90
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.runtime }} - assertions=${{ matrix.assertions }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
container:
image: ${{ contains(matrix.os, 'linux') && 'ghcr.io/enzymead/reactant-docker-images:main' || '' }}
strategy:
fail-fast: false
matrix:
Expand All @@ -59,6 +61,11 @@ jobs:
assertions:
- false
include:
- os: linux-x86-ct6e-180-4tpu
version: "1.11"
assertions: false
test_group: core
runtime: "IFRT"
- os: ubuntu-24.04
version: "1.10"
assertions: true
Expand Down Expand Up @@ -86,9 +93,13 @@ jobs:
# libReactant: packaged
# version: '1.10'
# test_group: integration
env:
TMPDIR: ${{ github.workspace }}/tmp
steps:
- name: Set TMPDIR
# We have to use `${GITHUB_WORKSPACE}` instead of `github.workspace` because GitHub
# is terrible and the two don't match inside containers:
# https://github.com/actions/runner/issues/2058
run: |
echo "TMPDIR=${GITHUB_WORKSPACE}/tmp" >> ${GITHUB_ENV}
- uses: actions/checkout@v4
- name: Create TMPDIR
run: |
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
unzip_jll = "88f77b66-78eb-5ed0-bc16-ebba0796830d"

[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down Expand Up @@ -102,6 +103,7 @@ Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Statistics = "1.10"
unzip_jll = "6"
YaoBlocks = "0.13, 0.14"
julia = "1.10"

Expand Down
5 changes: 3 additions & 2 deletions src/accelerators/TPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using EnumX: @enumx
using Scratch: @get_scratch!
using HTTP
using Downloads
using unzip_jll: unzip

const libtpu_dir = Ref{Union{Nothing,String}}(nothing)
const RUNNING_IN_CLOUD_TPU_VM = Ref(false)
Expand Down Expand Up @@ -42,10 +43,10 @@ function download_libtpu_if_needed(path=nothing)
zip_file_path = joinpath(path, "tpu.zip")
tmp_dir = joinpath(path, "tmp")
Downloads.download(
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250415+nightly-py3-none-manylinux_2_31_x86_64.whl",
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250727+nightly-py3-none-manylinux_2_31_x86_64.whl",
zip_file_path,
)
run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
run(`$(unzip()) -qq $(zip_file_path) -d $(tmp_dir)`)
mv(joinpath(tmp_dir, "libtpu", "libtpu.so"), libtpu_path)
rm(tmp_dir; recursive=true)
rm(zip_file_path; recursive=true)
Expand Down
15 changes: 9 additions & 6 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,13 +375,13 @@ end
x = rand(size...)

@testset "outer repeat" begin
@test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...)
@test (@jit repeat(Reactant.to_rarray(x), counts...)) repeat(x, counts...)
end

length(counts) < length(size) && continue

@testset "inner repeat" begin
@test (@jit repeat(Reactant.to_rarray(x); inner=counts)) ==
@test (@jit repeat(Reactant.to_rarray(x); inner=counts))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test (@jit repeat(Reactant.to_rarray(x); inner=counts))
@test (@jit repeat(Reactant.to_rarray(x); inner=counts))

repeat(x; inner=counts)
end
end
Expand Down Expand Up @@ -419,10 +419,13 @@ end
end

@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
a = Reactant.to_rarray(ones(CT, 2))
b = Reactant.to_rarray(ones(CT, 2))
c = Reactant.compile(+, (a, b))(a, b)
@test c == ones(CT, 2) + ones(CT, 2)
# complex f64 not supported on tpu
if CT == ComplexF32 || !contains(string(Reactant.devices()[1]), "Tpu")
a = Reactant.to_rarray(ones(CT, 2))
b = Reactant.to_rarray(ones(CT, 2))
c = Reactant.compile(+, (a, b))(a, b)
@test c == ones(CT, 2) + ones(CT, 2)
end
end

@testset "Scalars" begin
Expand Down
Loading