diff --git a/bioptim/examples/muscle_driven_ocp/muscle_activations_tracker.py b/bioptim/examples/muscle_driven_ocp/muscle_activations_tracker.py index f3f031958..f4f8d8306 100644 --- a/bioptim/examples/muscle_driven_ocp/muscle_activations_tracker.py +++ b/bioptim/examples/muscle_driven_ocp/muscle_activations_tracker.py @@ -153,6 +153,7 @@ def prepare_ocp( kin_data_to_track: str = "markers", use_residual_torque: bool = True, ode_solver: OdeSolver = OdeSolver.COLLOCATION(), + n_threads: int = 1, ) -> OptimalControlProgram: """ Prepare the ocp to solve @@ -177,6 +178,8 @@ def prepare_ocp( If residual torque are present or not in the dynamics ode_solver: OdeSolver The ode solver to use + n_threads: int + The number of threads Returns ------- @@ -240,6 +243,7 @@ def prepare_ocp( u_bounds, objective_functions, ode_solver=ode_solver, + n_threads=n_threads, ) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index 0739477c4..0f9b26351 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -251,7 +251,7 @@ def get_x_and_u_at_idx(_penalty, _idx): if penalty.multi_thread: if penalty.target is not None and len(penalty.target[0].shape) != 2: raise NotImplementedError("multi_thread penalty with target shape != [n x m] is not implemented yet") - target = penalty.target if penalty.target is not None else [] + target = penalty.target[0] if penalty.target is not None else [] x = nlp.cx() u = nlp.cx() diff --git a/tests/test_global_muscle_tracking.py b/tests/test_global_muscle_tracking.py index 99852c0a5..e36f0b53a 100644 --- a/tests/test_global_muscle_tracking.py +++ b/tests/test_global_muscle_tracking.py @@ -12,7 +12,8 @@ @pytest.mark.parametrize("ode_solver", [OdeSolver.RK4, OdeSolver.COLLOCATION, OdeSolver.IRK]) -def test_muscle_activations_and_states_tracking(ode_solver): +@pytest.mark.parametrize("n_threads", [1, 2]) +def test_muscle_activations_and_states_tracking(ode_solver, n_threads): # Load muscle_activations_tracker from bioptim.examples.muscle_driven_ocp import muscle_activations_tracker as ocp_module @@ -42,6 +43,7 @@ def test_muscle_activations_and_states_tracking(ode_solver): use_residual_torque=use_residual_torque, kin_data_to_track="q", ode_solver=ode_solver(), + n_threads=n_threads, ) solver = Solver.IPOPT() # solver.set_maximum_iterations(10)