Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit dc70eb6

Browse files
Merge pull request #65 from avik-pal/ap/format
Format and add a format CI
2 parents 4346380 + 9bfde85 commit dc70eb6

18 files changed

+261
-223
lines changed

.github/workflows/FormatPR.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: format-pr
2+
on:
3+
schedule:
4+
- cron: '0 0 * * *'
5+
jobs:
6+
build:
7+
runs-on: ubuntu-latest
8+
steps:
9+
- uses: actions/checkout@v3
10+
- name: Install JuliaFormatter and format
11+
run: |
12+
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
13+
julia -e 'using JuliaFormatter; format(".")'
14+
# https://github.com/marketplace/actions/create-pull-request
15+
# https://github.com/peter-evans/create-pull-request#reference-example
16+
- name: Create Pull Request
17+
id: cpr
18+
uses: peter-evans/create-pull-request@v5
19+
with:
20+
token: ${{ secrets.GITHUB_TOKEN }}
21+
commit-message: Format .jl files
22+
title: 'Automatic JuliaFormatter.jl run'
23+
branch: auto-juliaformatter-pr
24+
delete-branch: true
25+
labels: formatting, automated pr, no changelog
26+
- name: Check outputs
27+
run: |
28+
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
29+
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"

ext/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function _init_J_batched(x::AbstractMatrix{T}) where {T}
3131
end
3232

3333
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
34-
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
34+
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
3535
tc = alg.termination_condition
3636
mode = DiffEqBase.get_termination_mode(tc)
3737
f = Base.Fix2(prob.f, prob.p)
@@ -74,7 +74,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
7474
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ)
7575
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ) ./
7676
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
77-
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))
77+
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))
7878

7979
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
8080
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)

src/SimpleNonlinearSolve.jl

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ end
1616

1717
function __init__()
1818
@static if !isdefined(Base, :get_extension)
19-
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin include("../ext/SimpleBatchedNonlinearSolveExt.jl") end
19+
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin
20+
include("../ext/SimpleBatchedNonlinearSolveExt.jl")
21+
end
2022
end
2123
end
2224

@@ -42,31 +44,35 @@ include("alefeld.jl")
4244

4345
import PrecompileTools
4446

45-
PrecompileTools.@compile_workload begin for T in (Float32, Float64)
46-
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
47-
for alg in (SimpleNewtonRaphson, Halley, Broyden, Klement, SimpleTrustRegion,
48-
SimpleDFSane)
49-
solve(prob_no_brack, alg(), abstol = T(1e-2))
50-
end
47+
PrecompileTools.@compile_workload begin
48+
for T in (Float32, Float64)
49+
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
50+
for alg in (SimpleNewtonRaphson, Halley, Broyden, Klement, SimpleTrustRegion,
51+
SimpleDFSane)
52+
solve(prob_no_brack, alg(), abstol = T(1e-2))
53+
end
5154

52-
#=
53-
for alg in (SimpleNewtonRaphson,)
54-
for u0 in ([1., 1.], StaticArraysCore.SA[1.0, 1.0])
55-
u0 = T.(.1)
56-
probN = NonlinearProblem{false}((u,p) -> u .* u .- p, u0, T(2))
57-
solve(probN, alg(), tol = T(1e-2))
55+
#=
56+
for alg in (SimpleNewtonRaphson,)
57+
for u0 in ([1., 1.], StaticArraysCore.SA[1.0, 1.0])
58+
u0 = T.(.1)
59+
probN = NonlinearProblem{false}((u,p) -> u .* u .- p, u0, T(2))
60+
solve(probN, alg(), tol = T(1e-2))
61+
end
5862
end
59-
end
60-
=#
63+
=#
6164

62-
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
63-
for alg in (Bisection, Falsi, Ridder, Brent, Alefeld)
64-
solve(prob_brack, alg(), abstol = T(1e-2))
65+
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p,
66+
T.((0.0, 2.0)),
67+
T(2))
68+
for alg in (Bisection, Falsi, Ridder, Brent, Alefeld)
69+
solve(prob_brack, alg(), abstol = T(1e-2))
70+
end
6571
end
66-
end end
72+
end
6773

6874
# DiffEq styled algorithms
6975
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
70-
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
76+
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
7177

7278
end # module

src/ad.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,50 +29,50 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2929
end
3030

