27
27
import collections
28
28
import functools
29
29
import itertools as it
30
+ import operator as op
30
31
import os
31
- import string
32
32
import threading
33
33
from warnings import warn
34
34
51
51
from .lib .xla_bridge import (device_count , local_device_count , devices , local_devices ,
52
52
host_id , host_ids , host_count )
53
53
from .abstract_arrays import ConcreteArray , ShapedArray , raise_to_shaped
54
- from .interpreters .masking import eval_polymorphic_shape , Poly , Mon
55
54
from .interpreters import partial_eval as pe
56
55
from .interpreters import xla
57
56
from .interpreters import pxla
58
57
from .interpreters import ad
59
58
from .interpreters import batching
60
59
from .interpreters import parallel
61
60
from .interpreters import masking
61
+ from .interpreters .masking import shapecheck , ensure_poly
62
62
from .config import flags , config , bool_env
63
63
64
64
map = safe_map
@@ -1053,23 +1053,24 @@ def wrapped_fun(args, logical_env):
1053
1053
out_shapes = map (masking .finalize_spec , out_specs , map (onp .shape , outs ))
1054
1054
if not out_shapes == list (out_shapes_ ):
1055
1055
raise masking .ShapeError
1056
- if not all (onp .shape (out ) == eval_polymorphic_shape ( shape , padded_env )
1057
- for out , shape in zip (outs , out_shapes )):
1056
+ if not all (onp .shape (out ) == masking . eval_shape_expr ( padded_env , expr )
1057
+ for out , expr in zip (outs , out_shapes )):
1058
1058
raise masking .ShapeError
1059
1059
return tree_unflatten (out_tree (), outs )
1060
1060
return wrapped_fun
1061
1061
1062
1062
def _remap_ids (names , shape_spec ):
1063
- return masking .ShapeSpec (Poly ({Mon ({names [id ] : deg for id , deg in mon .items ()})
1063
+ ShapeSpec , Poly , Mon = masking .ShapeSpec , masking .Poly , masking .Mon
1064
+ mdim = masking .monomorphic_dim
1065
+ return ShapeSpec (Poly ({Mon ({names [id ] : deg for id , deg in mon .items ()})
1064
1066
: coeff for mon , coeff in poly .items ()})
1065
- if poly is not masking ._monomorphic_dim else
1066
- masking ._monomorphic_dim for poly in shape_spec )
1067
+ if poly is not mdim else mdim for poly in shape_spec )
1067
1068
1068
1069
def _bind_shapes (shape_exprs , shapes ):
1069
1070
env = {}
1070
1071
for shape_expr , shape in zip (shape_exprs , shapes ):
1071
1072
for poly , d in zip (shape_expr , shape ):
1072
- if type (poly ) is not Poly or poly .is_constant :
1073
+ if ensure_poly (poly ).is_constant :
1073
1074
continue
1074
1075
else :
1075
1076
(binder ,), = poly # TODO generalize to handle striding
@@ -1084,13 +1085,16 @@ def shapecheck(in_shapes, out_shape, fun):
1084
1085
out_shapes , out_tree = tree_flatten (out_shape )
1085
1086
out_shapes = map (masking .parse_spec , out_shapes )
1086
1087
flat_fun , out_tree_ = flatten_fun_nokwargs (lu .wrap_init (fun ), in_tree )
1087
- avals = map (partial (ShapedArray , dtype = onp .float32 ), in_shapes )
1088
- out_shapes_ = [o .shape for o in pe .abstract_eval_fun (flat_fun .call_wrapped , * avals )]
1088
+ out_shapes_ = masking .shapecheck (flat_fun , in_shapes )
1089
1089
if out_tree != out_tree_ (): raise TypeError ("pytree mismatch" )
1090
- if not all (map (masking . _shape_spec_consistent , out_shapes , out_shapes_ )):
1090
+ if not all (map (_shape_spec_consistent , out_shapes , out_shapes_ )):
1091
1091
raise masking .ShapeError
1092
1092
return fun
1093
1093
1094
+ def _shape_spec_consistent (spec , expr ):
1095
+ return all (a == b for a , b in zip (spec , expr ) if a is not masking .monomorphic_dim )
1096
+
1097
+
1094
1098
def jvp (fun , primals , tangents ):
1095
1099
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
1096
1100
0 commit comments