Skip to content

Commit 191a237

Browse files
Merge pull request #262 from avik-pal/ap/banded
Special Case for Banded Matrices
2 parents f1bb4a7 + 345ec2a commit 191a237

File tree

7 files changed

+32
-8
lines changed

7 files changed

+32
-8
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,17 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2525
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2626

2727
[weakdeps]
28+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
2829
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
2930
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3031

3132
[extensions]
33+
NonlinearSolveBandedMatricesExt = "BandedMatrices"
3234
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
3335
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
3436

3537
[compat]
38+
BandedMatrices = "1"
3639
ADTypes = "0.2"
3740
ArrayInterface = "6.0.24, 7"
3841
ConcreteStructs = "0.2"
@@ -58,6 +61,7 @@ Zygote = "0.6"
5861
julia = "1.9"
5962

6063
[extras]
64+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
6165
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
6266
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6367
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
@@ -77,4 +81,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7781
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7882

7983
[targets]
80-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath"]
84+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices"]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module NonlinearSolveBandedMatricesExt
2+
3+
using BandedMatrices, LinearAlgebra, NonlinearSolve, SparseArrays
4+
5+
# This is used if we vcat a Banded Jacobian with a Diagonal Matrix in Levenberg
6+
@inline NonlinearSolve._vcat(B::BandedMatrix, D::Diagonal) = vcat(sparse(B), D)
7+
8+
end

src/levenberg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
208208
rhs_tmp = nothing
209209
else
210210
# Preserve Types
211-
mat_tmp = vcat(J, DᵀD)
211+
mat_tmp = _vcat(J, DᵀD)
212212
fill!(mat_tmp, zero(eltype(u)))
213213
rhs_tmp = vcat(_vec(fu1), _vec(u))
214214
fill!(rhs_tmp, zero(eltype(u)))

src/utils.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,17 @@ function _try_factorize_and_check_singular!(linsolve, X)
257257
end
258258
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false
259259

260-
_reshape(x, args...) = reshape(x, args...)
261-
_reshape(x::Number, args...) = x
260+
@inline _reshape(x, args...) = reshape(x, args...)
261+
@inline _reshape(x::Number, args...) = x
262262

263263
@generated function _axpy!(α, x, y)
264264
hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y))
265265
return :(@. y += α * x)
266266
end
267267

268-
_needs_square_A(_, ::Number) = true
269-
_needs_square_A(_, ::StaticArray) = true
270-
_needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)
268+
@inline _needs_square_A(_, ::Number) = true
269+
@inline _needs_square_A(_, ::StaticArray) = true
270+
@inline _needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)
271+
272+
# Define special concatenation for certain Array combinations
273+
@inline _vcat(x, y) = vcat(x, y)

test/GPU/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
34
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
45

56
[compat]
67
CUDA = "5"
8+
LinearSolve = "2"
79
NonlinearSolve = "2"

test/gpu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using CUDA, NonlinearSolve
1+
using CUDA, NonlinearSolve, LinearSolve
22

33
CUDA.allowscalar(false)
44

test/misc.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Miscellaneous Tests
2+
using BandedMatrices, LinearAlgebra, NonlinearSolve, SparseArrays, Test
3+
4+
b = BandedMatrix(Ones(5, 5), (1, 1))
5+
d = Diagonal(ones(5, 5))
6+
7+
@test NonlinearSolve._vcat(b, d) == vcat(sparse(b), d)

0 commit comments

Comments
 (0)