3131
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
32-
iip,
33-
<:Dual{T, V, P}},
34-
alg::AbstractSimpleNonlinearSolveAlgorithm,
35-
args...; kwargs...) where {iip, T, V, P}
32+
iip,
33+
<:Dual{T, V, P}},
34+
alg::AbstractSimpleNonlinearSolveAlgorithm,
35+
args...; kwargs...) where {iip, T, V, P}
3636
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3737
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
38-
retcode = sol.retcode)
38+
retcode = sol.retcode)
3939
end
4040
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
41-
iip,
42-
<:AbstractArray{<:Dual{T, V, P}}},
43-
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
44-
kwargs...) where {iip, T, V, P}
41+
iip,
42+
<:AbstractArray{<:Dual{T, V, P}}},
43+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
44+
kwargs...) where {iip, T, V, P}
4545
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
4646
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
47-
retcode = sol.retcode)
47+
retcode = sol.retcode)
4848
end
4949

5050
# avoid ambiguities
5151
for Alg in [Bisection]
5252
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
53-
<:Dual{T, V, P}},
54-
alg::$Alg, args...;
55-
kwargs...) where {uType, iip, T, V, P}
53+
<:Dual{T, V, P}},
54+
alg::$Alg, args...;
55+
kwargs...) where {uType, iip, T, V, P}
5656
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5757
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
58-
sol.resid; retcode = sol.retcode,
59-
left = Dual{T, V, P}(sol.left, partials),
60-
right = Dual{T, V, P}(sol.right, partials))
58+
sol.resid; retcode = sol.retcode,
59+
left = Dual{T, V, P}(sol.left, partials),
60+
right = Dual{T, V, P}(sol.right, partials))
6161
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
6262
end
6363
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
64-
<:AbstractArray{
65-
<:Dual{T,
66-
V,
67-
P}
68-
}},
69-
alg::$Alg, args...;
70-
kwargs...) where {uType, iip, T, V, P}
64+
<:AbstractArray{
65+
<:Dual{T,
66+
V,
67+
P},
68+
}},
69+
alg::$Alg, args...;
70+
kwargs...) where {uType, iip, T, V, P}
7171
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
7272
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
73-
sol.resid; retcode = sol.retcode,
74-
left = Dual{T, V, P}(sol.left, partials),
75-
right = Dual{T, V, P}(sol.right, partials))
73+
sol.resid; retcode = sol.retcode,
74+
left = Dual{T, V, P}(sol.left, partials),
75+
right = Dual{T, V, P}(sol.right, partials))
7676
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
7777
end
7878
end

src/alefeld.jl

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,24 @@ algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal
99
struct Alefeld <: AbstractBracketingAlgorithm end
1010

1111
function SciMLBase.solve(prob::IntervalNonlinearProblem,
12-
alg::Alefeld, args...; abstol = nothing,
13-
reltol = nothing,
14-
maxiters = 1000, kwargs...)
12+
alg::Alefeld, args...; abstol = nothing,
13+
reltol = nothing,
14+
maxiters = 1000, kwargs...)
1515
f = Base.Fix2(prob.f, prob.p)
1616
a, b = prob.tspan
1717
c = a - (b - a) / (f(b) - f(a)) * f(a)
1818

1919
fc = f(c)
2020
(a == c || b == c) &&
2121
return SciMLBase.build_solution(prob, alg, c, fc;
22-
retcode = ReturnCode.FloatingPointLimit,
23-
left = a,
24-
right = b)
22+
retcode = ReturnCode.FloatingPointLimit,
23+
left = a,
24+
right = b)
2525
iszero(fc) &&
2626
return SciMLBase.build_solution(prob, alg, c, fc;
27-
retcode = ReturnCode.Success,
28-
left = a,
29-
right = b)
27+
retcode = ReturnCode.Success,
28+
left = a,
29+
right = b)
3030
a, b, d = _bracket(f, a, b, c)
3131
e = zero(a) # Set e as 0 before iteration to avoid a non-value f(e)
3232

@@ -45,14 +45,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
4545
ē, fc = d, f(c)
4646
(a == c || b == c) &&
4747
return SciMLBase.build_solution(prob, alg, c, fc;
48-
retcode = ReturnCode.FloatingPointLimit,
49-
left = a,
50-
right = b)
48+
retcode = ReturnCode.FloatingPointLimit,
49+
left = a,
50+
right = b)
5151
iszero(fc) &&
5252
return SciMLBase.build_solution(prob, alg, c, fc;
53-
retcode = ReturnCode.Success,
54-
left = a,
55-
right = b)
53+
retcode = ReturnCode.Success,
54+
left = a,
55+
right = b)
5656
ā, b̄, d̄ = _bracket(f, a, b, c)
5757

