Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 56 additions & 15 deletions src/uproot/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

import uproot
from uproot._util import no_filter, unset
from uproot.behaviors.RNTuple import HasFields
from uproot.behaviors.RNTuple import (
_regularize_step_size as _RNTuple_regularize_step_size,
)
from uproot.behaviors.TBranch import HasBranches, TBranch, _regularize_step_size

if TYPE_CHECKING:
Expand Down Expand Up @@ -616,7 +620,13 @@ def real_filter_branch(branch):
recursive=recursive,
filter_name=filter_name,
filter_typename=filter_typename,
filter_branch=real_filter_branch,
**{
(
"filter_field"
if isinstance(obj, HasFields)
else "filter_branch"
): real_filter_branch
},
full_paths=full_paths,
ignore_duplicates=True,
)
Expand Down Expand Up @@ -756,7 +766,11 @@ def _get_dask_array_delay_open(
recursive=recursive,
filter_name=filter_name,
filter_typename=filter_typename,
filter_branch=filter_branch,
**{
(
"filter_field" if isinstance(obj, HasFields) else "filter_branch"
): filter_branch
},
full_paths=full_paths,
ignore_duplicates=True,
)
Expand Down Expand Up @@ -925,6 +939,13 @@ def load_buffers(
how=tuple,
)

if isinstance(tree, HasFields):
# Temporary workaround to have basic support for RNTuple
# This is needed since currently arrays only has top-level fields
# TODO: Ask people how they want this handled since the previous
# approach might not make sense for RNTuples
keys = [f for f in tree.field_names if f in keys]

awkward = uproot.extras.awkward()

# The subform generated by awkward.to_buffers() has different form keys
Expand Down Expand Up @@ -1472,7 +1493,13 @@ def real_filter_branch(branch):
recursive=recursive,
filter_name=filter_name,
filter_typename=filter_typename,
filter_branch=real_filter_branch,
**{
(
"filter_field"
if isinstance(obj, HasFields)
else "filter_branch"
): real_filter_branch
},
full_paths=full_paths,
ignore_duplicates=True,
)
Expand Down Expand Up @@ -1515,14 +1542,21 @@ def real_filter_branch(branch):
entry_start = 0
entry_stop = ttree.num_entries

branchid_interpretation = {}
for key in common_keys:
branch = ttree[key]
branchid_interpretation[branch.cache_key] = branch.interpretation
ttree_step = _regularize_step_size(
ttree, step_size, entry_start, entry_stop, branchid_interpretation
)
step_sum += int(ttree_step)
if isinstance(ttree, HasFields):
akform = ttree.to_akform(filter_name=common_keys)
ttree_step = _RNTuple_regularize_step_size(
ttree, akform, step_size, entry_start, entry_stop
)
step_sum += int(ttree_step)
else:
branchid_interpretation = {}
for key in common_keys:
branch = ttree[key]
branchid_interpretation[branch.cache_key] = branch.interpretation
ttree_step = _regularize_step_size(
ttree, step_size, entry_start, entry_stop, branchid_interpretation
)
step_sum += int(ttree_step)

entry_step = round(step_sum / len(ttrees))

Expand Down Expand Up @@ -1558,9 +1592,12 @@ def real_filter_branch(branch):
divisions.append(divisions[-1] + length)
partition_args.append((i, start, stop))

base_form = _get_ttree_form(
awkward, ttrees[0], common_keys, interp_options.get("ak_add_doc")
)
if isinstance(ttrees[0], HasFields):
base_form = ttrees[0].to_akform(filter_name=common_keys)
else:
base_form = _get_ttree_form(
awkward, ttrees[0], common_keys, interp_options.get("ak_add_doc")
)

if len(partition_args) == 0:
divisions.append(0)
Expand Down Expand Up @@ -1628,7 +1665,11 @@ def _get_dak_array_delay_open(
recursive=recursive,
filter_name=filter_name,
filter_typename=filter_typename,
filter_branch=filter_branch,
**{
(
"filter_field" if isinstance(obj, HasFields) else "filter_branch"
): filter_branch
},
full_paths=full_paths,
ignore_duplicates=True,
)
Expand Down
27 changes: 27 additions & 0 deletions src/uproot/behaviors/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,15 @@ def arrays(
See also :ref:`uproot.behaviors.RNTuple.HasFields.iterate` to iterate over
the array in contiguous ranges of entries.
"""
# This temporarily provides basic functionality while expressions are properly implemented
if expressions is not None:
if filter_name == no_filter:
filter_name = expressions
else:
raise ValueError(
"Expressions are not supported yet. They are currently equivalent to filter_name."
)

entry_start, entry_stop = (
uproot.behaviors.TBranch._regularize_entries_start_stop(
self.num_entries, entry_start, entry_stop
Expand Down Expand Up @@ -823,6 +832,15 @@ def iterate(
See also :doc:`uproot.behaviors.RNTuple.iterate` to iterate over many
files.
"""
# This temporarily provides basic functionality while expressions are properly implemented
if expressions is not None:
if filter_name == no_filter:
filter_name = expressions
else:
raise ValueError(
"Expressions are not supported yet. They are currently equivalent to filter_name."
)

entry_start, entry_stop = (
uproot.behaviors.TBranch._regularize_entries_start_stop(
self.ntuple.num_entries, entry_start, entry_stop
Expand Down Expand Up @@ -1316,6 +1334,15 @@ def num_entries_for(
:ref:`uproot.behaviors.RNTuple.HasFields.iterate` uses to convert a
``step_size`` expressed in memory units into a number of entries.
"""
# This temporarily provides basic functionality while expressions are properly implemented
if expressions is not None:
if filter_name == no_filter:
filter_name = expressions
else:
raise ValueError(
"Expressions are not supported yet. They are currently equivalent to filter_name."
)

target_num_bytes = uproot._util.memory_size(memory_size)

entry_start, entry_stop = (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_1223_more_rntuple_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ def test_invalid_variant():
with uproot.open(filename) as f:
obj = f["ntuple"]

a = obj.arrays("variant")
a = obj.arrays("variant.*")

assert a.variant.tolist() == [1, None, {"i": 2}]
37 changes: 37 additions & 0 deletions tests/test_1412_rntuple_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE

import os

import numpy
import pytest

import uproot

ak = pytest.importorskip("awkward")
dask = pytest.importorskip("dask")

data = ak.Array(
{
"ints": [1, 2, 3, 4, 5],
"floats": [1.1, 2.2, 3.3, 4.4, 5.5],
"strings": ["one", "two", "three", "four", "five"],
}
)


def test_dask(tmp_path):
filepath1 = os.path.join(tmp_path, "test1.root")
filepath2 = os.path.join(tmp_path, "test2.root")

with uproot.recreate(filepath1) as file:
file.mkrntuple("ntuple", data)

with uproot.recreate(filepath2) as file:
file.mkrntuple("ntuple", data)

dask_arr = uproot.dask(f"{tmp_path}/test*.root:ntuple")

arr = dask_arr.compute()

assert ak.array_equal(arr[:5], data)
assert ak.array_equal(arr[5:], data)