File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed
Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -50,9 +50,14 @@ def __init__(self):
5050 self ._game = Game ()
5151
5252 def _init (self , key : PRNGKey ) -> State :
53- state = State ()
54- player_order = jnp .array ([[0 , 1 ], [1 , 0 ]])[jax .random .bernoulli (key ).astype (jnp .int32 )]
55- return state .replace (_player_order = player_order ) # type: ignore
53+ x = GameState ()
54+ _player_order = jnp .array ([[0 , 1 ], [1 , 0 ]])[jax .random .bernoulli (key ).astype (jnp .int32 )]
55+ state = State ( # type: ignore
56+ current_player = _player_order [x .color ],
57+ _player_order = _player_order ,
58+ _x = x ,
59+ )
60+ return state
5661
5762 def _step (self , state : core .State , action : Array , key ) -> State :
5863 del key
@@ -83,4 +88,4 @@ def version(self) -> str:
8388
8489 @property
8590 def num_players (self ) -> int :
86- return 2
91+ return 2
You can’t perform that action at this time.
0 commit comments