Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions pgx/_src/games/domineering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2026 The Pgx Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple, Optional

import jax
import jax.numpy as jnp
from jax import Array


class GameState(NamedTuple):
"""Internal state for the game Domineering on an 8x8 board."""

color: Array = jnp.int32(0)
# 8x8 board
# [[ 0, 1, 2, 3, 4, 5, 6, 7],
# [ 8, 9, 10, 11, 12, 13, 14, 15],
# [16, 17, 18, 19, 20, 21, 22, 23],
# [24, 25, 26, 27, 28, 29, 30, 31],
# [32, 33, 34, 35, 36, 37, 38, 39],
# [40, 41, 42, 43, 44, 45, 46, 47],
# [48, 49, 50, 51, 52, 53, 54, 55],
# [56, 57, 58, 59, 60, 61, 62, 63]]
board: Array = jnp.ones(64, jnp.bool_) # True (available), False (occupied)
winner: Array = jnp.int32(-1)


class Game:
"""The game representation of Domineering on an 8x8 board."""

def init(self) -> GameState:
return GameState()

def step(self, state: GameState, action: Array) -> GameState:
"""Performs a step in the Domineering game.

Args:
state: The current game state.
action: The chosen action, representing the index of the top/left
square of the domino to be placed.

Returns:
The new game state after the action has been applied.
"""
new_board = state.board.at[jnp.array([action, action + jax.lax.select(state.color == 0, 1, 8)])].set(False)

def can_play(move_mask):
return (new_board & move_mask).sum(axis=0) == 2

# Game is over if the player next to play has no legal moves.
has_next_move = jax.vmap(can_play)(jax.lax.select(state.color == 0, MASK_CACHE_V, MASK_CACHE_H)).any()

return state._replace( # type: ignore
color=1 - state.color,
board=new_board,
winner=jax.lax.select(has_next_move, -1, state.color),
)

def observe(self, state: GameState, _: Optional[Array] = None) -> Array:
return state.board.reshape(8, 8)

def legal_action_mask(self, state: GameState) -> Array:
# To be legal, a move have its own square and a neighbour free, and not be
# on the edge of the board. The relevant definition of neighbour and edge
# depends on the player's direction.
return state.board & jax.lax.select(
state.color == 0,
EDGE_EXCLUDER_H & jnp.roll(state.board.reshape(8, 8), shift=-1, axis=1).flatten(),
EDGE_EXCLUDER_V & jnp.roll(state.board, shift=-8, axis=0),
)

def is_terminal(self, state: GameState) -> Array:
return state.winner >= 0 # Game always ends with a winner.

def rewards(self, state: GameState) -> Array:
return jax.lax.select(
state.winner >= 0,
jnp.float32([-1, -1]).at[state.winner].set(1),
jnp.zeros(2, jnp.float32),
)


def _make_mask_cache_horizontal():
move_masks = []
for x in range(7):
for y in range(8):
move_masks.append(jnp.zeros(64, jnp.bool_).at[jnp.array([y * 8 + x, y * 8 + x + 1])].set(True))
return jnp.array(move_masks)


def _make_mask_cache_vertical():
move_masks = []
for x in range(8):
for y in range(7):
move_masks.append(jnp.zeros(64, jnp.bool_).at[jnp.array([y * 8 + x, y * 8 + x + 8])].set(True))
return jnp.array(move_masks)


# Precomputed masks for required empty squares for each possible move.
MASK_CACHE_H = _make_mask_cache_horizontal()
MASK_CACHE_V = _make_mask_cache_vertical()

# Blockers for moves along the (player-appropriate) edge.
EDGE_EXCLUDER_H = jnp.tile(jnp.ones(8, jnp.bool_).at[7].set(False), 8)
EDGE_EXCLUDER_V = jnp.append(jnp.tile(jnp.ones(8, jnp.bool_), 7), jnp.zeros(8, jnp.bool_))
5 changes: 5 additions & 0 deletions pgx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"bridge_bidding",
"chess",
"connect_four",
"domineering",
"gardner_chess",
"go_9x9",
"go_19x19",
Expand Down Expand Up @@ -357,6 +358,10 @@ def make(env_id: EnvId): # noqa: C901
from pgx.connect_four import ConnectFour

return ConnectFour()
elif env_id == "domineering":
from pgx.domineering import Domineering

return Domineering()
elif env_id == "gardner_chess":
from pgx.gardner_chess import GardnerChess

Expand Down
90 changes: 90 additions & 0 deletions pgx/domineering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2026 The Pgx Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp

import pgx.core as core
from pgx._src.games.domineering import Game, GameState
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey


@dataclass
class State(core.State):
"""State for the game Domineering."""

current_player: Array = jnp.int32(0)
observation: Array = jnp.ones((8, 8), dtype=jnp.bool_)
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = jnp.bool_(False)
truncated: Array = jnp.bool_(False)
legal_action_mask: Array = jnp.tile(jnp.ones(8, dtype=jnp.bool_).at[7].set(False), 8)
_step_count: Array = jnp.int32(0)
_x: GameState = GameState()

@property
def env_id(self) -> core.EnvId:
return "domineering"


class Domineering(core.Env):
"""Environment for the game Domineering."""

def __init__(self):
super().__init__()
self._game = Game()

def _init(self, key: PRNGKey) -> State:
current_player = jnp.int32(jax.random.bernoulli(key))
return State(current_player=current_player, _x=self._game.init()) # type:ignore

def _step(self, state: core.State, action: Array, key) -> State:
del key
assert isinstance(state, State)
x = self._game.step(state._x, action)
state = state.replace( # type: ignore
current_player=1 - state.current_player,
_x=x,
)
assert isinstance(state, State)
legal_action_mask = self._game.legal_action_mask(state._x)
terminated = self._game.is_terminal(state._x)
rewards = self._game.rewards(state._x)
should_flip = state.current_player != state._x.color
rewards = jax.lax.select(should_flip, jnp.flip(rewards), rewards)
rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32))
return state.replace( # type: ignore
legal_action_mask=legal_action_mask,
rewards=rewards,
terminated=terminated,
)

def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
curr_color = state._x.color
my_color = jax.lax.select(player_id == state.current_player, curr_color, 1 - curr_color)
return self._game.observe(state._x, my_color)

@property
def id(self) -> core.EnvId:
return "domineering"

@property
def version(self) -> str:
return "v0"

@property
def num_players(self) -> int:
return 2
Loading