@@ -274,22 +274,12 @@ function prediction_warn()
274
274
@warn " Evaluation failed either due to NaNs detected or due to unfinished search. Using 0s for prediction."
275
275
end
276
276
277
- @inline function wrap_units (v, y_units, i:: Integer )
278
- if y_units === nothing
279
- return v
280
- else
281
- return (yi -> Quantity (yi, y_units[i])). (v)
282
- end
283
- end
284
- @inline function wrap_units (v, y_units, :: Nothing )
285
- if y_units === nothing
286
- return v
287
- else
288
- return (yi -> Quantity (yi, y_units)). (v)
289
- end
290
- end
277
+ wrap_units (v, :: Nothing , :: Integer ) = v
278
+ wrap_units (v, :: Nothing , :: Nothing ) = v
279
+ wrap_units (v, y_units, i:: Integer ) = (yi -> Quantity (yi, y_units[i])). (v)
280
+ wrap_units (v, y_units, :: Nothing ) = (yi -> Quantity (yi, y_units)). (v)
291
281
292
- function prediction_fallback (:: Type{T} , m :: SRRegressor , Xnew_t, fitresult) where {T}
282
+ function prediction_fallback (:: Type{T} , :: SRRegressor , Xnew_t, fitresult, _ ) where {T}
293
283
prediction_warn ()
294
284
out = fill! (similar (Xnew_t, T, axes (Xnew_t, 2 )), zero (T))
295
285
return wrap_units (out, fitresult. y_units, nothing )
@@ -303,11 +293,11 @@ function prediction_fallback(
303
293
fill! (similar (Xnew_t, T, axes (Xnew_t, 2 )), zero (T)), fitresult. y_units, i
304
294
) for i in 1 : (fitresult. num_targets)
305
295
]
306
- out_matrix = reduce ( hcat, out_cols)
296
+ out_matrix = hcat ( out_cols... )
307
297
if ! fitresult. y_is_table
308
298
return out_matrix
309
299
else
310
- return MMI. table (out_matrix; names= fitresult. y_variable_names, prototype= prototype )
300
+ return MMI. table (out_matrix; names= fitresult. y_variable_names, prototype)
311
301
end
312
302
end
313
303
@@ -344,50 +334,58 @@ function MMI.fitted_params(m::AbstractSRRegressor, fitresult)
344
334
)
345
335
end
346
336
347
- function MMI. predict (m:: SRRegressor , fitresult, Xnew)
348
- params = full_report (m, fitresult; v_with_strings= Val (false ))
349
- Xnew_t, variable_names, X_units = get_matrix_and_info (Xnew, m. dimensions_type)
350
- T = promote_type (eltype (Xnew_t), fitresult. types. T)
351
- if length (params. equations) == 0
352
- return prediction_fallback (T, m, Xnew_t, fitresult)
353
- end
354
- X_units_clean = clean_units (X_units)
355
- validate_variable_names (variable_names, fitresult)
356
- validate_units (X_units_clean, fitresult. X_units)
357
- eq = params. equations[params. best_idx]
358
- out, completed = eval_tree_array (eq, Xnew_t, fitresult. options)
359
- if ! completed
360
- return prediction_fallback (T, m, Xnew_t, fitresult)
337
+ function eval_tree_mlj (
338
+ tree:: Node , X_t, m:: AbstractSRRegressor , :: Type{T} , fitresult, i, prototype
339
+ ) where {T}
340
+ out, completed = eval_tree_array (tree, X_t, fitresult. options)
341
+ if completed
342
+ return wrap_units (out, fitresult. y_units, i)
361
343
else
362
- return wrap_units (out, fitresult. y_units, nothing )
344
+ return prediction_fallback (T, m, X_t, fitresult, prototype )
363
345
end
364
346
end
365
- function MMI. predict (m:: MultitargetSRRegressor , fitresult, Xnew)
347
+
348
+ function MMI. predict (m:: M , fitresult, Xnew; idx= nothing ) where {M<: AbstractSRRegressor }
349
+ if Xnew isa NamedTuple && (haskey (Xnew, :idx ) || haskey (Xnew, :data ))
350
+ @assert (
351
+ haskey (Xnew, :idx ) && haskey (Xnew, :data ) && length (keys (Xnew)) == 2 ,
352
+ " If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`."
353
+ )
354
+ return MMI. predict (m, fitresult, Xnew. data; idx= Xnew. idx)
355
+ end
356
+
366
357
params = full_report (m, fitresult; v_with_strings= Val (false ))
367
358
prototype = MMI. istable (Xnew) ? Xnew : nothing
368
359
Xnew_t, variable_names, X_units = get_matrix_and_info (Xnew, m. dimensions_type)
369
360
T = promote_type (eltype (Xnew_t), fitresult. types. T)
361
+
362
+ if isempty (params. equations) || any (isempty, params. equations)
363
+ @warn " Equations not found. Returning 0s for prediction."
364
+ return prediction_fallback (T, m, Xnew_t, fitresult, prototype)
365
+ end
366
+
370
367
X_units_clean = clean_units (X_units)
371
368
validate_variable_names (variable_names, fitresult)
372
369
validate_units (X_units_clean, fitresult. X_units)
373
- equations = params. equations
374
- if any (t -> length (t) == 0 , equations)
375
- return prediction_fallback (T, m, Xnew_t, fitresult, prototype)
376
- end
377
- best_idx = params. best_idx
378
- outs = []
379
- for (i, (best_i, eq)) in enumerate (zip (best_idx, equations))
380
- out, completed = eval_tree_array (eq[best_i], Xnew_t, fitresult. options)
381
- if ! completed
382
- return prediction_fallback (T, m, Xnew_t, fitresult, prototype)
370
+
371
+ idx = idx === nothing ? params. best_idx : idx
372
+
373
+ if M <: SRRegressor
374
+ return eval_tree_mlj (
375
+ params. equations[idx], Xnew_t, m, T, fitresult, nothing , prototype
376
+ )
377
+ elseif M <: MultitargetSRRegressor
378
+ outs = [
379
+ eval_tree_mlj (
380
+ params. equations[i][idx[i]], Xnew_t, m, T, fitresult, i, prototype
381
+ ) for i in eachindex (idx, params. equations)
382
+ ]
383
+ out_matrix = reduce (hcat, outs)
384
+ if ! fitresult. y_is_table
385
+ return out_matrix
386
+ else
387
+ return MMI. table (out_matrix; names= fitresult. y_variable_names, prototype)
383
388
end
384
- push! (outs, wrap_units (out, fitresult. y_units, i))
385
- end
386
- out_matrix = reduce (hcat, outs)
387
- if ! fitresult. y_is_table
388
- return out_matrix
389
- else
390
- return MMI. table (out_matrix; names= fitresult. y_variable_names, prototype= prototype)
391
389
end
392
390
end
393
391
@@ -508,11 +506,14 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
508
506
Note that if you pass complex data `::Complex{L}`, then the loss
509
507
type will automatically be set to `L`.
510
508
- `selection_method::Function`: Function to selection expression from
511
- the Pareto frontier for use in `predict`. See `SymbolicRegression.MLJInterfaceModule.choose_best`
512
- for an example. This function should return a single integer specifying
513
- the index of the expression to use. By default, `choose_best` maximizes
509
+ the Pareto frontier for use in `predict`.
510
+ See `SymbolicRegression.MLJInterfaceModule.choose_best` for an example.
511
+ This function should return a single integer specifying
512
+ the index of the expression to use. By default, this maximizes
514
513
the score (a pound-for-pound rating) of expressions reaching the threshold
515
- of 1.5x the minimum loss. To fix the index at `5`, you could just write `Returns(5)`.
514
+ of 1.5x the minimum loss. To override this at prediction time, you can pass
515
+ a named tuple with keys `data` and `idx` to `predict`. See the Operations
516
+ section for details.
516
517
- `dimensions_type::AbstractDimensions`: The type of dimensions to use when storing
517
518
the units of the data. By default this is `DynamicQuantities.SymbolicDimensions`.
518
519
"""
@@ -523,6 +524,9 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
523
524
- `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which
524
525
should have same scitype as `X` above. The expression used for prediction is defined
525
526
by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`.
527
+ - `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features
528
+ `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys
529
+ `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`.
526
530
527
531
$(bottom_matter)
528
532
"""
@@ -583,7 +587,8 @@ eval(
583
587
Note that unlike other regressors, symbolic regression stores a list of
584
588
trained models. The model chosen from this list is defined by the function
585
589
`selection_method` keyword argument, which by default balances accuracy
586
- and complexity.
590
+ and complexity. You can override this at prediction time by passing a named
591
+ tuple with keys `data` and `idx`.
587
592
588
593
""" ,
589
594
r" ^ " => " " ,
@@ -595,7 +600,8 @@ eval(
595
600
The fields of `fitted_params(mach)` are:
596
601
597
602
- `best_idx::Int`: The index of the best expression in the Pareto frontier,
598
- as determined by the `selection_method` function.
603
+ as determined by the `selection_method` function. Override in `predict` by passing
604
+ a named tuple with keys `data` and `idx`.
599
605
- `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
600
606
in a dominating Pareto frontier (i.e., the best expressions found for
601
607
each complexity). `T` is equal to the element type
@@ -608,7 +614,8 @@ eval(
608
614
The fields of `report(mach)` are:
609
615
610
616
- `best_idx::Int`: The index of the best expression in the Pareto frontier,
611
- as determined by the `selection_method` function.
617
+ as determined by the `selection_method` function. Override in `predict` by passing
618
+ a named tuple with keys `data` and `idx`.
612
619
- `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
613
620
in a dominating Pareto frontier (i.e., the best expressions found for
614
621
each complexity).
@@ -705,7 +712,8 @@ eval(
705
712
Note that unlike other regressors, symbolic regression stores a list of lists of
706
713
trained models. The models chosen from each of these lists is defined by the function
707
714
`selection_method` keyword argument, which by default balances accuracy
708
- and complexity.
715
+ and complexity. You can override this at prediction time by passing a named
716
+ tuple with keys `data` and `idx`.
709
717
710
718
""" ,
711
719
r" ^ " => " " ,
@@ -717,7 +725,8 @@ eval(
717
725
The fields of `fitted_params(mach)` are:
718
726
719
727
- `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
720
- as determined by the `selection_method` function.
728
+ as determined by the `selection_method` function. Override in `predict` by passing
729
+ a named tuple with keys `data` and `idx`.
721
730
- `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
722
731
in a dominating Pareto frontier (i.e., the best expressions found for
723
732
each complexity). The outer vector is indexed by target variable, and the inner
@@ -731,7 +740,8 @@ eval(
731
740
The fields of `report(mach)` are:
732
741
733
742
- `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
734
- as determined by the `selection_method` function.
743
+ as determined by the `selection_method` function. Override in `predict` by passing
744
+ a named tuple with keys `data` and `idx`.
735
745
- `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
736
746
in a dominating Pareto frontier (i.e., the best expressions found for
737
747
each complexity). The outer vector is indexed by target variable, and the inner
0 commit comments