Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Format and add a format CI #65

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/FormatPR.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: format-pr
on:
schedule:
- cron: '0 0 * * *'
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install JuliaFormatter and format
run: |
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
julia -e 'using JuliaFormatter; format(".")'
# https://github.com/marketplace/actions/create-pull-request
# https://github.com/peter-evans/create-pull-request#reference-example
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v5
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Format .jl files
title: 'Automatic JuliaFormatter.jl run'
branch: auto-juliaformatter-pr
delete-branch: true
labels: formatting, automated pr, no changelog
- name: Check outputs
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"
4 changes: 2 additions & 2 deletions ext/SimpleBatchedNonlinearSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function _init_J_batched(x::AbstractMatrix{T}) where {T}
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
f = Base.Fix2(prob.f, prob.p)
Expand Down Expand Up @@ -74,7 +74,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ)
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ) ./
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
Expand Down
46 changes: 26 additions & 20 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ end

function __init__()
@static if !isdefined(Base, :get_extension)
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin include("../ext/SimpleBatchedNonlinearSolveExt.jl") end
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin
include("../ext/SimpleBatchedNonlinearSolveExt.jl")
end
end
end

Expand All @@ -42,31 +44,35 @@ include("alefeld.jl")

import PrecompileTools

PrecompileTools.@compile_workload begin for T in (Float32, Float64)
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
for alg in (SimpleNewtonRaphson, Halley, Broyden, Klement, SimpleTrustRegion,
SimpleDFSane)
solve(prob_no_brack, alg(), abstol = T(1e-2))
end
PrecompileTools.@compile_workload begin
for T in (Float32, Float64)
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
for alg in (SimpleNewtonRaphson, Halley, Broyden, Klement, SimpleTrustRegion,
SimpleDFSane)
solve(prob_no_brack, alg(), abstol = T(1e-2))
end

#=
for alg in (SimpleNewtonRaphson,)
for u0 in ([1., 1.], StaticArraysCore.SA[1.0, 1.0])
u0 = T.(.1)
probN = NonlinearProblem{false}((u,p) -> u .* u .- p, u0, T(2))
solve(probN, alg(), tol = T(1e-2))
#=
for alg in (SimpleNewtonRaphson,)
for u0 in ([1., 1.], StaticArraysCore.SA[1.0, 1.0])
u0 = T.(.1)
probN = NonlinearProblem{false}((u,p) -> u .* u .- p, u0, T(2))
solve(probN, alg(), tol = T(1e-2))
end
end
end
=#
=#

prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
for alg in (Bisection, Falsi, Ridder, Brent, Alefeld)
solve(prob_brack, alg(), abstol = T(1e-2))
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p,
T.((0.0, 2.0)),
T(2))
for alg in (Bisection, Falsi, Ridder, Brent, Alefeld)
solve(prob_brack, alg(), abstol = T(1e-2))
end
end
end end
end

# DiffEq styled algorithms
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld

end # module
52 changes: 26 additions & 26 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,50 +29,50 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:Dual{T, V, P}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...) where {iip, T, V, P}
iip,
<:Dual{T, V, P}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
retcode = sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {iip, T, V, P}
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
retcode = sol.retcode)
end

# avoid ambiguities
for Alg in [Bisection]
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:Dual{T, V, P}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
<:Dual{T, V, P}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:AbstractArray{
<:Dual{T,
V,
P}
}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
<:AbstractArray{
<:Dual{T,
V,
P},
}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
end
68 changes: 34 additions & 34 deletions src/alefeld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal
struct Alefeld <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem,
alg::Alefeld, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
alg::Alefeld, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
a, b = prob.tspan
c = a - (b - a) / (f(b) - f(a)) * f(a)

fc = f(c)
(a == c || b == c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = a,
right = b)
retcode = ReturnCode.FloatingPointLimit,
left = a,
right = b)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = a,
right = b)
retcode = ReturnCode.Success,
left = a,
right = b)
a, b, d = _bracket(f, a, b, c)
e = zero(a) # Set e as 0 before iteration to avoid a non-value f(e)

Expand All @@ -45,14 +45,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
ē, fc = d, f(c)
(a == c || b == c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = a,
right = b)
retcode = ReturnCode.FloatingPointLimit,
left = a,
right = b)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = a,
right = b)
retcode = ReturnCode.Success,
left = a,
right = b)
ā, b̄, d̄ = _bracket(f, a, b, c)

# The second bracketing block
Expand All @@ -68,14 +68,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
fc = f(c)
(ā == c || b̄ == c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = ā,
right = b̄)
retcode = ReturnCode.Success,
left = ā,
right = b̄)
ā, b̄, d̄ = _bracket(f, ā, b̄, c)

# The third bracketing block
Expand All @@ -91,14 +91,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
fc = f(c)
(ā == c || b̄ == c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = ā,
right = b̄)
retcode = ReturnCode.Success,
left = ā,
right = b̄)
ā, b̄, d = _bracket(f, ā, b̄, c)

# The last bracketing block
Expand All @@ -110,14 +110,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
fc = f(c)
(ā == c || b̄ == c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = ā,
right = b̄)
retcode = ReturnCode.Success,
left = ā,
right = b̄)
a, b, d = _bracket(f, ā, b̄, c)
end
end
Expand All @@ -132,7 +132,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,

# Reuturn solution when run out of max interation
return SciMLBase.build_solution(prob, alg, c, fc; retcode = ReturnCode.MaxIters,
left = a, right = b)
left = a, right = b)
end

# Define subrotine function bracket, check fc before bracket to return solution
Expand Down
18 changes: 9 additions & 9 deletions src/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ function Bisection(; exact_left = false, exact_right = false)
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...;
maxiters = 1000,
kwargs...)
maxiters = 1000,
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
fl, fr = f(left), f(right)

if iszero(fl)
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
end

i = 1
Expand All @@ -38,8 +38,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -60,8 +60,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -74,5 +74,5 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
end

return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
left = left, right = right)
left = left, right = right)
end
Loading