@@ -866,79 +866,69 @@ def __dask_postcompute__(self):
866
866
import dask
867
867
868
868
info = [
869
- (True , k , v .__dask_postcompute__ () )
869
+ (k , None ) + v .__dask_postcompute__ ()
870
870
if dask .is_dask_collection (v )
871
- else (False , k , v )
871
+ else (k , v , None , None )
872
872
for k , v in self ._variables .items ()
873
873
]
874
- args = (
875
- info ,
874
+ construct_direct_args = (
876
875
self ._coord_names ,
877
876
self ._dims ,
878
877
self ._attrs ,
879
878
self ._indexes ,
880
879
self ._encoding ,
881
880
self ._close ,
882
881
)
883
- return self ._dask_postcompute , args
882
+ return self ._dask_postcompute , ( info , construct_direct_args )
884
883
885
884
def __dask_postpersist__ (self ):
886
885
import dask
887
886
888
887
info = [
889
- (True , k , v .__dask_postpersist__ () )
888
+ (k , None , v .__dask_keys__ ()) + v . __dask_postpersist__ ( )
890
889
if dask .is_dask_collection (v )
891
- else (False , k , v )
890
+ else (k , v , None , None , None )
892
891
for k , v in self ._variables .items ()
893
892
]
894
- args = (
895
- info ,
893
+ construct_direct_args = (
896
894
self ._coord_names ,
897
895
self ._dims ,
898
896
self ._attrs ,
899
897
self ._indexes ,
900
898
self ._encoding ,
901
899
self ._close ,
902
900
)
903
- return self ._dask_postpersist , args
901
+ return self ._dask_postpersist , ( info , construct_direct_args )
904
902
905
903
@staticmethod
906
- def _dask_postcompute (results , info , * args ):
904
+ def _dask_postcompute (results , info , construct_direct_args ):
907
905
variables = {}
908
- results2 = list (results [::- 1 ])
909
- for is_dask , k , v in info :
910
- if is_dask :
911
- func , args2 = v
912
- r = results2 .pop ()
913
- result = func (r , * args2 )
906
+ results_iter = iter (results )
907
+ for k , v , rebuild , rebuild_args in info :
908
+ if v is None :
909
+ variables [k ] = rebuild (next (results_iter ), * rebuild_args )
914
910
else :
915
- result = v
916
- variables [k ] = result
911
+ variables [k ] = v
917
912
918
- final = Dataset ._construct_direct (variables , * args )
913
+ final = Dataset ._construct_direct (variables , * construct_direct_args )
919
914
return final
920
915
921
916
@staticmethod
922
- def _dask_postpersist (dsk , info , * args ):
917
+ def _dask_postpersist (dsk , info , construct_direct_args ):
918
+ from dask .optimization import cull
919
+
923
920
variables = {}
924
921
# postpersist is called in both dask.optimize and dask.persist
925
922
# When persisting, we want to filter out unrelated keys for
926
923
# each Variable's task graph.
927
- is_persist = len (dsk ) == len (info )
928
- for is_dask , k , v in info :
929
- if is_dask :
930
- func , args2 = v
931
- if is_persist :
932
- name = args2 [1 ][0 ]
933
- dsk2 = {k : v for k , v in dsk .items () if k [0 ] == name }
934
- else :
935
- dsk2 = dsk
936
- result = func (dsk2 , * args2 )
924
+ for k , v , dask_keys , rebuild , rebuild_args in info :
925
+ if v is None :
926
+ dsk2 , _ = cull (dsk , dask_keys )
927
+ variables [k ] = rebuild (dsk2 , * rebuild_args )
937
928
else :
938
- result = v
939
- variables [k ] = result
929
+ variables [k ] = v
940
930
941
- return Dataset ._construct_direct (variables , * args )
931
+ return Dataset ._construct_direct (variables , * construct_direct_args )
942
932
943
933
def compute (self , ** kwargs ) -> "Dataset" :
944
934
"""Manually trigger loading and/or computation of this dataset's data
0 commit comments