Skip to content
Merged

uh #64

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[deps]
CompatHelperLocal = "5224ae11-6099-4aaa-941d-3aab004bd678"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
35 changes: 34 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,41 @@ import CompatHelperLocal as CHL
CHL.@check()

using AlgorithmicRecourseDynamics
using AlgorithmicRecourseDynamics.Data
using AlgorithmicRecourseDynamics.Experiments
using AlgorithmicRecourseDynamics.Models
using AlgorithmicRecourseDynamics: run!
using CounterfactualExplanations
using Flux
using MLJBase
using Plots
using Random
using Test

@testset "AlgorithmicRecourseDynamics.jl" begin
# Write your tests here.

N = 1000
xmax = 2
X, ys = make_blobs(
N, 2;
centers=2, as_table=false, center_box=(-xmax => xmax), cluster_std=0.1
)
ys .= ys .== 2
X = X'
counterfactual_data = CounterfactualData(X, ys')

n_epochs = 100
model = Chain(Dense(2, 1))
mod = FluxModel(model)
generator = GenericGenerator()

data_train, data_test = Data.train_test_split(counterfactual_data)
Models.train(mod, data_train; n_epochs=n_epochs)

models = Dict(:mymodel => mod)
generators = Dict(:wachter => generator)
experiment = set_up_experiment(data_train, data_test, models, generators)

run!(experiment)

end