Skip to content

Commit ceab1e3

Browse files
committed
Revert "Allow shapecheck of PixelCNN++ (#2017)"
This reverts commit 8f538f4. Issue: #2245
1 parent bf784a4 commit ceab1e3

File tree

8 files changed

+260
-224
lines changed

8 files changed

+260
-224
lines changed

jax/api.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import collections
2828
import functools
2929
import itertools as it
30+
import operator as op
3031
import os
31-
import string
3232
import threading
3333
from warnings import warn
3434

@@ -51,14 +51,14 @@
5151
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
5252
host_id, host_ids, host_count)
5353
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
54-
from .interpreters.masking import eval_polymorphic_shape, Poly, Mon
5554
from .interpreters import partial_eval as pe
5655
from .interpreters import xla
5756
from .interpreters import pxla
5857
from .interpreters import ad
5958
from .interpreters import batching
6059
from .interpreters import parallel
6160
from .interpreters import masking
61+
from .interpreters.masking import shapecheck, ensure_poly
6262
from .config import flags, config, bool_env
6363

6464
map = safe_map
@@ -1053,23 +1053,24 @@ def wrapped_fun(args, logical_env):
10531053
out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs))
10541054
if not out_shapes == list(out_shapes_):
10551055
raise masking.ShapeError
1056-
if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env)
1057-
for out, shape in zip(outs, out_shapes)):
1056+
if not all(onp.shape(out) == masking.eval_shape_expr(padded_env, expr)
1057+
for out, expr in zip(outs, out_shapes)):
10581058
raise masking.ShapeError
10591059
return tree_unflatten(out_tree(), outs)
10601060
return wrapped_fun
10611061

10621062
def _remap_ids(names, shape_spec):
1063-
return masking.ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
1063+
ShapeSpec, Poly, Mon = masking.ShapeSpec, masking.Poly, masking.Mon
1064+
mdim = masking.monomorphic_dim
1065+
return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
10641066
: coeff for mon, coeff in poly.items()})
1065-
if poly is not masking._monomorphic_dim else
1066-
masking._monomorphic_dim for poly in shape_spec)
1067+
if poly is not mdim else mdim for poly in shape_spec)
10671068

10681069
def _bind_shapes(shape_exprs, shapes):
10691070
env = {}
10701071
for shape_expr, shape in zip(shape_exprs, shapes):
10711072
for poly, d in zip(shape_expr, shape):
1072-
if type(poly) is not Poly or poly.is_constant:
1073+
if ensure_poly(poly).is_constant:
10731074
continue
10741075
else:
10751076
(binder,), = poly # TODO generalize to handle striding
@@ -1084,13 +1085,16 @@ def shapecheck(in_shapes, out_shape, fun):
10841085
out_shapes, out_tree = tree_flatten(out_shape)
10851086
out_shapes = map(masking.parse_spec, out_shapes)
10861087
flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
1087-
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
1088-
out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
1088+
out_shapes_ = masking.shapecheck(flat_fun, in_shapes)
10891089
if out_tree != out_tree_(): raise TypeError("pytree mismatch")
1090-
if not all(map(masking._shape_spec_consistent, out_shapes, out_shapes_)):
1090+
if not all(map(_shape_spec_consistent, out_shapes, out_shapes_)):
10911091
raise masking.ShapeError
10921092
return fun
10931093

1094+
def _shape_spec_consistent(spec, expr):
1095+
return all(a == b for a, b in zip(spec, expr) if a is not masking.monomorphic_dim)
1096+
1097+
10941098
def jvp(fun, primals, tangents):
10951099
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
10961100

0 commit comments

Comments
 (0)