diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 65bc74e6aa..8942ad7e6f 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -674,7 +674,6 @@ def __init__( typeConstructor: Optional[TensorConstructorType] = None, truncate_gradient: int = -1, name: Optional[str] = None, - as_while: bool = False, profile: Optional[Union[str, bool]] = None, allow_gc: bool = True, strict: bool = True, @@ -1183,7 +1182,7 @@ def make_node(self, *inputs): # these are states that do not feed anything back in the recurrent # computation, and hence they do not have an initial state. The scan # node however receives an input for each such argument, the input - # in this case is just a int saying how many steps of this output we + # in this case is just an int saying how many steps of this output we # need to store. This input does not have the same dtype, nor is it the same # type of tensor as the output, it is always a scalar int. new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)] diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 011ad2d208..da324d5df8 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -28,10 +28,18 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import compute_test_value from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter +from pytensor.graph.rewriting.basic import ( + GraphRewriter, + copy_stack_trace, + in2out, + node_rewriter, +) from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB +from pytensor.graph.rewriting.utils import get_clients_at_depth from pytensor.graph.type import HasShape from pytensor.graph.utils import InconsistencyError +from pytensor.raise_op import Assert +from pytensor.scalar import ScalarConstant from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.utils import ( ScanArgs, @@ -1103,6 +1111,71 @@ def sanitize(x): return at.as_tensor_variable(x) +@node_rewriter([Scan]) +def while_scan_merge_subtensor_last_element(fgraph, scan_node): + """ + Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for + recurring outputs, asserting that at least one step occurs. + Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`), + as the while scan could abort earlier anytime after that. This means it is + not possible to replace while_scan_out[abs(min(tap)):][-i] + by while_scan_out[-i], for -i != -1. + """ + op = scan_node.op + + if not op.info.as_while: + return None + + # Optimization is not implemented form mit-mot + recurrent_outputs = op.outer_mitsot_outs(scan_node.outputs) + op.outer_sitsot_outs( + scan_node.outputs + ) + recurrent_outputs_taps_slices = ( + op.info.mit_sot_in_slices + op.info.sit_sot_in_slices + ) + + n_steps = scan_node.inputs[0] + non_zero_steps_cond = n_steps > 0 + assert_non_zero_steps_op = Assert("n_steps > 0") + + subtensor_merge_replacements = {} + + # Iterate over all nodes that are two computations below the while scan + for node2 in get_clients_at_depth(fgraph, scan_node, depth=2): + if not isinstance(node2.op, Subtensor): + continue + + node1 = node2.inputs[0].owner + if not (node1 and isinstance(node1.op, Subtensor)): + continue + + x = node1.inputs[0] + if x not in recurrent_outputs: + continue + + slice1 = get_idx_list(node1.inputs, node1.op.idx_list) + slice2 = get_idx_list(node2.inputs, node2.op.idx_list) + + min_tap = abs(min(recurrent_outputs_taps_slices[recurrent_outputs.index(x)])) + + if ( + len(slice1) == 1 + and isinstance(slice1[0], slice) + and isinstance(slice1[0].start, aes.ScalarConstant) + and slice1[0].start.data == min_tap + and slice1[0].stop is None + and slice1[0].step is None + and len(slice2) == 1 + and isinstance(slice2[0], aes.ScalarConstant) + and slice2[0].data == -1 + ): + out = assert_non_zero_steps_op(x[-1], non_zero_steps_cond) + copy_stack_trace([node2.outputs[0], node2.inputs[0]], out) + subtensor_merge_replacements[node2.outputs[0]] = out + + return subtensor_merge_replacements + + @node_rewriter([Scan]) def save_mem_new_scan(fgraph, node): r"""Graph optimizer that reduces scan memory consumption. @@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node): that SITSOT output. Only the most recently computed timestep ever needs to be kept in memory. + There are two ways in which the Scan buffer size is controlled: + 1. Each recurring output is saved in an input empty tensor x with the initial + state written at x[:abs(min(taps))]. The remaining x[abs(min(taps)):] + positions determine how many intermediate results should be stored. + This rewrite shortens x[abs(min(taps)):] to the smallest possible size. + 2. Each non-recurrent output (nit-sot) is associated with a scalar integer + input that determines how many steps should be saved in the perform method. + This rewrite reduces this number to the smallest possible. + + The scan perform implementation takes the output sizes into consideration, + saving the newest results over the oldest ones whenever the buffer is filled. """ if not isinstance(node.op, Scan): return False @@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node): # index(step) for any output scan actually needs to compute # In other words n_steps should be equal to this maximal ! # Note: if we have a shared variable that gets updated at every step - # of the loop, reducing the number of steps will affect the the - # value of the shared variable after the loop so we need not to + # of the loop, reducing the number of steps will affect the + # value of the shared variable after the loop so we cannot # change the number of steps in that case. To do this we set # global_nsteps to None which is seen as a flag that nothing needs - # to be done + # to be done. + # Note: For simplicity while Scans also have global_nsteps set to None. + # All step optimizations require knowing the shape of the output, which + # cannot be determined from the inputs alone. assert len(node.outputs) >= c_outs - if len(node.outputs) == c_outs: + if len(node.outputs) == c_outs and not op.info.as_while: global_nsteps = {"real": -1, "sym": []} else: global_nsteps = None @@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node): else: # there is a **gotcha** here ! Namely, scan returns an # array that contains the initial state of the output - # as well. Which means that if have a initial state of + # as well. Which means that if y has an initial state of # length 3, and you look for 5 steps you get an output # y of length 8. If you only use y[:5], this does not # mean that you only need to loop for 5 steps but @@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node): # 2.3. Analyze global_nsteps to figure out for how many steps scan # needs to iterate - if global_nsteps is not None: + if global_nsteps is None: nw_steps = node.inputs[0] - + else: # there are some symbolic tensors that limit the number of # steps if len(global_nsteps["sym"]) == 0: @@ -1303,6 +1390,7 @@ def save_mem_new_scan(fgraph, node): real_steps = None nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0]) + # FIXME: This is not correct. Scan with 0 steps seems to be supported # Make sure the ScanSaveMem optimization never makes the new # number of steps to be 0 (this could happen, for instance, if # the optimization detects that the outputs of the Scan go through @@ -1310,9 +1398,6 @@ def save_mem_new_scan(fgraph, node): # 0 iterations are not supported. Make sure the new number of steps # is at least 1. nw_steps = select_max(nw_steps, 1) - else: - nw_steps = node.inputs[0] - global_nsteps = None # 2.4 Loop over the clients again now looking just to see how many # intermediate steps to store @@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node): store_steps[i] = 0 break - if i > op_info.n_mit_mot: - length = node.inputs[0] + init_l[i] + # Special case for recurrent outputs where only the last result + # is requested. This is needed for this rewrite to apply to + # do-while Scans at all. Otherwise, `get_canonical_form_slice` in + # the `else` branch would reintroduce a shape dependency on the + # original Scan which would lead this rewrite to abort in the end. + if ( + i <= op.info.n_mit_mot + and isinstance(this_slice[0], ScalarConstant) + and this_slice[0].value == -1 + ): + start = nw_steps - 1 else: - try: - length = shape_of[out][0] - except KeyError: - length = out.shape[0] - cf_slice = get_canonical_form_slice(this_slice[0], length) + if i <= op.info.n_mit_mot: + try: + length = shape_of[out][0] + except KeyError: + length = out.shape[0] + else: + length = node.inputs[0] + init_l[i] + + cf_slice = get_canonical_form_slice(this_slice[0], length) + + if isinstance(cf_slice[0], slice): + start = at.extract_constant(cf_slice[0].start) + else: + start = at.extract_constant(cf_slice[0]) - if isinstance(cf_slice[0], slice): - start = at.extract_constant(cf_slice[0].start) - else: - start = at.extract_constant(cf_slice[0]) if start == 0 or store_steps[i] == 0: store_steps[i] = 0 else: @@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node): nw_input = expand_empty(_nw_input, nw_steps) nw_inputs[in_idx] = nw_input else: + # FIXME: This is never used nw_input = nw_inputs[in_idx][: (initl + nw_steps)] elif ( @@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node): ) else: fslice = sanitize(cnf_slice[0]) - nw_slice = (fslice,) + tuple(old_slices[1:]) + nw_pos = inv_compress_map[idx] subtens = Subtensor(nw_slice) @@ -1604,9 +1704,16 @@ def save_mem_new_scan(fgraph, node): ) + tuple(old_slices[1:]) else: - position = ( - cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos] - ) + # Special case when only last value is requested + if ( + isinstance(old_slices[0], ScalarConstant) + and old_slices[0].value == -1 + ): + position = old_slices[0] + else: + position = ( + cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos] + ) nw_slice = (sanitize(position),) + tuple(old_slices[1:]) subtens = Subtensor(nw_slice) @@ -2403,6 +2510,12 @@ def push_out_dot1_scan(fgraph, node): position=5, ) +scan_eqopt2.register( + "while_scan_merge_subtensor_last_element", + in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True), + "fast_run", + "scan", +) scan_eqopt2.register( "constant_folding_for_scan2", diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 576a83e024..d121586860 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node): expresses all slices in a canonical form, and then merges them together. """ + from pytensor.scan.op import Scan if isinstance(node.op, Subtensor): u = node.inputs[0] @@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node): # slices of the first applied subtensor slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) slices2 = get_idx_list(node.inputs, node.op.idx_list) + + # Don't try to do the optimization on do-while scan outputs, + # as it will create a dependency on the shape of the outputs + if ( + x.owner is not None + and isinstance(x.owner.op, Scan) + and x.owner.op.info.as_while + ): + return None + # Get the shapes of the vectors ! try: # try not to introduce new shape into the graph diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 707c86fe0f..7087cfd79a 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -85,7 +85,7 @@ def indices_from_subtensor( op_indices: Iterable[ScalarConstant], idx_list: Optional[List[Union[Type, slice, Variable]]], -) -> Union[slice, Variable]: +) -> Tuple[Union[slice, Variable], ...]: """Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created. Parameters diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 5323f4b4c7..da53093b12 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -1395,6 +1395,98 @@ def f_pow2(x_tm1): rng = np.random.default_rng(utt.fetch_seed()) my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3])) + def test_while_scan_taps(self): + n_steps = scalar("n_steps", dtype="int64") + x0 = vector("x0") + + ys, _ = pytensor.scan( + # Fibonacci Sequence + lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)), + outputs_info=[{"initial": x0, "taps": [-2, -1]}], + n_steps=n_steps, + ) + # Save memory is triggered by choosing only last value + y = ys[-1] + + f = pytensor.function( + [n_steps, x0], y, mode=get_default_mode().including("scan") + ) + + np.testing.assert_equal(f(n_steps=1000, x0=[1, 1]), 55) + np.testing.assert_equal(f(n_steps=1, x0=[1, 1]), 2) + with pytest.raises(AssertionError, match="n_steps > 0"): + f(n_steps=0, x0=[1, 1]) + + # ys_trace is an Alloc that controls the size of the inner buffer, + # it should have shape[0] == 3, with two entries for the taps and one + # entry for the intermediate output + [scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) + _, ys_trace = scan_node.inputs + debug_fn = pytensor.function( + [n_steps, x0], ys_trace.shape[0], accept_inplace=True + ) + assert debug_fn(n_steps=1000, x0=[1, 1]) == 3 + + def test_while_scan_map(self): + xs = vector("xs") + ys, _ = pytensor.scan( + lambda x: (x + 1, {}, until(x + 1 >= 10)), + outputs_info=[None], + sequences=[xs], + ) + # Save memory is triggered by choosing only last value + y = ys[-1] + + f = pytensor.function([xs], y, mode=get_default_mode().including("scan")) + np.testing.assert_equal(f(xs=np.arange(100, dtype=config.floatX)), 10) + np.testing.assert_equal(f(xs=[0]), 1) + with pytest.raises(IndexError): + f(xs=[]) + + # len_ys is a numerical input that controls the shape of the inner buffer + # It should be 1, as only the last output is needed + [scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) + _, _, len_ys = scan_node.inputs + debug_fn = pytensor.function([xs], len_ys, accept_inplace=True) + assert debug_fn(xs=np.zeros((100,), dtype=config.floatX)) == 1 + + def test_while_scan_taps_and_map(self): + x0 = scalar("x0") + seq = vector("seq") + n_steps = scalar("n_steps", dtype="int64") + + # while loop + [ys, zs], _ = pytensor.scan( + lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)), + sequences=[seq], + outputs_info=[x0, None], + n_steps=n_steps, + ) + # Save memory is triggered by choosing only last value + y = ys[-1] + z = zs[-1] + + f = pytensor.function( + [x0, seq, n_steps], [y, z], mode=get_default_mode().including("scan") + ) + test_seq = np.zeros(200, dtype=config.floatX) + np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100) + np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21) + np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1) + with pytest.raises(AssertionError, match="n_steps > 0"): + f(x0=0, seq=test_seq, n_steps=0) + + # Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly. + # If a MissingInputError is raised, it means the rewrite failed + [scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) + _, _, ys_trace, len_zs = scan_node.inputs + debug_fn = pytensor.function( + [n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True + ) + stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200) + assert stored_ys_steps == 2 + assert stored_zs_steps == 1 + def test_inner_replace_dot(): """