Skip to content

Commit 22b5b7b

Browse files
committed
Update docs for index selection API
1 parent 908141e commit 22b5b7b

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

README.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,20 @@ predict(mach, X)
153153
```
154154

155155
This will make predictions using the expression
156-
selected using the function passed to `selection_method`.
157-
By default this selection is made a mix of accuracy and complexity.
158-
For example, we can make predictions using expression 2 with:
156+
selected by `model.selection_method`,
157+
which by default is a mix of accuracy and complexity.
158+
159+
You can override this selection and select an equation from
160+
the Pareto front manually with:
159161

160162
```julia
161-
mach.model.selection_method = Returns(2)
162-
predict(mach, X)
163+
predict(mach, (data=X, idx=2))
163164
```
164165

165-
For fitting multiple outputs, one can use `MultitargetSRRegressor`.
166+
where here we choose to evaluate the second equation.
167+
168+
For fitting multiple outputs, one can use `MultitargetSRRegressor`
169+
(and pass an array of indices to `idx` in `predict` for selecting specific equations).
166170
For a full list of options available to each regressor, see the [API page](https://astroautomata.com/SymbolicRegression.jl/dev/api/).
167171

168172
### Low-Level Interface

src/MLJInterface.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -500,11 +500,14 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
500500
Note that if you pass complex data `::Complex{L}`, then the loss
501501
type will automatically be set to `L`.
502502
- `selection_method::Function`: Function to selection expression from
503-
the Pareto frontier for use in `predict`. See `SymbolicRegression.MLJInterfaceModule.choose_best`
504-
for an example. This function should return a single integer specifying
505-
the index of the expression to use. By default, `choose_best` maximizes
503+
the Pareto frontier for use in `predict`.
504+
See `SymbolicRegression.MLJInterfaceModule.choose_best` for an example.
505+
This function should return a single integer specifying
506+
the index of the expression to use. By default, this maximizes
506507
the score (a pound-for-pound rating) of expressions reaching the threshold
507-
of 1.5x the minimum loss. To fix the index at `5`, you could just write `Returns(5)`.
508+
of 1.5x the minimum loss. To override this at prediction time, you can pass
509+
a named tuple with keys `data` and `idx` to `predict`. See the Operations
510+
section for details.
508511
- `dimensions_type::AbstractDimensions`: The type of dimensions to use when storing
509512
the units of the data. By default this is `DynamicQuantities.SymbolicDimensions`.
510513
"""
@@ -515,7 +518,7 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
515518
- `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which
516519
should have same scitype as `X` above. The expression used for prediction is defined
517520
by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`.
518-
- `predict(mach, (; data=Xnew, idx=i))`: Return predictions of the target given features
521+
- `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features
519522
`Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys
520523
`data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`.
521524
@@ -578,7 +581,8 @@ eval(
578581
Note that unlike other regressors, symbolic regression stores a list of
579582
trained models. The model chosen from this list is defined by the function
580583
`selection_method` keyword argument, which by default balances accuracy
581-
and complexity.
584+
and complexity. You can override this at prediction time by passing a named
585+
tuple with keys `data` and `idx`.
582586
583587
""",
584588
r"^ " => "",
@@ -590,7 +594,8 @@ eval(
590594
The fields of `fitted_params(mach)` are:
591595
592596
- `best_idx::Int`: The index of the best expression in the Pareto frontier,
593-
as determined by the `selection_method` function.
597+
as determined by the `selection_method` function. Override in `predict` by passing
598+
a named tuple with keys `data` and `idx`.
594599
- `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
595600
in a dominating Pareto frontier (i.e., the best expressions found for
596601
each complexity). `T` is equal to the element type
@@ -701,7 +706,8 @@ eval(
701706
Note that unlike other regressors, symbolic regression stores a list of lists of
702707
trained models. The models chosen from each of these lists is defined by the function
703708
`selection_method` keyword argument, which by default balances accuracy
704-
and complexity.
709+
and complexity. You can override this at prediction time by passing a named
710+
tuple with keys `data` and `idx`.
705711
706712
""",
707713
r"^ " => "",
@@ -713,7 +719,8 @@ eval(
713719
The fields of `fitted_params(mach)` are:
714720
715721
- `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
716-
as determined by the `selection_method` function.
722+
as determined by the `selection_method` function. Override in `predict` by passing
723+
a named tuple with keys `data` and `idx`.
717724
- `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
718725
in a dominating Pareto frontier (i.e., the best expressions found for
719726
each complexity). The outer vector is indexed by target variable, and the inner
@@ -727,7 +734,8 @@ eval(
727734
The fields of `report(mach)` are:
728735
729736
- `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
730-
as determined by the `selection_method` function.
737+
as determined by the `selection_method` function. Override in `predict` by passing
738+
a named tuple with keys `data` and `idx`.
731739
- `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
732740
in a dominating Pareto frontier (i.e., the best expressions found for
733741
each complexity). The outer vector is indexed by target variable, and the inner

0 commit comments

Comments
 (0)