1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from typing import NamedTuple
16+
1517import jax
1618import jax .numpy as jnp
1719
2325TRUE = jnp .bool_ (True )
2426
2527
26- @dataclass
27- class State (core .State ):
28- current_player : Array = jnp .int32 (0 )
29- observation : Array = jnp .zeros ((8 , 8 , 2 ), dtype = jnp .bool_ )
30- rewards : Array = jnp .float32 ([0.0 , 0.0 ])
31- terminated : Array = FALSE
32- truncated : Array = FALSE
33- legal_action_mask : Array = jnp .ones (64 + 1 , dtype = jnp .bool_ )
34- _step_count : Array = jnp .int32 (0 )
35- # --- Othello specific ---
36- _turn : Array = jnp .int32 (0 )
28+ class GameState (NamedTuple ):
29+ turn : Array = jnp .int32 (0 )
3730 # 8x8 board
3831 # [[ 0, 1, 2, 3, 4, 5, 6, 7],
3932 # [ 8, 9, 10, 11, 12, 13, 14, 15],
@@ -43,8 +36,20 @@ class State(core.State):
4336 # [40, 41, 42, 43, 44, 45, 46, 47],
4437 # [48, 49, 50, 51, 52, 53, 54, 55],
4538 # [56, 57, 58, 59, 60, 61, 62, 63]]
46- _board : Array = jnp .zeros (64 , jnp .int32 ) # -1(opp), 0(empty), 1(self)
47- _passed : Array = FALSE
39+ board : Array = jnp .zeros (64 , jnp .int32 )
40+ passed : Array = jnp .bool_ (False )
41+
42+
43+ @dataclass
44+ class State (core .State ):
45+ current_player : Array = jnp .int32 (0 )
46+ observation : Array = jnp .zeros ((8 , 8 , 2 ), dtype = jnp .bool_ )
47+ rewards : Array = jnp .float32 ([0.0 , 0.0 ])
48+ terminated : Array = FALSE
49+ truncated : Array = FALSE
50+ legal_action_mask : Array = jnp .ones (64 + 1 , dtype = jnp .bool_ )
51+ _step_count : Array = jnp .int32 (0 )
52+ _x : GameState = GameState ()
4853
4954 @property
5055 def env_id (self ) -> core .EnvId :
@@ -107,7 +112,10 @@ def _init(rng: PRNGKey) -> State:
107112 current_player = jnp .int32 (jax .random .bernoulli (rng ))
108113 return State (
109114 current_player = current_player ,
110- _board = jnp .zeros (64 , dtype = jnp .int32 ).at [28 ].set (1 ).at [35 ].set (1 ).at [27 ].set (- 1 ).at [36 ].set (- 1 ),
115+ _x = GameState (
116+ turn = 0 ,
117+ board = jnp .zeros (64 , dtype = jnp .int32 ).at [28 ].set (1 ).at [35 ].set (1 ).at [27 ].set (- 1 ).at [36 ].set (- 1 ),
118+ ),
111119 legal_action_mask = jnp .zeros (64 + 1 , dtype = jnp .bool_ )
112120 .at [19 ]
113121 .set (TRUE )
@@ -121,7 +129,7 @@ def _init(rng: PRNGKey) -> State:
121129
122130
123131def _step (state , action ):
124- board = state ._board
132+ board = state ._x . board
125133 my = board > 0
126134 opp = board < 0
127135
@@ -167,19 +175,21 @@ def _make_legal(i, legal):
167175 legal_action = jax .lax .fori_loop (0 , 8 , _make_legal , jnp .zeros (64 , dtype = jnp .bool_ ))
168176
169177 reward , terminated = jax .lax .cond (
170- ((jnp .count_nonzero (my | opp ) == 64 ) | ~ opp .any () | (state ._passed & (action == 64 ))),
178+ ((jnp .count_nonzero (my | opp ) == 64 ) | ~ opp .any () | (state ._x . passed & (action == 64 ))),
171179 lambda : (_get_reward (my , opp , state .current_player ), TRUE ),
172180 lambda : (jnp .zeros (2 , jnp .float32 ), FALSE ),
173181 )
174182
175183 return state .replace (
176184 current_player = 1 - state .current_player ,
177- _turn = 1 - state ._turn ,
185+ _x = GameState (
186+ turn = 1 - state ._x .turn ,
187+ board = - jnp .where (jnp .int32 (opp ), - 1 , jnp .int32 (my )),
188+ passed = action == 64 ,
189+ ),
178190 legal_action_mask = state .legal_action_mask .at [:64 ].set (legal_action ).at [64 ].set (~ legal_action .any ()),
179191 rewards = reward ,
180192 terminated = terminated ,
181- _board = - jnp .where (jnp .int32 (opp ), - 1 , jnp .int32 (my )),
182- _passed = action == 64 ,
183193 )
184194
185195
@@ -208,8 +218,8 @@ def _get_reward(my, opp, curr_player):
208218def _observe (state , player_id ) -> Array :
209219 board = jax .lax .cond (
210220 player_id == state .current_player ,
211- lambda : state ._board .reshape ((8 , 8 )),
212- lambda : (state ._board * - 1 ).reshape ((8 , 8 )),
221+ lambda : state ._x . board .reshape ((8 , 8 )),
222+ lambda : (state ._x . board * - 1 ).reshape ((8 , 8 )),
213223 )
214224
215225 def make (color ):
@@ -219,4 +229,4 @@ def make(color):
219229
220230
221231def _get_abs_board (state ):
222- return jax .lax .cond (state ._turn == 0 , lambda : state ._board , lambda : state ._board * - 1 )
232+ return jax .lax .cond (state ._x . turn == 0 , lambda : state ._x . board , lambda : state ._x . board * - 1 )
0 commit comments