Skip to content

fix: Calculate bold mask and dummy scans in transform-only runs #3428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions fmriprep/workflows/bold/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
init_ds_registration_wf,
init_func_fit_reports_wf,
)
from .reference import init_raw_boldref_wf
from .reference import init_raw_boldref_wf, init_validation_and_dummies_wf
from .registration import init_bold_reg_wf
from .stc import init_bold_stc_wf
from .t2s import init_bold_t2s_wf
Expand Down Expand Up @@ -407,15 +407,18 @@ def init_bold_fit_wf(
]) # fmt:skip
else:
config.loggers.workflow.info('Found HMC boldref - skipping Stage 1')

validate_bold = pe.Node(ValidateImage(), name='validate_bold')
validate_bold.inputs.in_file = bold_file

hmcref_buffer.inputs.boldref = precomputed['hmc_boldref']

validation_and_dummies_wf = init_validation_and_dummies_wf(bold_file=bold_file)

workflow.connect([
(validate_bold, hmcref_buffer, [('out_file', 'bold_file')]),
(validate_bold, func_fit_reports_wf, [('out_report', 'inputnode.validation_report')]),
(validation_and_dummies_wf, hmcref_buffer, [
('outputnode.bold_file', 'bold_file'),
('outputnode.skip_vols', 'dummy_scans'),
]),
(validation_and_dummies_wf, func_fit_reports_wf, [
('outputnode.validation_report', 'inputnode.validation_report'),
]),
(hmcref_buffer, hmc_boldref_source_buffer, [('boldref', 'in_file')]),
]) # fmt:skip

Expand Down Expand Up @@ -591,6 +594,14 @@ def init_bold_fit_wf(
config.loggers.workflow.info('Found coregistration reference - skipping Stage 3')
regref_buffer.inputs.boldref = precomputed['coreg_boldref']

# TODO: Allow precomputed bold masks to be passed
# Also needs consideration for how it interacts above
skullstrip_precomp_ref_wf = init_skullstrip_bold_wf(name='skullstrip_precomp_ref_wf')
skullstrip_precomp_ref_wf.inputs.inputnode.in_file = precomputed['coreg_boldref']
workflow.connect([
(skullstrip_precomp_ref_wf, regref_buffer, [('outputnode.mask_file', 'boldmask')])
]) # fmt:skip

if not boldref2anat_xform:
# calculate BOLD registration to T1w
bold_reg_wf = init_bold_reg_wf(
Expand Down
110 changes: 98 additions & 12 deletions fmriprep/workflows/bold/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def init_raw_boldref_wf(
beginning of ``bold_file``

"""
from niworkflows.interfaces.bold import NonsteadyStatesDetector
from niworkflows.interfaces.images import RobustAverage

workflow = Workflow(name=name)
Expand All @@ -106,6 +105,96 @@ def init_raw_boldref_wf(
name='outputnode',
)

# Simplify manually setting input image
if bold_file is not None:
inputnode.inputs.bold_file = bold_file

validation_and_dummies_wf = init_validation_and_dummies_wf()

gen_avg = pe.Node(RobustAverage(), name='gen_avg', mem_gb=1)

workflow.connect([
(inputnode, validation_and_dummies_wf, [
('bold_file', 'inputnode.bold_file'),
('dummy_scans', 'inputnode.dummy_scans'),
]),
(validation_and_dummies_wf, gen_avg, [
('outputnode.bold_file', 'in_file'),
('outputnode.t_mask', 't_mask'),
]),
(validation_and_dummies_wf, outputnode, [
('outputnode.bold_file', 'bold_file'),
('outputnode.skip_vols', 'skip_vols'),
('outputnode.algo_dummy_scans', 'algo_dummy_scans'),
('outputnode.validation_report', 'validation_report'),
]),
(gen_avg, outputnode, [('out_file', 'boldref')]),
]) # fmt:skip

return workflow


def init_validation_and_dummies_wf(
bold_file=None,
name='validation_and_dummies_wf',
):
"""
Build a workflow that validates a BOLD image and detects non-steady-state volumes.

Workflow Graph
.. workflow::
:graph2use: orig
:simple_form: yes

from fmriprep.workflows.bold.reference import init_validation_and_dummies_wf
wf = init_validation_and_dummies_wf()

Parameters
----------
bold_file : :obj:`str`
BOLD series NIfTI file
name : :obj:`str`
Name of workflow (default: ``validation_and_dummies_wf``)

Inputs
------
bold_file : str
BOLD series NIfTI file
dummy_scans : int or None
Number of non-steady-state volumes specified by user at beginning of ``bold_file``

Outputs
-------
bold_file : str
Validated BOLD series NIfTI file
skip_vols : int
Number of non-steady-state volumes selected at beginning of ``bold_file``
algo_dummy_scans : int
Number of non-steady-state volumes agorithmically detected at
beginning of ``bold_file``

"""
from niworkflows.interfaces.bold import NonsteadyStatesDetector

workflow = Workflow(name=name)

inputnode = pe.Node(
niu.IdentityInterface(fields=['bold_file', 'dummy_scans']),
name='inputnode',
)
outputnode = pe.Node(
niu.IdentityInterface(
fields=[
'bold_file',
'skip_vols',
'algo_dummy_scans',
't_mask',
'validation_report',
]
),
name='outputnode',
)

# Simplify manually setting input image
if bold_file is not None:
inputnode.inputs.bold_file = bold_file
Expand All @@ -117,7 +206,6 @@ def init_raw_boldref_wf(
)

get_dummy = pe.Node(NonsteadyStatesDetector(), name='get_dummy')
gen_avg = pe.Node(RobustAverage(), name='gen_avg', mem_gb=1)

calc_dummy_scans = pe.Node(
niu.Function(function=pass_dummy_scans, output_names=['skip_vols_num']),
Expand All @@ -126,22 +214,20 @@ def init_raw_boldref_wf(
mem_gb=DEFAULT_MEMORY_MIN_GB,
)

# fmt: off
workflow.connect([
(inputnode, val_bold, [('bold_file', 'in_file')]),
(inputnode, get_dummy, [('bold_file', 'in_file')]),
(inputnode, calc_dummy_scans, [('dummy_scans', 'dummy_scans')]),
(val_bold, gen_avg, [('out_file', 'in_file')]),
(get_dummy, gen_avg, [('t_mask', 't_mask')]),
(get_dummy, calc_dummy_scans, [('n_dummy', 'algo_dummy_scans')]),
(val_bold, outputnode, [
('out_file', 'bold_file'),
('out_report', 'validation_report'),
]),
(inputnode, get_dummy, [('bold_file', 'in_file')]),
(inputnode, calc_dummy_scans, [('dummy_scans', 'dummy_scans')]),
(get_dummy, calc_dummy_scans, [('n_dummy', 'algo_dummy_scans')]),
(get_dummy, outputnode, [
('n_dummy', 'algo_dummy_scans'),
('t_mask', 't_mask'),
]),
(calc_dummy_scans, outputnode, [('skip_vols_num', 'skip_vols')]),
(gen_avg, outputnode, [('out_file', 'boldref')]),
(get_dummy, outputnode, [('n_dummy', 'algo_dummy_scans')]),
])
# fmt: on
]) # fmt:skip

return workflow
Loading