@@ -43,9 +43,8 @@ function SciMLBase.__solve(prob::NonlinearProblem,
43
43
f = Base. Fix2 (prob. f, prob. p)
44
44
x = float (prob. u0)
45
45
fx = f (x)
46
- # fx = float(prob.u0)
47
- if ! isa (fx, Number) || ! isa (x, Number)
48
- error (" Halley currently only supports scalar-valued single-variable functions" )
46
+ if isa (x, AbstractArray)
47
+ n = length (x)
49
48
end
50
49
T = typeof (x)
51
50
@@ -65,22 +64,45 @@ function SciMLBase.__solve(prob::NonlinearProblem,
65
64
66
65
for i in 1 : maxiters
67
66
if alg_autodiff (alg)
68
- fx = f (x)
69
- dfdx (x) = ForwardDiff. derivative (f, x)
70
- dfx = dfdx (x)
71
- d2fx = ForwardDiff. derivative (dfdx, x)
67
+ if isa (x, Number)
68
+ fx = f (x)
69
+ dfx = ForwardDiff. derivative (f, x)
70
+ d2fx = ForwardDiff. derivative (x -> ForwardDiff. derivative (f, x), x)
71
+ else
72
+ fx = f (x)
73
+ dfx = ForwardDiff. jacobian (f, x)
74
+ d2fx = ForwardDiff. jacobian (x -> ForwardDiff. jacobian (f, x), x)
75
+ ai = - (dfx \ fx)
76
+ A = reshape (d2fx * ai, (n, n))
77
+ bi = (dfx) \ (A * ai)
78
+ ci = (ai .* ai) ./ (ai .+ (0.5 .* bi))
79
+ end
72
80
else
73
- fx = f (x)
74
- dfx = FiniteDiff. finite_difference_derivative (f, x, diff_type (alg), eltype (x),
75
- fx)
76
- d2fx = FiniteDiff. finite_difference_derivative (x -> FiniteDiff. finite_difference_derivative (f,
77
- x),
78
- x, diff_type (alg), eltype (x), fx)
81
+ if isa (x, Number)
82
+ fx = f (x)
83
+ dfx = FiniteDiff. finite_difference_derivative (f, x, diff_type (alg), eltype (x))
84
+ d2fx = FiniteDiff. finite_difference_derivative (x -> FiniteDiff. finite_difference_derivative (f, x), x,
85
+ diff_type (alg), eltype (x))
86
+ else
87
+ fx = f (x)
88
+ dfx = FiniteDiff. finite_difference_jacobian (f, x, diff_type (alg), eltype (x))
89
+ d2fx = FiniteDiff. finite_difference_jacobian (x -> FiniteDiff. finite_difference_jacobian (f, x), x,
90
+ diff_type (alg), eltype (x))
91
+ ai = - (dfx \ fx)
92
+ A = reshape (d2fx * ai, (n, n))
93
+ bi = (dfx) \ (A * ai)
94
+ ci = (ai .* ai) ./ (ai .+ (0.5 .* bi))
95
+ end
79
96
end
80
97
iszero (fx) &&
81
98
return SciMLBase. build_solution (prob, alg, x, fx; retcode = ReturnCode. Success)
82
- Δx = (2 * dfx^ 2 - fx * d2fx) \ (2 fx * dfx)
83
- x -= Δx
99
+ if isa (x, Number)
100
+ Δx = (2 * dfx^ 2 - fx * d2fx) \ (2 fx * dfx)
101
+ x -= Δx
102
+ else
103
+ Δx = ci
104
+ x += Δx
105
+ end
84
106
if isapprox (x, xo, atol = atol, rtol = rtol)
85
107
return SciMLBase. build_solution (prob, alg, x, fx; retcode = ReturnCode. Success)
86
108
end
0 commit comments