diff --git a/fmriprep/workflows/bold/fit.py b/fmriprep/workflows/bold/fit.py index aa21ca72f..a585d5646 100644 --- a/fmriprep/workflows/bold/fit.py +++ b/fmriprep/workflows/bold/fit.py @@ -53,7 +53,7 @@ init_ds_registration_wf, init_func_fit_reports_wf, ) -from .reference import init_raw_boldref_wf +from .reference import init_raw_boldref_wf, init_validation_and_dummies_wf from .registration import init_bold_reg_wf from .stc import init_bold_stc_wf from .t2s import init_bold_t2s_wf @@ -407,15 +407,18 @@ def init_bold_fit_wf( ]) # fmt:skip else: config.loggers.workflow.info('Found HMC boldref - skipping Stage 1') - - validate_bold = pe.Node(ValidateImage(), name='validate_bold') - validate_bold.inputs.in_file = bold_file - hmcref_buffer.inputs.boldref = precomputed['hmc_boldref'] + validation_and_dummies_wf = init_validation_and_dummies_wf(bold_file=bold_file) + workflow.connect([ - (validate_bold, hmcref_buffer, [('out_file', 'bold_file')]), - (validate_bold, func_fit_reports_wf, [('out_report', 'inputnode.validation_report')]), + (validation_and_dummies_wf, hmcref_buffer, [ + ('outputnode.bold_file', 'bold_file'), + ('outputnode.skip_vols', 'dummy_scans'), + ]), + (validation_and_dummies_wf, func_fit_reports_wf, [ + ('outputnode.validation_report', 'inputnode.validation_report'), + ]), (hmcref_buffer, hmc_boldref_source_buffer, [('boldref', 'in_file')]), ]) # fmt:skip @@ -591,6 +594,14 @@ def init_bold_fit_wf( config.loggers.workflow.info('Found coregistration reference - skipping Stage 3') regref_buffer.inputs.boldref = precomputed['coreg_boldref'] + # TODO: Allow precomputed bold masks to be passed + # Also needs consideration for how it interacts above + skullstrip_precomp_ref_wf = init_skullstrip_bold_wf(name='skullstrip_precomp_ref_wf') + skullstrip_precomp_ref_wf.inputs.inputnode.in_file = precomputed['coreg_boldref'] + workflow.connect([ + (skullstrip_precomp_ref_wf, regref_buffer, [('outputnode.mask_file', 'boldmask')]) + ]) # fmt:skip + if not boldref2anat_xform: # calculate BOLD registration to T1w bold_reg_wf = init_bold_reg_wf( diff --git a/fmriprep/workflows/bold/reference.py b/fmriprep/workflows/bold/reference.py index 68d53e935..b1f395e28 100644 --- a/fmriprep/workflows/bold/reference.py +++ b/fmriprep/workflows/bold/reference.py @@ -80,7 +80,6 @@ def init_raw_boldref_wf( beginning of ``bold_file`` """ - from niworkflows.interfaces.bold import NonsteadyStatesDetector from niworkflows.interfaces.images import RobustAverage workflow = Workflow(name=name) @@ -106,6 +105,96 @@ def init_raw_boldref_wf( name='outputnode', ) + # Simplify manually setting input image + if bold_file is not None: + inputnode.inputs.bold_file = bold_file + + validation_and_dummies_wf = init_validation_and_dummies_wf() + + gen_avg = pe.Node(RobustAverage(), name='gen_avg', mem_gb=1) + + workflow.connect([ + (inputnode, validation_and_dummies_wf, [ + ('bold_file', 'inputnode.bold_file'), + ('dummy_scans', 'inputnode.dummy_scans'), + ]), + (validation_and_dummies_wf, gen_avg, [ + ('outputnode.bold_file', 'in_file'), + ('outputnode.t_mask', 't_mask'), + ]), + (validation_and_dummies_wf, outputnode, [ + ('outputnode.bold_file', 'bold_file'), + ('outputnode.skip_vols', 'skip_vols'), + ('outputnode.algo_dummy_scans', 'algo_dummy_scans'), + ('outputnode.validation_report', 'validation_report'), + ]), + (gen_avg, outputnode, [('out_file', 'boldref')]), + ]) # fmt:skip + + return workflow + + +def init_validation_and_dummies_wf( + bold_file=None, + name='validation_and_dummies_wf', +): + """ + Build a workflow that validates a BOLD image and detects non-steady-state volumes. + + Workflow Graph + .. workflow:: + :graph2use: orig + :simple_form: yes + + from fmriprep.workflows.bold.reference import init_validation_and_dummies_wf + wf = init_validation_and_dummies_wf() + + Parameters + ---------- + bold_file : :obj:`str` + BOLD series NIfTI file + name : :obj:`str` + Name of workflow (default: ``validation_and_dummies_wf``) + + Inputs + ------ + bold_file : str + BOLD series NIfTI file + dummy_scans : int or None + Number of non-steady-state volumes specified by user at beginning of ``bold_file`` + + Outputs + ------- + bold_file : str + Validated BOLD series NIfTI file + skip_vols : int + Number of non-steady-state volumes selected at beginning of ``bold_file`` + algo_dummy_scans : int + Number of non-steady-state volumes agorithmically detected at + beginning of ``bold_file`` + + """ + from niworkflows.interfaces.bold import NonsteadyStatesDetector + + workflow = Workflow(name=name) + + inputnode = pe.Node( + niu.IdentityInterface(fields=['bold_file', 'dummy_scans']), + name='inputnode', + ) + outputnode = pe.Node( + niu.IdentityInterface( + fields=[ + 'bold_file', + 'skip_vols', + 'algo_dummy_scans', + 't_mask', + 'validation_report', + ] + ), + name='outputnode', + ) + # Simplify manually setting input image if bold_file is not None: inputnode.inputs.bold_file = bold_file @@ -117,7 +206,6 @@ def init_raw_boldref_wf( ) get_dummy = pe.Node(NonsteadyStatesDetector(), name='get_dummy') - gen_avg = pe.Node(RobustAverage(), name='gen_avg', mem_gb=1) calc_dummy_scans = pe.Node( niu.Function(function=pass_dummy_scans, output_names=['skip_vols_num']), @@ -126,22 +214,20 @@ def init_raw_boldref_wf( mem_gb=DEFAULT_MEMORY_MIN_GB, ) - # fmt: off workflow.connect([ (inputnode, val_bold, [('bold_file', 'in_file')]), - (inputnode, get_dummy, [('bold_file', 'in_file')]), - (inputnode, calc_dummy_scans, [('dummy_scans', 'dummy_scans')]), - (val_bold, gen_avg, [('out_file', 'in_file')]), - (get_dummy, gen_avg, [('t_mask', 't_mask')]), - (get_dummy, calc_dummy_scans, [('n_dummy', 'algo_dummy_scans')]), (val_bold, outputnode, [ ('out_file', 'bold_file'), ('out_report', 'validation_report'), ]), + (inputnode, get_dummy, [('bold_file', 'in_file')]), + (inputnode, calc_dummy_scans, [('dummy_scans', 'dummy_scans')]), + (get_dummy, calc_dummy_scans, [('n_dummy', 'algo_dummy_scans')]), + (get_dummy, outputnode, [ + ('n_dummy', 'algo_dummy_scans'), + ('t_mask', 't_mask'), + ]), (calc_dummy_scans, outputnode, [('skip_vols_num', 'skip_vols')]), - (gen_avg, outputnode, [('out_file', 'boldref')]), - (get_dummy, outputnode, [('n_dummy', 'algo_dummy_scans')]), - ]) - # fmt: on + ]) # fmt:skip return workflow