5858
# The second bracketing block
@@ -68,14 +68,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
6868
fc = f(c)
6969
(ā == c ||== c) &&
7070
return SciMLBase.build_solution(prob, alg, c, fc;
71-
retcode = ReturnCode.FloatingPointLimit,
72-
left = ā,
73-
right = b̄)
71+
retcode = ReturnCode.FloatingPointLimit,
72+
left = ā,
73+
right = b̄)
7474
iszero(fc) &&
7575
return SciMLBase.build_solution(prob, alg, c, fc;
76-
retcode = ReturnCode.Success,
77-
left = ā,
78-
right = b̄)
76+
retcode = ReturnCode.Success,
77+
left = ā,
78+
right = b̄)
7979
ā, b̄, d̄ = _bracket(f, ā, b̄, c)
8080

8181
# The third bracketing block
@@ -91,14 +91,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
9191
fc = f(c)
9292
(ā == c ||== c) &&
9393
return SciMLBase.build_solution(prob, alg, c, fc;
94-
retcode = ReturnCode.FloatingPointLimit,
95-
left = ā,
96-
right = b̄)
94+
retcode = ReturnCode.FloatingPointLimit,
95+
left = ā,
96+
right = b̄)
9797
iszero(fc) &&
9898
return SciMLBase.build_solution(prob, alg, c, fc;
99-
retcode = ReturnCode.Success,
100-
left = ā,
101-
right = b̄)
99+
retcode = ReturnCode.Success,
100+
left = ā,
101+
right = b̄)
102102
ā, b̄, d = _bracket(f, ā, b̄, c)
103103

104104
# The last bracketing block
@@ -110,14 +110,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
110110
fc = f(c)
111111
(ā == c ||== c) &&
112112
return SciMLBase.build_solution(prob, alg, c, fc;
113-
retcode = ReturnCode.FloatingPointLimit,
114-
left = ā,
115-
right = b̄)
113+
retcode = ReturnCode.FloatingPointLimit,
114+
left = ā,
115+
right = b̄)
116116
iszero(fc) &&
117117
return SciMLBase.build_solution(prob, alg, c, fc;
118-
retcode = ReturnCode.Success,
119-
left = ā,
120-
right = b̄)
118+
retcode = ReturnCode.Success,
119+
left = ā,
120+
right = b̄)
121121
a, b, d = _bracket(f, ā, b̄, c)
122122
end
123123
end
@@ -132,7 +132,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
132132

133133
# Reuturn solution when run out of max interation
134134
return SciMLBase.build_solution(prob, alg, c, fc; retcode = ReturnCode.MaxIters,
135-
left = a, right = b)
135+
left = a, right = b)
136136
end
137137

138138
# Define subrotine function bracket, check fc before bracket to return solution

src/bisection.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ function Bisection(; exact_left = false, exact_right = false)
2020
end
2121

2222
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...;
23-
maxiters = 1000,
24-
kwargs...)
23+
maxiters = 1000,
24+
kwargs...)
2525
f = Base.Fix2(prob.f, prob.p)
2626
left, right = prob.tspan
2727
fl, fr = f(left), f(right)
2828

2929
if iszero(fl)
3030
return SciMLBase.build_solution(prob, alg, left, fl;
31-
retcode = ReturnCode.ExactSolutionLeft, left = left,
32-
right = right)
31+
retcode = ReturnCode.ExactSolutionLeft, left = left,
32+
right = right)
3333
end
3434

3535
i = 1
@@ -38,8 +38,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
3838
mid = (left + right) / 2
3939
(mid == left || mid == right) &&
4040
return SciMLBase.build_solution(prob, alg, left, fl;
41-
retcode = ReturnCode.FloatingPointLimit,
42-
left = left, right = right)
41+
retcode = ReturnCode.FloatingPointLimit,
42+
left = left, right = right)
4343
fm = f(mid)
4444
if iszero(fm)
4545
right = mid
@@ -60,8 +60,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
6060
mid = (left + right) / 2
6161
(mid == left || mid == right) &&
6262
return SciMLBase.build_solution(prob, alg, left, fl;
63-
retcode = ReturnCode.FloatingPointLimit,
64-
left = left, right = right)
63+
retcode = ReturnCode.FloatingPointLimit,
64+
left = left, right = right)
6565
fm = f(mid)
6666
if iszero(fm)
6767
right = mid
@@ -74,5 +74,5 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
7474
end
7575

7676
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
77-
left = left, right = right)
77+
left = left, right = right)
7878
end

0 commit comments

Comments
 (0)