Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit f111bd3

Browse files
committed
Remove the static arrays special casing
1 parent ae0bf10 commit f111bd3

File tree

7 files changed

+25
-25
lines changed

7 files changed

+25
-25
lines changed

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1818
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1919
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2020
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
21+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2122
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2223

2324
[weakdeps]
2425
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2526
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
26-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2727
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2828
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2929

3030
[extensions]
3131
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
3232
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
33-
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
3433
SimpleNonlinearSolveTrackerExt = "Tracker"
3534
SimpleNonlinearSolveZygoteExt = "Zygote"
3635

@@ -40,7 +39,7 @@ AllocCheck = "0.1.1"
4039
Aqua = "0.8"
4140
ArrayInterface = "7.9"
4241
CUDA = "5.2"
43-
ChainRulesCore = "1.22"
42+
ChainRulesCore = "1.23"
4443
ConcreteStructs = "0.2.3"
4544
DiffEqBase = "6.149"
4645
DiffResults = "1.1"
@@ -59,13 +58,14 @@ PrecompileTools = "1.2"
5958
Random = "1.10"
6059
ReTestItems = "1.23"
6160
Reexport = "1.2"
62-
ReverseDiff = "1.15"
61+
ReverseDiff = "1.15.3"
6362
SciMLBase = "2.37.0"
6463
SciMLSensitivity = "7.58"
64+
Setfield = "1.1.1"
6565
StaticArrays = "1.9"
6666
StaticArraysCore = "1.4.2"
6767
Test = "1.10"
68-
Tracker = "0.2.32"
68+
Tracker = "0.2.33"
6969
Zygote = "0.6.69"
7070
julia = "1.10"
7171

ext/SimpleNonlinearSolveStaticArraysExt.jl

Lines changed: 0 additions & 7 deletions
This file was deleted.

src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
2424
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
2525
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
2626
build_solution, isinplace, _unwrap_val
27+
using Setfield: @set!
2728
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
2829
end
2930

src/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,12 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
109109
end
110110
else
111111
# For small problems, nesting ForwardDiff is actually quite fast
112-
_f = Base.Fix2(prob.f, newprob.p)
113112
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
114113
# TODO: Remove once DI has the value_and_pullback_split defined
115-
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(_f, u, p)
114+
_F = @closure (u, p) -> begin
115+
_f = Base.Fix2(prob.f, p)
116+
return __zygote_compute_nlls_vjp(_f, u, p)
117+
end
116118
else
117119
_F = @closure (u, p) -> begin
118120
_f = Base.Fix2(prob.f, p)

src/nlsolve/dfsane.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
7777
α_1 = one(T)
7878
f_1 = fx_norm
7979

80-
history_f_k = if x isa SArray ||
81-
(x isa Number && __is_extension_loaded(Val(:StaticArrays)))
82-
ones(SVector{M, T}) * fx_norm
83-
else
84-
fill(fx_norm, M)
85-
end
80+
history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm :
81+
__history_vec(fx_norm, Val(M))
8682

8783
# Generate the cache
8884
@bb x_cache = similar(x)
@@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
150146
# Store function value
151147
if history_f_k isa SVector
152148
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
149+
elseif history_f_k isa NTuple
150+
@set! history_f_k[mod1(k, M)] = fx_norm_new
153151
else
154152
history_f_k[mod1(k, M)] = fx_norm_new
155153
end
@@ -158,3 +156,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
158156

159157
return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
160158
end
159+
160+
@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M}
161+
# Julia can't specialize here
162+
M 11 && return :(fill(fx_norm, M))
163+
return :(ntuple(Returns(fx_norm), $(M)))
164+
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function value_and_jacobian(
3232

3333
if isinplace(prob)
3434
if cache isa HasAnalyticJacobian
35-
prob.f.jac(J, x, p)
35+
prob.f.jac(J, x, prob.p)
3636
f(y, x)
3737
else
3838
DI.jacobian!(f, y, J, ad, x, cache)

test/core/rootfind_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
SimpleTrustRegion,
4343
(args...; kwargs...) -> SimpleTrustRegion(
4444
args...; nlsolve_update_rule = Val(true), kwargs...))
45-
@testset "AutoDiff: $(nameof(typeof(autodiff))))" for autodiff in (
45+
@testset "AutoDiff: $(nameof(typeof(autodiff)))" for autodiff in (
4646
AutoFiniteDiff(), AutoForwardDiff(), AutoPolyesterForwardDiff())
4747
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
4848
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
@@ -59,7 +59,7 @@ end
5959
end
6060
end
6161

62-
@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
62+
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
6363
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
6464

6565
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -79,7 +79,7 @@ end
7979
end
8080
end
8181

82-
@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
82+
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
8383
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
8484

8585
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -104,7 +104,7 @@ end
104104
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
105105
end
106106

107-
@testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
107+
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
108108
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
109109

110110
probN = NonlinearProblem(quadratic_f, u0, 2.0)

0 commit comments

Comments
 (0)