Skip to content

Commit d97a28a

Browse files
committed
feat: handle mixed types in body of template expression
1 parent 7d937c7 commit d97a28a

File tree

4 files changed

+88
-6
lines changed

4 files changed

+88
-6
lines changed

src/ComposableExpression.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,22 +235,45 @@ end
235235
# Basically we want to vectorize every single operation on ValidVector,
236236
# so that the user can use it easily.
237237

238+
function _apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N}
239+
vx = map(_get_value, x)
240+
safe_op = get_safe_op(op)
241+
result = safe_op.(vx...)
242+
return ValidVector(result, is_valid_array(result))
243+
end
244+
238245
function apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N}
239246
if all(_is_valid, x)
240-
vx = map(_get_value, x)
241-
safe_op = get_safe_op(op)
242-
result = safe_op.(vx...)
243-
return ValidVector(result, is_valid_array(result))
247+
return _apply_operator(op, x...)
244248
else
245249
example_vector =
246250
something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector
247-
return ValidVector(_get_value(example_vector), false)
251+
expected_return_type = Base.promote_op(
252+
_apply_operator, typeof(op), map(typeof, x)...
253+
)
254+
if expected_return_type !== Union{} &&
255+
expected_return_type <: ValidVector{<:AbstractArray}
256+
return ValidVector(
257+
_match_eltype(expected_return_type, example_vector.x), false
258+
)::expected_return_type
259+
else
260+
return ValidVector(example_vector.x, false)
261+
end
248262
end
249263
end
250264
_is_valid(x::ValidVector) = x.valid
251265
_is_valid(x) = true
252266
_get_value(x::ValidVector) = x.x
253267
_get_value(x) = x
268+
function _match_eltype(
269+
::Type{<:ValidVector{<:AbstractArray{T1}}}, x::AbstractArray{T2}
270+
) where {T1,T2}
271+
if T1 == T2
272+
return x
273+
else
274+
return Base.Fix1(convert, T1).(x)
275+
end
276+
end
254277

255278
#! format: off
256279
# First, binary operators:

src/TemplateExpression.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,18 @@ function HOF.make_prefix(::TemplateExpression, ::AbstractOptions, ::Dataset)
631631
return ""
632632
end
633633

634+
function _match_input_eltype(
635+
::Type{<:AbstractMatrix{T1}}, result::AbstractVector{T2}
636+
) where {T1,T2}
637+
if T1 != T2 && T1 <: AbstractFloat && T2 <: AbstractFloat
638+
# Just to handle cases where the user might write
639+
# 0.5 in their template spec, but the data is Float32.
640+
return Base.Fix1(convert, T1).(result)
641+
else
642+
return result
643+
end
644+
end
645+
634646
@stable(
635647
default_mode = "disable",
636648
default_union_limit = 2,
@@ -657,7 +669,7 @@ end
657669
extra_args...,
658670
map(x -> ValidVector(copy(x), true), eachrow(cX)),
659671
)
660-
return result.x, result.valid
672+
return _match_input_eltype(typeof(cX), result.x), result.valid
661673
end
662674
function (ex::TemplateExpression)(
663675
X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...

test/test_composable_expression.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,23 @@ end
129129
x2_val = ValidVector([1.0, 2.0], false)
130130
@test ex(x1_val, x2_val).valid == false
131131
end
132+
133+
@testitem "ValidVector operations with Union{} return type" tags = [:part2] begin
134+
using SymbolicRegression: ValidVector
135+
using SymbolicRegression.ComposableExpressionModule: apply_operator
136+
137+
error_op(::Any, ::Any) = error("This should cause Union{} inference")
138+
139+
x = ValidVector([1.0, 2.0], false)
140+
y = ValidVector([3.0, 4.0], false)
141+
142+
result = apply_operator(error_op, x, y)
143+
@test result isa ValidVector
144+
@test !result.valid
145+
@test result.x == [1.0, 2.0]
146+
147+
a = ValidVector(Float32[1.0, 2.0], false)
148+
b = 1.0
149+
result2 = apply_operator(*, a, b)
150+
@test result2 isa ValidVector{<:AbstractArray{Float64}}
151+
end

test/test_template_expression.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,30 @@ end
657657
parse_guesses(PopMember{Float64,Float64}, [bad_guess], [dataset], options)
658658
)
659659
end
660+
661+
@testitem "Test Float32/Float64 type conversion in TemplateExpression" tags = [:part2] begin
662+
using SymbolicRegression
663+
664+
template = @template_spec(expressions = (f,)) do x1, x2
665+
0.5 * f(x1, x2) # 0.5 is Float64 literal
666+
end
667+
668+
options = Options(; binary_operators=[+, *, /, -], expression_spec=template)
669+
x1 = ComposableExpression(Node{Float32}(; feature=1); operators=options.operators)
670+
x2 = ComposableExpression(Node{Float32}(; feature=2); operators=options.operators)
671+
f_expr = x1 + x2
672+
673+
template_expr = TemplateExpression(
674+
(; f=f_expr); structure=template.structure, operators=options.operators
675+
)
676+
677+
X = Float32[1.0 2.0; 3.0 4.0]
678+
result = template_expr(X)
679+
@test result isa Vector{Float32}
680+
681+
y = Float32[2.0, 3.0]
682+
dataset = Dataset(X, y)
683+
loss = eval_loss(template_expr, dataset, options)
684+
@test loss isa Float32
685+
@test loss 0.0
686+
end

0 commit comments

Comments
 (0)