Skip to content

Commit 432c334

Browse files
authored
Merge pull request #482 from MilesCranmer/auto-convert-template
feat: handle mixed types in body of template expression
2 parents d05f559 + 1086320 commit 432c334

File tree

5 files changed

+134
-8
lines changed

5 files changed

+134
-8
lines changed

.github/workflows/CI.yml

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,30 @@ jobs:
8686
julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)'
8787
julia --color=yes coverage.jl
8888
shell: bash
89-
- uses: codecov/codecov-action@v5
89+
- name: "Upload coverage artifacts"
90+
uses: actions/upload-artifact@v4
91+
with:
92+
name: coverage-${{ matrix.julia-version }}-${{ matrix.os }}-${{ matrix.test }}
93+
path: lcov.info
94+
retention-days: 1
95+
96+
upload-coverage:
97+
name: Upload Coverage to Codecov
98+
needs: test
99+
runs-on: ubuntu-latest
100+
# Only run on pushes to master or pull requests
101+
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.repository
102+
steps:
103+
- uses: actions/checkout@v5
104+
- name: Download all coverage artifacts
105+
uses: actions/download-artifact@v4
106+
with:
107+
pattern: coverage-*
108+
path: coverage
109+
- name: Upload to Codecov
110+
uses: codecov/codecov-action@v5
90111
with:
91112
token: ${{ secrets.CODECOV_TOKEN }}
92-
files: lcov.info
113+
directory: ./coverage
114+
fail_ci_if_error: true
115+
verbose: true

src/ComposableExpression.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,22 +260,45 @@ end
260260
# Basically we want to vectorize every single operation on ValidVector,
261261
# so that the user can use it easily.
262262

263+
function _apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N}
264+
vx = map(_get_value, x)
265+
safe_op = get_safe_op(op)
266+
result = safe_op.(vx...)
267+
return ValidVector(result, is_valid_array(result))
268+
end
269+
263270
function apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N}
264271
if all(_is_valid, x)
265-
vx = map(_get_value, x)
266-
safe_op = get_safe_op(op)
267-
result = safe_op.(vx...)
268-
return ValidVector(result, is_valid_array(result))
272+
return _apply_operator(op, x...)
269273
else
270274
example_vector =
271275
something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector
272-
return ValidVector(_get_value(example_vector), false)
276+
expected_return_type = Base.promote_op(
277+
_apply_operator, typeof(op), map(typeof, x)...
278+
)
279+
if expected_return_type !== Union{} &&
280+
expected_return_type <: ValidVector{<:AbstractArray}
281+
return ValidVector(
282+
_match_eltype(expected_return_type, example_vector.x), false
283+
)::expected_return_type
284+
else
285+
return ValidVector(example_vector.x, false)
286+
end
273287
end
274288
end
275289
_is_valid(x::ValidVector) = x.valid
276290
_is_valid(x) = true
277291
_get_value(x::ValidVector) = x.x
278292
_get_value(x) = x
293+
function _match_eltype(
294+
::Type{<:ValidVector{<:AbstractArray{T1}}}, x::AbstractArray{T2}
295+
) where {T1,T2}
296+
if T1 == T2
297+
return x
298+
else
299+
return Base.Fix1(convert, T1).(x)
300+
end
301+
end
279302

280303
struct ValidVectorMixError <: Exception end
281304
struct ValidVectorAccessError <: Exception end

src/TemplateExpression.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,18 @@ and this would automatically handle the validity and vectorization.
665665
)
666666
end
667667

