diff --git a/src/uproot/_dask.py b/src/uproot/_dask.py index 0b2b4633c..9145e0083 100644 --- a/src/uproot/_dask.py +++ b/src/uproot/_dask.py @@ -20,6 +20,10 @@ import uproot from uproot._util import no_filter, unset +from uproot.behaviors.RNTuple import HasFields +from uproot.behaviors.RNTuple import ( + _regularize_step_size as _RNTuple_regularize_step_size, +) from uproot.behaviors.TBranch import HasBranches, TBranch, _regularize_step_size if TYPE_CHECKING: @@ -616,7 +620,13 @@ def real_filter_branch(branch): recursive=recursive, filter_name=filter_name, filter_typename=filter_typename, - filter_branch=real_filter_branch, + **{ + ( + "filter_field" + if isinstance(obj, HasFields) + else "filter_branch" + ): real_filter_branch + }, full_paths=full_paths, ignore_duplicates=True, ) @@ -756,7 +766,11 @@ def _get_dask_array_delay_open( recursive=recursive, filter_name=filter_name, filter_typename=filter_typename, - filter_branch=filter_branch, + **{ + ( + "filter_field" if isinstance(obj, HasFields) else "filter_branch" + ): filter_branch + }, full_paths=full_paths, ignore_duplicates=True, ) @@ -925,6 +939,13 @@ def load_buffers( how=tuple, ) + if isinstance(tree, HasFields): + # Temporary workaround to have basic support for RNTuple + # This is needed since currently arrays only has top-level fields + # TODO: Ask people how they want this handled since the previous + # approach might not make sense for RNTuples + keys = [f for f in tree.field_names if f in keys] + awkward = uproot.extras.awkward() # The subform generated by awkward.to_buffers() has different form keys @@ -1472,7 +1493,13 @@ def real_filter_branch(branch): recursive=recursive, filter_name=filter_name, filter_typename=filter_typename, - filter_branch=real_filter_branch, + **{ + ( + "filter_field" + if isinstance(obj, HasFields) + else "filter_branch" + ): real_filter_branch + }, full_paths=full_paths, ignore_duplicates=True, ) @@ -1515,14 +1542,21 @@ def real_filter_branch(branch): entry_start = 0 entry_stop = ttree.num_entries - branchid_interpretation = {} - for key in common_keys: - branch = ttree[key] - branchid_interpretation[branch.cache_key] = branch.interpretation - ttree_step = _regularize_step_size( - ttree, step_size, entry_start, entry_stop, branchid_interpretation - ) - step_sum += int(ttree_step) + if isinstance(ttree, HasFields): + akform = ttree.to_akform(filter_name=common_keys) + ttree_step = _RNTuple_regularize_step_size( + ttree, akform, step_size, entry_start, entry_stop + ) + step_sum += int(ttree_step) + else: + branchid_interpretation = {} + for key in common_keys: + branch = ttree[key] + branchid_interpretation[branch.cache_key] = branch.interpretation + ttree_step = _regularize_step_size( + ttree, step_size, entry_start, entry_stop, branchid_interpretation + ) + step_sum += int(ttree_step) entry_step = round(step_sum / len(ttrees)) @@ -1558,9 +1592,12 @@ def real_filter_branch(branch): divisions.append(divisions[-1] + length) partition_args.append((i, start, stop)) - base_form = _get_ttree_form( - awkward, ttrees[0], common_keys, interp_options.get("ak_add_doc") - ) + if isinstance(ttrees[0], HasFields): + base_form = ttrees[0].to_akform(filter_name=common_keys) + else: + base_form = _get_ttree_form( + awkward, ttrees[0], common_keys, interp_options.get("ak_add_doc") + ) if len(partition_args) == 0: divisions.append(0) @@ -1628,7 +1665,11 @@ def _get_dak_array_delay_open( recursive=recursive, filter_name=filter_name, filter_typename=filter_typename, - filter_branch=filter_branch, + **{ + ( + "filter_field" if isinstance(obj, HasFields) else "filter_branch" + ): filter_branch + }, full_paths=full_paths, ignore_duplicates=True, ) diff --git a/src/uproot/behaviors/RNTuple.py b/src/uproot/behaviors/RNTuple.py index 7bef644f0..ed688f703 100644 --- a/src/uproot/behaviors/RNTuple.py +++ b/src/uproot/behaviors/RNTuple.py @@ -627,6 +627,15 @@ def arrays( See also :ref:`uproot.behaviors.RNTuple.HasFields.iterate` to iterate over the array in contiguous ranges of entries. """ + # This temporarily provides basic functionality while expressions are properly implemented + if expressions is not None: + if filter_name == no_filter: + filter_name = expressions + else: + raise ValueError( + "Expressions are not supported yet. They are currently equivalent to filter_name." + ) + entry_start, entry_stop = ( uproot.behaviors.TBranch._regularize_entries_start_stop( self.num_entries, entry_start, entry_stop @@ -823,6 +832,15 @@ def iterate( See also :doc:`uproot.behaviors.RNTuple.iterate` to iterate over many files. """ + # This temporarily provides basic functionality while expressions are properly implemented + if expressions is not None: + if filter_name == no_filter: + filter_name = expressions + else: + raise ValueError( + "Expressions are not supported yet. They are currently equivalent to filter_name." + ) + entry_start, entry_stop = ( uproot.behaviors.TBranch._regularize_entries_start_stop( self.ntuple.num_entries, entry_start, entry_stop @@ -1316,6 +1334,15 @@ def num_entries_for( :ref:`uproot.behaviors.RNTuple.HasFields.iterate` uses to convert a ``step_size`` expressed in memory units into a number of entries. """ + # This temporarily provides basic functionality while expressions are properly implemented + if expressions is not None: + if filter_name == no_filter: + filter_name = expressions + else: + raise ValueError( + "Expressions are not supported yet. They are currently equivalent to filter_name." + ) + target_num_bytes = uproot._util.memory_size(memory_size) entry_start, entry_stop = ( diff --git a/tests/test_1223_more_rntuple_types.py b/tests/test_1223_more_rntuple_types.py index 95a791f57..965daf946 100644 --- a/tests/test_1223_more_rntuple_types.py +++ b/tests/test_1223_more_rntuple_types.py @@ -85,6 +85,6 @@ def test_invalid_variant(): with uproot.open(filename) as f: obj = f["ntuple"] - a = obj.arrays("variant") + a = obj.arrays("variant.*") assert a.variant.tolist() == [1, None, {"i": 2}] diff --git a/tests/test_1412_rntuple_dask.py b/tests/test_1412_rntuple_dask.py new file mode 100644 index 000000000..871817cd6 --- /dev/null +++ b/tests/test_1412_rntuple_dask.py @@ -0,0 +1,37 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE + +import os + +import numpy +import pytest + +import uproot + +ak = pytest.importorskip("awkward") +dask = pytest.importorskip("dask") + +data = ak.Array( + { + "ints": [1, 2, 3, 4, 5], + "floats": [1.1, 2.2, 3.3, 4.4, 5.5], + "strings": ["one", "two", "three", "four", "five"], + } +) + + +def test_dask(tmp_path): + filepath1 = os.path.join(tmp_path, "test1.root") + filepath2 = os.path.join(tmp_path, "test2.root") + + with uproot.recreate(filepath1) as file: + file.mkrntuple("ntuple", data) + + with uproot.recreate(filepath2) as file: + file.mkrntuple("ntuple", data) + + dask_arr = uproot.dask(f"{tmp_path}/test*.root:ntuple") + + arr = dask_arr.compute() + + assert ak.array_equal(arr[:5], data) + assert ak.array_equal(arr[5:], data)