Skip to content

Commit 7932e97

Browse files
authored
[Othello] Extract game specific attributes (#1296)
1 parent 17ea895 commit 7932e97

File tree

2 files changed

+33
-23
lines changed

2 files changed

+33
-23
lines changed

pgx/othello.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import NamedTuple
16+
1517
import jax
1618
import jax.numpy as jnp
1719

@@ -23,17 +25,8 @@
2325
TRUE = jnp.bool_(True)
2426

2527

26-
@dataclass
27-
class State(core.State):
28-
current_player: Array = jnp.int32(0)
29-
observation: Array = jnp.zeros((8, 8, 2), dtype=jnp.bool_)
30-
rewards: Array = jnp.float32([0.0, 0.0])
31-
terminated: Array = FALSE
32-
truncated: Array = FALSE
33-
legal_action_mask: Array = jnp.ones(64 + 1, dtype=jnp.bool_)
34-
_step_count: Array = jnp.int32(0)
35-
# --- Othello specific ---
36-
_turn: Array = jnp.int32(0)
28+
class GameState(NamedTuple):
29+
turn: Array = jnp.int32(0)
3730
# 8x8 board
3831
# [[ 0, 1, 2, 3, 4, 5, 6, 7],
3932
# [ 8, 9, 10, 11, 12, 13, 14, 15],
@@ -43,8 +36,20 @@ class State(core.State):
4336
# [40, 41, 42, 43, 44, 45, 46, 47],
4437
# [48, 49, 50, 51, 52, 53, 54, 55],
4538
# [56, 57, 58, 59, 60, 61, 62, 63]]
46-
_board: Array = jnp.zeros(64, jnp.int32) # -1(opp), 0(empty), 1(self)
47-
_passed: Array = FALSE
39+
board: Array = jnp.zeros(64, jnp.int32)
40+
passed: Array = jnp.bool_(False)
41+
42+
43+
@dataclass
44+
class State(core.State):
45+
current_player: Array = jnp.int32(0)
46+
observation: Array = jnp.zeros((8, 8, 2), dtype=jnp.bool_)
47+
rewards: Array = jnp.float32([0.0, 0.0])
48+
terminated: Array = FALSE
49+
truncated: Array = FALSE
50+
legal_action_mask: Array = jnp.ones(64 + 1, dtype=jnp.bool_)
51+
_step_count: Array = jnp.int32(0)
52+
_x: GameState = GameState()
4853

4954
@property
5055
def env_id(self) -> core.EnvId:
@@ -107,7 +112,10 @@ def _init(rng: PRNGKey) -> State:
107112
current_player = jnp.int32(jax.random.bernoulli(rng))
108113
return State(
109114
current_player=current_player,
110-
_board=jnp.zeros(64, dtype=jnp.int32).at[28].set(1).at[35].set(1).at[27].set(-1).at[36].set(-1),
115+
_x=GameState(
116+
turn=0,
117+
board=jnp.zeros(64, dtype=jnp.int32).at[28].set(1).at[35].set(1).at[27].set(-1).at[36].set(-1),
118+
),
111119
legal_action_mask=jnp.zeros(64 + 1, dtype=jnp.bool_)
112120
.at[19]
113121
.set(TRUE)
@@ -121,7 +129,7 @@ def _init(rng: PRNGKey) -> State:
121129

122130

123131
def _step(state, action):
124-
board = state._board
132+
board = state._x.board
125133
my = board > 0
126134
opp = board < 0
127135

@@ -167,19 +175,21 @@ def _make_legal(i, legal):
167175
legal_action = jax.lax.fori_loop(0, 8, _make_legal, jnp.zeros(64, dtype=jnp.bool_))
168176

169177
reward, terminated = jax.lax.cond(
170-
((jnp.count_nonzero(my | opp) == 64) | ~opp.any() | (state._passed & (action == 64))),
178+
((jnp.count_nonzero(my | opp) == 64) | ~opp.any() | (state._x.passed & (action == 64))),
171179
lambda: (_get_reward(my, opp, state.current_player), TRUE),
172180
lambda: (jnp.zeros(2, jnp.float32), FALSE),
173181
)
174182

175183
return state.replace(
176184
current_player=1 - state.current_player,
177-
_turn=1 - state._turn,
185+
_x=GameState(
186+
turn=1 - state._x.turn,
187+
board=-jnp.where(jnp.int32(opp), -1, jnp.int32(my)),
188+
passed=action == 64,
189+
),
178190
legal_action_mask=state.legal_action_mask.at[:64].set(legal_action).at[64].set(~legal_action.any()),
179191
rewards=reward,
180192
terminated=terminated,
181-
_board=-jnp.where(jnp.int32(opp), -1, jnp.int32(my)),
182-
_passed=action == 64,
183193
)
184194

185195

@@ -208,8 +218,8 @@ def _get_reward(my, opp, curr_player):
208218
def _observe(state, player_id) -> Array:
209219
board = jax.lax.cond(
210220
player_id == state.current_player,
211-
lambda: state._board.reshape((8, 8)),
212-
lambda: (state._board * -1).reshape((8, 8)),
221+
lambda: state._x.board.reshape((8, 8)),
222+
lambda: (state._x.board * -1).reshape((8, 8)),
213223
)
214224

215225
def make(color):
@@ -219,4 +229,4 @@ def make(color):
219229

220230

221231
def _get_abs_board(state):
222-
return jax.lax.cond(state._turn == 0, lambda: state._board, lambda: state._board * -1)
232+
return jax.lax.cond(state._x.turn == 0, lambda: state._x.board, lambda: state._x.board * -1)

tests/test_othello.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_step():
3737
0, 0, 0, 0, 0, 0, 0, 0,
3838
0, 0, 0, 0, 0, 0, 0, 0])
3939
# fmt:on
40-
assert jnp.all(state._board == expected)
40+
assert jnp.all(state._x.board == expected)
4141

4242

4343
def test_terminated():

0 commit comments

Comments
 (0)