668+
function _match_input_eltype(
669+
::Type{<:AbstractMatrix{T1}}, result::AbstractVector{T2}
670+
) where {T1,T2}
671+
if T1 != T2 && T1 <: AbstractFloat && T2 <: AbstractFloat
672+
# Just to handle cases where the user might write
673+
# 0.5 in their template spec, but the data is Float32.
674+
return Base.Fix1(convert, T1).(result)
675+
else
676+
return result
677+
end
678+
end
679+
668680
@stable(
669681
default_mode = "disable",
670682
default_union_limit = 2,
@@ -695,7 +707,7 @@ end
695707
if !(result isa ValidVector)
696708
throw(TemplateReturnError())
697709
end
698-
return result.x, result.valid
710+
return _match_input_eltype(typeof(cX), result.x), result.valid
699711
end
700712
function (ex::TemplateExpression)(
701713
X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...

test/test_composable_expression.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,36 @@ end
182182
@test ex(ValidVector([1.0, 1.0], true), 2.0).x [3.0, 3.0]
183183
@test ex(ValidVector([1.0, 1.0], false), 2.0).valid == false
184184
end
185+
186+
@testitem "ValidVector operations with Union{} return type" tags = [:part2] begin
187+
using SymbolicRegression: ValidVector
188+
using SymbolicRegression.ComposableExpressionModule: apply_operator, _match_eltype
189+
190+
error_op(::Any, ::Any) = error("This should cause Union{} inference")
191+
192+
x = ValidVector([1.0, 2.0], false)
193+
y = ValidVector([3.0, 4.0], false)
194+
195+
result = apply_operator(error_op, x, y)
196+
@test result isa ValidVector
197+
@test !result.valid
198+
@test result.x == [1.0, 2.0]
199+
200+
a = ValidVector(Float32[1.0, 2.0], false)
201+
b = 1.0
202+
result2 = apply_operator(*, a, b)
203+
@test result2 isa ValidVector{<:AbstractArray{Float64}}
204+
205+
# Test apply_operator when all inputs are valid
206+
valid_x = ValidVector([1.0, 2.0], true)
207+
valid_y = ValidVector([3.0, 4.0], true)
208+
valid_result = apply_operator(+, valid_x, valid_y)
209+
@test valid_result.valid == true
210+
@test valid_result.x [4.0, 6.0]
211+
212+
# cover _match_eltype
213+
arr = [1.0, 2.0]
214+
@test _match_eltype(ValidVector{Vector{Float64}}, arr) === arr # Same type
215+
arr_f32 = Float32[1.0, 2.0]
216+
@test _match_eltype(ValidVector{Vector{Float64}}, arr_f32) isa Vector{Float64} # Different type
217+
end

test/test_template_expression.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,38 @@ end
692692
@test contains(msg, "ValidVector is required")
693693
@test contains(msg, "ValidVector(my_data, computation_is_valid)")
694694
end
695+
696+
@testitem "Test Float32/Float64 type conversion in TemplateExpression" tags = [:part2] begin
697+
using SymbolicRegression
698+
using SymbolicRegression: eval_loss
699+
using SymbolicRegression.TemplateExpressionModule: _match_input_eltype
700+
701+
template = @template_spec(expressions = (f,)) do x1, x2
702+
0.5 * f(x1, x2) # 0.5 is Float64 literal
703+
end
704+
705+
options = Options(; binary_operators=[+, *, /, -], expression_spec=template)
706+
x1 = ComposableExpression(Node{Float32}(; feature=1); operators=options.operators)
707+
x2 = ComposableExpression(Node{Float32}(; feature=2); operators=options.operators)
708+
f_expr = x1 + x2
709+
710+
template_expr = TemplateExpression(
711+
(; f=f_expr); structure=template.structure, operators=options.operators
712+
)
713+
714+
X = Float32[1.0 2.0; 3.0 4.0]
715+
result = template_expr(X)
716+
@test result isa Vector{Float32}
717+
718+
y = Float32[2.0, 3.0]
719+
dataset = Dataset(X, y)
720+
loss = eval_loss(template_expr, dataset, options)
721+
@test loss isa Float32
722+
@test loss 0.0
723+
724+
# Test _match_input_eltype coverage (covers lines 675-676)
725+
result_f64 = [1.0, 2.0]
726+
@test _match_input_eltype(Matrix{Float64}, result_f64) === result_f64 # Same type
727+
result_int = [1, 2]
728+
@test _match_input_eltype(Matrix{Float64}, result_int) === result_int # Non-float type
729+
end

0 commit comments

Comments
 (0)