Skip to content

Add xtensor broadcast #1489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: labeled_tensors
Choose a base branch
from

Conversation

AllenDowney
Copy link
Contributor

@AllenDowney AllenDowney commented Jun 20, 2025

This replaces #1486. This one is based on a rebased labeled_tensor branch


📚 Documentation preview 📚: https://pytensor--1489.org.readthedocs.build/en/1489/

@AllenDowney
Copy link
Contributor Author

@ricardoV94 Here's my attempt to rebase on the changes you just force pushed. Looks like mypy is unhappy -- is that something you expected?

Other than that, I think this is ready for review.

@ricardoV94
Copy link
Member

@ricardoV94 Here's my attempt to rebase on the changes you just force pushed. Looks like mypy is unhappy -- is that something you expected?

Other than that, I think this is ready for review.

Yeah I didn't make mypy pass yet

x_tensor = x_tensor.dimshuffle(shuffle_pattern)

# Now we are aligned with target dims and correct ndim
x_tensor = broadcast_to(x_tensor, out.type.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work when the output shape is not statically known. The target shape has to be computer symbolically from the symbolic input shapes.

You can test by having an xtensor with shape=(None) for a dim that only that tensor has

@ricardoV94 ricardoV94 force-pushed the labeled_tensors branch 4 times, most recently from 71bc4ef to 41d9be4 Compare June 21, 2025 17:24
@AllenDowney
Copy link
Contributor Author

@ricardoV94 I think I have symbolic dimensions working. My solution is more complicated than I think any of us would like, but I don't see a simpler solution. Maybe you will.

Should we continue work on this PR, for now, and I will rebase later?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 21, 2025

Here is an idea:

def lower_broadcast(fgraph, node):
  excluded_dims = node.op.exclude
  broadcast_dims = tuple(dim for dim in node.outputs[0].type.dims if dim not in excluded_dims)
  all_dims = broadcast_dims + excluded_dims
  
  # align inputs with all_dims like we do in other rewrites 
  # probably time to refactor this kind of logic into a helper
  inp_tensors = []
  for inp, out in zip(node.inputs, node.outputs, strict=True)
    inp_dims = inp.type.dims
    order = tuple(inp_dims.index(dim) if dim in inp_dims else "x" for dim in all_dims)
    inp_tensors.append(inp.values.dimshuffle(order))
  
  if not excluded_dims:
    out_tensors = pt.broadcast_arrays(*inp_tensors)
  else:
    all_shape = tuple(pt.broadcast_shape(*inp_tensors))
    assert len(all_shape) == len(all_dims)
    for inp_tensor, out in zip(inp_tensors, node.outputs):
      out_dims = out.type.dims
      out_shape = tuple(length for length, dim in zip(all_shape, all_dims) if dim in out_dims)
      out_tensors.append(pt.broadcast_to(inp_tensor, out_shape)

  new_outs = [as_xtensor(out_tensor, dims=out.type.dims) for out_tensor, out in zip(out_tensors, node.outputs)]
  return new_outs

Btw the base branch is merged. You can rebase/ start from it. Note that you don't need to open a new PR. You can force-push your changes after cleaning up the branch to your current remote

@AllenDowney
Copy link
Contributor Author

@ricardoV94 I've added broadcast_like. Once we're happy with broadcast and broadcast_like, I will factor out some common code.

@AllenDowney
Copy link
Contributor Author

Your version of lower_broadcast works when exclude is empty, but it fails on tests that have excluded dims.

I'll work on debugging it, but at the moment it's not clear to me whether these is a small error in your implementation or an actual problem with the logic.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 21, 2025

I suspect some wrong assumption on the excluded dims alignment but the general idea should work

@AllenDowney
Copy link
Contributor Author

I think the incorrect assumption is that all outputs have the same shape. When exclude is not empty, they don't, in general.

@ricardoV94
Copy link
Member

Actually there's a logical flaw. Two inputs could have an excluded dim with the same name but different length, in which case they shouldn't be aligned for the broadcast shape.

We should add that as a test.

Still the logic for each output should be something like broadcast_to(tensor, common_broadcast_shape + original_excluded_shape)

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 21, 2025

I think the incorrect assumption is that all outputs have the same shape. When exclude is not empty, they don't, in general.

I didn't assume that, the dimshuffle was supposed to take care of that so that things were put in different axis for broadcasting. Still as I just wrote there was a wrong assumption that you could align shared excluded dims. They don't even come out in a uniform order do they?

@ricardoV94
Copy link
Member

I don't think this logical flaw is why the tests are failing though. We should test that case as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants