Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions cellpose/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@

io_logger = logging.getLogger(__name__)


def logger_setup():
cp_dir = pathlib.Path.home().joinpath(".cellpose")
def logger_setup(cp_path=".cellpose", logfile_name="run.log"):
cp_dir = pathlib.Path.home().joinpath(cp_path)
cp_dir.mkdir(exist_ok=True)
log_file = cp_dir.joinpath("run.log")
log_file = cp_dir.joinpath(logfile_name)
try:
log_file.unlink()
except:
Expand All @@ -70,8 +69,7 @@ def logger_setup():

from . import utils, plot, transforms


# helper function to check for a path; if it doesn"t exist, make it
# helper function to check for a path; if it doesn't exist, make it
def check_dir(path):
if not os.path.isdir(path):
os.mkdir(path)
Expand Down Expand Up @@ -338,7 +336,6 @@ def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):

return image_names


def get_label_files(image_names, mask_filter, imf=None):
"""
Get the label files corresponding to the given image names and mask filter.
Expand Down Expand Up @@ -440,7 +437,6 @@ def load_images_labels(tdir, mask_filter="_masks", image_filter=None,
io_logger.info(f"{k} / {nimg} images in {tdir} folder have labels")
return images, labels, image_names


def load_train_test_data(train_dir, test_dir=None, image_filter=None,
mask_filter="_masks", look_one_level_down=False):
"""
Expand All @@ -463,7 +459,6 @@ def load_train_test_data(train_dir, test_dir=None, image_filter=None,
"""
images, labels, image_names = load_images_labels(train_dir, mask_filter,
image_filter, look_one_level_down)

# testing data
test_images, test_labels, test_image_names = None, None, None
if test_dir is not None:
Expand Down Expand Up @@ -568,12 +563,11 @@ def masks_flows_to_seg(images, masks, flows, file_names, diams=30., channels=Non

np.save(base + "_seg.npy", dat)


def save_to_png(images, masks, flows, file_names):
""" deprecated (runs io.save_masks with png=True)
""" deprecated (runs io.save_masks with png=True)

does not work for 3D images

"""
save_masks(images, masks, flows, file_names, png=True)

Expand Down