Source code for med_dataloader.med_dataloader

"""Main module."""

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import SimpleITK as sitk
import json
import numpy as np
import sys

AUTOTUNE = tf.data.experimental.AUTOTUNE

__dataloader_modality__ = ["get", "gen"]
__dict_dtype__ = {"int8": tf.int8,
                  "int16": tf.int16,
                  "int32": tf.int32,
                  "int64": tf.int64,
                  "uint8": tf.uint8,
                  "uint16": tf.uint16,
                  "uint32": tf.uint32,
                  "uint64": tf.uint64,
                  "float32": tf.float32,
                  "float64": tf.float64,
                  }


[docs]class DataLoader: """[summary]"""
[docs] def __init__( self, mode, imgA_label=None, imgB_label=None, input_size=None, data_dir="./Data", output_dir=None, is_B_categorical=False, num_classes=None, norm_boundsA=None, norm_boundsB=None, extract_only=None, use_3D=False, ): """[summary] Args: mode ([type]): [description] imgA_label (str): Identifier for class A. It's the name of the folder inside :py:attr:`data_dir` that contains images labeled as class A. imgB_label (str): Identifier for class B. It's the name of the folder inside :py:attr:`data_dir` that contains images labeled as class B. input_size (int): Dimension of a single image, defined as input_size x input_size. Currently, it supports only squared images. data_dir (str, optional): Path to directory that contains the Dataset. This folder **must** contain two subfolders named like :py:attr:`imgA_label` and :py:attr:`imgB_label`. Defaults to './Data'. output_dir ([type], optional): [description]. Defaults to None. is_B_categorical (bool, optional): [description]. Defaults to False. num_classes ([type], optional): [description]. Defaults to None. norm_boundsA ([type], optional): [description]. Defaults to None. norm_boundsB ([type], optional): [description]. Defaults to None. extract_only (int, optional): Indicate wheter to partially cache a certain amount of elements in the dataset. Please remember that if :py:attr:`output_dir` folder is already populated, you need to clean this folder content to recreate a partial cache file. When it is set to None, the entire Dataset is cached. Defaults to None. use_3D: Indicate whether to use three-dimensional data in the cache (if True) or to extract two-dimensional slices from the 3D volumes (if False). Defaults to False. Raises: ValueError: [description] FileNotFoundError: [description] ValueError: [description] FileNotFoundError: [description] FileNotFoundError: [description] ValueError: [description] ValueError: [description] ValueError: [description] FileNotFoundError: [description] """ if mode not in __dataloader_modality__: raise ValueError(f"{mode} modality not recognized. Choose between 'gen' or 'get'") # noqa self.mode = mode if mode == "gen": if not os.path.exists(data_dir): raise FileNotFoundError(f"{data_dir} does not exist") self.data_dir = data_dir if output_dir is None: self.output_dir = os.path.join(os.path.dirname(data_dir), f"{os.path.basename(data_dir)}_TF") # noqa else: self.output_dir = output_dir os.makedirs(self.output_dir, exist_ok=True) if input_size is None: raise ValueError("input_size is None") self.input_size = input_size if (not isinstance(use_3D, bool)): raise ValueError("use_3D is not a Boolean value") self.use_3D = use_3D if imgA_label is None or imgB_label is None: raise ValueError("imgA_label or imgB_label is None.") self.imgA_label = imgA_label self.imgB_label = imgB_label if not os.path.exists(os.path.join(data_dir, imgA_label)): raise FileNotFoundError(f"{imgA_label} does not exist") if not os.path.exists(os.path.join(data_dir, imgB_label)): raise FileNotFoundError(f"{imgB_label} does not exist") self.imgA_paths, self.imgB_paths = self.get_imgs_paths() if extract_only is not None: self.imgA_paths = self.imgA_paths[:extract_only] self.imgB_paths = self.imgB_paths[:extract_only] self.is_3D = self.is_3D_data(self.imgA_paths[0]) if self.is_3D: self.is_A_RGB = self.is_RGB_data(self.imgA_paths[0]) self.is_B_RGB = self.is_RGB_data(self.imgB_paths[0]) if self.is_A_RGB or self.is_B_RGB: self.is_3D = False else: self.is_A_RGB = False self.is_B_RGB = False if ((not self.is_3D) and (self.use_3D)): raise ValueError( "Image files are not 3D but use_3D was set to True") self.imgA_type = self.check_type(self.imgA_paths[0]) self.imgB_type = self.check_type(self.imgB_paths[0]) self.is_B_categorical = is_B_categorical self.num_classes = num_classes if norm_boundsA is not None: if norm_boundsA[0] >= norm_boundsA[1]: raise ValueError( f"Lower lim for normalization ({norm_boundsA[0]}) must be lower than upper lim ({norm_boundsA[1]})") # noqa self.norm_boundsA = norm_boundsA else: self.norm_boundsA = None if norm_boundsB is not None: if norm_boundsB[0] >= norm_boundsB[1]: raise ValueError( f"\rLower lim for normalization ({norm_boundsB[0]}) must be lower than upper lim ({norm_boundsB[1]})") # noqa self.norm_boundsB = norm_boundsB else: self.norm_boundsB = norm_boundsB dataset_property = {"is3D": True, "input_size": self.input_size, "imgA_label": self.imgA_label, "imgB_label": self.imgB_label, "imgA_type": self.imgA_type, "imgB_type": self.imgB_type, "is_A_RGB": self.is_A_RGB, "is_B_RGB": self.is_B_RGB, "is_B_categorical": self.is_B_categorical, "num_classes": self.num_classes, "norm_boundsA": self.norm_boundsA, "norm_boundsB": self.norm_boundsB, "use_3D": self.use_3D } output_dir_content = os.listdir(self.output_dir) if output_dir_content is not None: if "ds_property.json" not in output_dir_content: # folder is not empty, but property file is missing, # we need to write it write_property = True elif len(output_dir_content) == 1 and "ds_property.json" in output_dir_content: # noqa # folder contains only an old version of property file, # we need to overwrite it write_property = True else: # every necessary file already exist write_property = False else: write_property = False if write_property: with open(os.path.join(self.output_dir, "ds_property.json"), 'w') as property_file: # noqa json.dump(dataset_property, property_file, indent=2) elif mode == "get": if not os.path.exists(data_dir): raise FileNotFoundError(f"{data_dir} does not exist") self.output_dir = data_dir # dummy variables for images path self.imgA_paths = [] self.imgB_paths = [] with open(os.path.join(self.output_dir, "ds_property.json"), 'r') as property_file: dataset_property = json.load(property_file) self.is_3D = dataset_property["is3D"] self.input_size = dataset_property["input_size"] self.imgA_label = dataset_property["imgA_label"] self.imgB_label = dataset_property["imgB_label"] self.imgA_type = dataset_property["imgA_type"] self.imgB_type = dataset_property["imgB_type"] self.is_A_RGB = dataset_property["is_A_RGB"] self.is_B_RGB = dataset_property["is_B_RGB"] self.is_B_categorical = dataset_property["is_B_categorical"] self.num_classes = dataset_property["num_classes"] self.norm_boundsA = dataset_property["norm_boundsA"] self.norm_boundsB = dataset_property["norm_boundsB"] self.use_3D = dataset_property["use_3D"]
[docs] def get_dataset(self, batch_size=32, augmentation=False, random_crop_size=None, random_rotate=False, random_flip=False): ds = tf.data.Dataset.zip((self.get_imgs(img_paths=self.imgA_paths, img_label=self.imgA_label, img_type=self.imgA_type, is_RGB=self.is_A_RGB, norm_bounds=self.norm_boundsA), self.get_imgs(img_paths=self.imgB_paths, img_label=self.imgB_label, img_type=self.imgB_type, is_RGB=self.is_B_RGB, is_categorical=self.is_B_categorical, # noqa num_classes=self.num_classes, norm_bounds=self.norm_boundsB) )) if augmentation: if random_crop_size: ds = ds.map( lambda imgA, imgB: self.random_crop(imgA, imgB, random_crop_size), num_parallel_calls=AUTOTUNE) if random_rotate: ds = ds.map(self.random_rotate, num_parallel_calls=AUTOTUNE) if random_flip: ds = ds.map(self.random_flip, num_parallel_calls=AUTOTUNE) ds = ds.batch(batch_size) ds = ds.prefetch(buffer_size=AUTOTUNE) return ds
[docs] def get_imgs(self, img_paths, img_label, img_type, is_RGB, is_categorical=False, num_classes=None, norm_bounds=None): """Open image files for one class and store it inside cache. This function performs all the (usually) slow reading operations that is necessary to execute at least the first time. After the first execution information are saved inside some cache file inside Cache folder (typically created in your Dataset folder, at the same level of Images folder). This function detects if cache files are already present, and in that case it skips the definition of these files. Please take into account that cache files will be as big as your Dataset overall size. First execution may result in a considerably bigger amount of time. Args: img_paths(str): Path to single class images. Returns: tf.Data.Dataset: Tensorflow dataset object containing images of one classes converted in Tensor format, without any other computations. """ cache_file = os.path.join(self.output_dir, f"{img_label}.cache") index_file = f"{cache_file}.index" ds = tf.data.Dataset.from_tensor_slices(img_paths) ds = ds.map(lambda path: tf.py_function(self.open_img, [path], [__dict_dtype__[img_type]], ), num_parallel_calls=AUTOTUNE) if self.is_3D and (not self.use_3D): ds = ds.unbatch() if is_RGB: ds = ds.map(lambda img: tf.image.rgb_to_grayscale(img), num_parallel_calls=AUTOTUNE) # ds = ds.map(lambda img: self.check_dims(img, # self.input_size), # num_parallel_calls=AUTOTUNE) ds = ds.map(lambda img: self.fix_image_dims(img, self.input_size), num_parallel_calls=AUTOTUNE) if is_categorical: ds = ds.map(lambda img: tf.one_hot(tf.squeeze(tf.cast(img, img_type)), depth=int(num_classes)), num_parallel_calls=AUTOTUNE) ds = ds.map(lambda img: tf.cast(img, img_type), num_parallel_calls=AUTOTUNE) if norm_bounds is not None: ds = ds.map(lambda img: self.norm_with_bounds(img, norm_bounds), num_parallel_calls=AUTOTUNE) ds = ds.cache(cache_file) if not os.path.exists(index_file): self._populate_cache(ds, cache_file, len(img_paths)) return ds
[docs] def get_imgs_paths(self): """Get paths of every single image divided by classes. Returns: list, list: two list containing the paths of every images for both classes. The list is sorted alphabetically, this can be usefull when images are named with a progressive number inside a folder (e.g.: 001.xxx, 002.xxx, ..., 999.xxx) """ # print("Fetching images paths...") subset_dir_imgA = os.path.join(self.data_dir, self.imgA_label) subset_dir_imgB = os.path.join(self.data_dir, self.imgB_label) filenames_imgA = os.listdir(subset_dir_imgA) filenames_imgB = os.listdir(subset_dir_imgB) paths_imgA = [os.path.join(subset_dir_imgA, img) for img in filenames_imgA] paths_imgB = [os.path.join(subset_dir_imgB, img) for img in filenames_imgB] # print("Images paths collected.") # Sort paths alphabetically paths_imgA.sort() paths_imgB.sort() if len(paths_imgA) != len(paths_imgB): raise ValueError( f"Dimension mismatch: {len(paths_imgA)} != {len(paths_imgB)}") return paths_imgA, paths_imgB
[docs] def open_img(self, path): """Open an image file and convert it to a tensor. Args: path(tf.Tensor): Tensor containing the path to the file to be opened. Returns: tf.Tensor: Tensor containing the actual image content. """ path = path.numpy().decode("utf-8") image = sitk.GetArrayFromImage(sitk.ReadImage(path)) if (self.use_3D): image = np.transpose(image, axes=(2, 1, 0)) tensor = tf.convert_to_tensor(image) return tensor
def _populate_cache(self, ds, cache_file, num_tot): print(f"Caching decoded images in {cache_file}...") i = 0 for _ in ds: i += 1 sys.stdout.write("\r") sys.stdout.write(f"{i}/{num_tot}") sys.stdout.flush() print(f"\nCached decoded images in {cache_file}.")
[docs] @ staticmethod def is_3D_data(path): image = sitk.GetArrayFromImage(sitk.ReadImage(path)) if len(image.shape) == 3: return True elif len(image.shape) == 2: return False else: raise ValueError("Work only with 2D or 3D files.")
[docs] @ staticmethod def is_RGB_data(path): image = sitk.GetArrayFromImage(sitk.ReadImage(path)) if image.shape[-1] == 3: return True else: return False
[docs] @ staticmethod def check_type(path): image = sitk.GetArrayFromImage(sitk.ReadImage(path)) img_type = image.dtype.name return img_type
[docs] def fix_image_dims(self, img, size): """Fix tensor dimensions so that they are of the proper size to carry out Tensorflow operations. This function performs three steps: #. `Squeeze <https://www.tensorflow.org/api_docs/python/tf/squeeze>`_ to remove axis with dimension of 1 #. `Expand <https://www.tensorflow.org/api_docs/python/tf/expand_dims>`_ the dimensions of the tensor by adding one axis #. `Resize and pad <https://www.tensorflow.org/api_docs/python/tf/image/resize_with_pad>`_ the tensor to a target width and height If `use_3D` was enabled, volume is not resized and padded. Args: img: image or volume to be processed size: desired size of image or volume in the two/three axis. """ img = tf.expand_dims(tf.squeeze(img), axis=-1) if (not self.use_3D): img = tf.image.resize_with_pad(img, size, size) return img
# ------------------------------------------------------------------------- # Transformations # -------------------------------------------------------------------------
[docs] @ staticmethod def norm_with_bounds(image, bounds): """Image normalisation. Normalises image in the range defined by lb and ub to fit[0, 1] range.""" lb = tf.cast(bounds[0], dtype=image.dtype) ub = tf.cast(bounds[1], dtype=image.dtype) image = tf.where(image < lb, lb, image) image = tf.where(image > ub, ub, image) image = image - lb image /= (ub - lb) return image
[docs] @ staticmethod def random_crop(imgA, imgB, crop_size=256): stacked_img = tf.stack([imgA, imgB], axis=0) cropped_img = tf.image.random_crop( stacked_img, size=[2, crop_size, crop_size, 1] ) cropped_img = tf.split(cropped_img, 2) imgA = tf.expand_dims(tf.squeeze(cropped_img[0]), axis=-1) imgB = tf.expand_dims(tf.squeeze(cropped_img[1]), axis=-1) return imgA, imgB
[docs] @ staticmethod def random_flip(imgA, imgB): if tf.random.uniform(()) > 0.5: imgA = tf.image.flip_left_right(imgA) imgB = tf.image.flip_left_right(imgB) return imgA, imgB
[docs] @ staticmethod def random_rotate(imgA, imgB): rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32) imgA = tf.image.rot90(imgA, k=rn) imgB = tf.image.rot90(imgB, k=rn) return imgA, imgB
[docs]def generate_dataset(data_dir, imgA_label, imgB_label, input_size, output_dir=None, extract_only=None, norm_boundsA=None, norm_boundsB=None, is_B_categorical=False, num_classes=None, use_3D=False, ): data_loader = DataLoader(mode="gen", input_size=input_size, imgA_label=imgA_label, imgB_label=imgB_label, data_dir=data_dir, output_dir=output_dir, is_B_categorical=is_B_categorical, num_classes=num_classes, norm_boundsA=norm_boundsA, norm_boundsB=norm_boundsB, extract_only=extract_only, use_3D=use_3D ) data_loader.get_dataset() return
[docs]def get_dataset(data_dir, percentages, batch_size, train_augmentation=True, random_crop_size=None, random_rotate=True, random_flip=True ): if len(percentages) != 3: raise ValueError("Percentages has to be a list of 3 elements") if round((percentages[0] + percentages[1] + percentages[2]), 1) != 1.0: raise ValueError("Sum of percentages has to be 1") data_loader = DataLoader(mode="get", data_dir=data_dir, ) complete_ds = data_loader.get_dataset(batch_size=batch_size, augmentation=train_augmentation, random_crop_size=random_crop_size, random_rotate=random_rotate, random_flip=random_flip) complete_ds = complete_ds.unbatch() # Compute length of dataset num_imgs = 0 for img in complete_ds: num_imgs += 1 train_ends = int(num_imgs * percentages[0]) valid_begins = train_ends valid_ends = valid_begins + int(num_imgs * percentages[1]) # Train Datasets train_ds = complete_ds.take(train_ends) train_ds = train_ds.batch(batch_size) # Same as before, but without augmentation since now we want to obtain # validation and test set complete_ds = data_loader.get_dataset(batch_size=batch_size, augmentation=False) complete_ds = complete_ds.unbatch() # Validation Datasets valid_ds = complete_ds.take(valid_ends) valid_ds = valid_ds.skip(valid_begins) valid_ds = valid_ds.batch(batch_size) # Test Datasets test_ds = complete_ds.skip(valid_ends) test_ds = test_ds.batch(batch_size) return train_ds, valid_ds, test_ds