Source code for unagi.train

#!/usr/bin/python3

import os
import random
import time
import argparse
from copy import deepcopy
from typing import List, Tuple, Type, Any, List

import cv2
import numpy as np
import PIL
from Augmentor import DataPipeline
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import Sequence
from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model

from unagi.utils.augmentor_utils import (
    GaussianNoiseAugmentor,
    InvertPartAugmentor,
    SaltPepperNoiseAugmentor,
)
from unagi.utils.callback_utils import create_callbacks
from unagi.utils.img_processing_utils import ImageUtils
from unagi.utils.metric_utils import ModelMetrics
from unagi.utils.model_utils import UNAGIModel


[docs]class ParallelDataGenerator(Sequence): """Generate images for training/validation/testing (parallel version). Parameters ---------- fnames_in: List[str] list of input images fnames_gt: List[str] list of gt images batch_size: int batch size to generate augmentations on images augmentate: bool apply augmentate to batch of images """ def __init__( self, fnames_in: List[str], fnames_gt: List[str], batch_size: int, augmentate: bool, ): """Generate images for training/validation/testing.""" self.fnames_in = deepcopy(fnames_in) self.fnames_gt = deepcopy(fnames_gt) self.batch_size = batch_size self.augmentate = augmentate # self.idxs = np.array([i for i in range(len(self.fnames_in))]) self.idxs = np.array(list(range(len(self.fnames_in)))) def __len__(self): # type: ignore """Returns length of the data generator.""" return int(np.ceil(float(self.idxs.shape[0]) / float(self.batch_size)))
[docs] def on_epoch_end(self) -> None: """Shuffles the images at the end of epoch.""" np.random.shuffle(self.idxs)
def __apply_augmentation__(self, p: Type[DataPipeline]) -> List[np.ndarray]: """Apply augmentation on batch of images.""" batch = [] for i in range(0, len(p.augmentor_images)): images_to_return = [PIL.Image.fromarray(x) for x in p.augmentor_images[i]] for operation in p.operations: r = round(random.uniform(0, 1), 1) # nosec if r <= operation.probability: images_to_return = operation.perform_operation(images_to_return) images_to_return = [np.asarray(x) for x in images_to_return] batch.append(images_to_return) return batch
[docs] def augmentate_batch( self, imgs_in: List[np.ndarray], imgs_gt: List[np.ndarray] ) -> Tuple[List[np.ndarray], List[np.ndarray]]: """Generate ordered augmented batch of images, using Augmentor. Parameters ---------- imgs_in: List[numpy.ndarray] list of input images as array imgs_gt: List[numpy.ndarray] list of gt image as array Returns ------- Tuple[List[numpy.ndarray], List[numpy.ndarray]] List of input images after applying augmentation List of gt images after applying augmentation """ # Non-Linear transformations. imgs = [[imgs_in[i], imgs_gt[i]] for i in range(len(imgs_in))] p = DataPipeline(imgs) p.random_distortion(0.5, 6, 6, 4) # Linear transformations. # p.rotate(0.75, 15, 15) p.shear(0.75, 10.0, 10.0) p.zoom(0.75, 1.0, 1.2) p.skew(0.75, 0.75) imgs = self.__apply_augmentation__(p) imgs_in = [p[0] for p in imgs] imgs_gt = [p[1] for p in imgs] # Noise transformations. p = DataPipeline([[img] for img in imgs_in]) gaussian_noise = GaussianNoiseAugmentor(0.25, 0, 10) p.add_operation(gaussian_noise) salt_pepper_noise = SaltPepperNoiseAugmentor(0.25, 0.005) p.add_operation(salt_pepper_noise) # Brightness transformation. p.random_brightness(0.75, 0.5, 1.5) p.random_contrast(0.75, 0.5, 1.5) # Colors invertion. invert = InvertPartAugmentor(0.25) p.add_operation(invert) p.invert(0.5) imgs_in = self.__apply_augmentation__(p) imgs_in = [p[0] for p in imgs_in] return imgs_in, imgs_gt
def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]: """Creates numpy arrays with images.""" start = idx * self.batch_size stop = start + self.batch_size if stop >= self.idxs.shape[0]: stop = self.idxs.shape[0] imgs_in: List = [] imgs_gt: List = [] for i in range(start, stop): imgs_in.append( cv2.imread(self.fnames_in[self.idxs[i]], cv2.IMREAD_GRAYSCALE) ) imgs_gt.append( cv2.imread(self.fnames_gt[self.idxs[i]], cv2.IMREAD_GRAYSCALE) ) # Applying augmentations. if self.augmentate: imgs_in, imgs_gt = self.augmentate_batch(imgs_in, imgs_gt) # Normalization. imgs_in = np.array([ImageUtils.normalize_in(img) for img in imgs_in]) imgs_in.shape = (imgs_in.shape[0], imgs_in.shape[1], imgs_in.shape[2], 1) # type: ignore # noqa: E501 imgs_gt = np.array([ImageUtils.normalize_gt(img) for img in imgs_gt]) imgs_gt.shape = (imgs_gt.shape[0], imgs_gt.shape[1], imgs_gt.shape[2], 1) # type: ignore # noqa: E501 return imgs_in, imgs_gt
[docs]def main( input_path: str = os.path.join(".", "input"), vis: str = os.path.join(".", "vis"), debug: str = os.path.join(".", "train_logs"), loss: str = "dice_coef_loss", epochs: int = 1, batchsize: int = 32, augmentate: bool = True, train_split: int = 80, val_split: int = 10, test_split: int = 10, weights_path: str = os.path.join(".", "bin_weights.hdf5"), num_gpus: int = 1, extraprocesses: int = 0, queuesize: int = 10, ) -> None: """Train U-net with pairs of train and ground-truth images. Parameters ---------- input_path: str, optional input dir with in and gt sub folders to train (default is os.path.join(".", "input")). vis: str, optional dir with image to use for train visualization (default is os.path.join(".", "vis")). debug: str, optional path to save training logs (default is os.path.join(".", "train_logs")). loss: str, optional loss function (default is dice_coef_loss - dice loss). epochs: int, optional number of epochs to train unagi (default is `1`). batchsize: int, optional batchsize to train unagi (default is `32`). augmentate: bool, optional argumentate the original images for training unagi (default is `True`) train_split: int, optional train dataset split percentage (default is `80`). val_split: int, optional validation dataset split percentage (default is `10`). test_split: int, optional train dataset split percentage (default is `10`). weights_path: str, optional path to save final weights (default is os.path.join(".", "bin_weights.hdf5")). num_gpus: int, optional number of gpus to use for training unagi (default is `1`) extraprocesses: int, optional number of extraprocesses to use (default is `0`). queuesize: int, optional number of batches to generate in queue while training (default is `10`). Returns ------- None Note ---- All train images should be in "in" directory. All ground-truth images should be in "gt" directory. Example ------- unagi.train.main(input, vis, logs_dir, 2, 4) """ assert epochs > 0 assert batchsize > 0 assert train_split >= 0 assert val_split >= 0 assert test_split >= 0 assert num_gpus >= 1 assert extraprocesses >= 0 assert queuesize >= 0 start_time = time.time() np.random.seed() weights_path_dir, _ = os.path.split(weights_path) ImageUtils.mkdir_s(weights_path_dir) # Creating data for training, validation and testing. fnames_in = [ os.path.join(input_path, "in", str(i) + "_in.png") for i in range(len(os.listdir(os.path.join(input_path, "in")))) ] fnames_gt = [ os.path.join(input_path, "gt", str(i) + "_gt.png") for i in range(len(os.listdir(os.path.join(input_path, "gt")))) ] assert len(fnames_in) == len(fnames_gt) n = len(fnames_in) train_start = 0 train_stop = int(n * (train_split / 100)) train_in = fnames_in[train_start:train_stop] train_gt = fnames_gt[train_start:train_stop] train_generator = ParallelDataGenerator(train_in, train_gt, batchsize, augmentate) validation_start = train_stop validation_stop = validation_start + int(n * (val_split / 100)) validation_in = fnames_in[validation_start:validation_stop] validation_gt = fnames_gt[validation_start:validation_stop] validation_generator = ParallelDataGenerator( validation_in, validation_gt, batchsize, augmentate ) test_start = validation_stop test_stop = n test_in = fnames_in[test_start:test_stop] test_gt = fnames_gt[test_start:test_stop] test_generator = ParallelDataGenerator(test_in, test_gt, batchsize, augmentate) # check if validation steps are more than batch size or not assert validation_generator.__len__() >= batchsize assert test_generator.__len__() >= batchsize # check if vis folder contains data from previous training vis_dir_files = os.listdir(vis) for file_name in vis_dir_files: print(vis + "/" + file_name) if file_name.endswith(".gif"): os.remove(vis + "/" + file_name) elif file_name.endswith("_frames"): os.rmdir(vis + "/" + file_name) else: pass # loss function selection loss_function: Any = { "dice": ModelMetrics.dice_coef_loss, "wce": ModelMetrics.weighted_cross_entropy, "focal": ModelMetrics.focal_loss, "focal_tversky": ModelMetrics.focal_tversky, } loss = loss_function.get(loss, ModelMetrics.dice_coef_loss) # Creating model. original_model = UNAGIModel().unet() if num_gpus == 1: model = original_model model.compile( optimizer=Adam(lr=1e-4), loss=loss, metrics=[ModelMetrics.dice_coef, ModelMetrics.jacard_coef, "accuracy"], ) model.summary() else: model = multi_gpu_model(original_model, gpus=num_gpus) model.compile( optimizer=Adam(lr=1e-4), loss=loss, metrics=[ModelMetrics.dice_coef, ModelMetrics.jacard_coef, "accuracy"], ) model.summary() callbacks = create_callbacks( model, original_model, debug, num_gpus, batchsize, vis, weights_path ) # Running training, validation and testing. if extraprocesses == 0: model.fit_generator( generator=train_generator, steps_per_epoch=train_generator.__len__(), # Compatibility with old Keras versions. validation_data=validation_generator, validation_steps=validation_generator.__len__(), # Compatibility with old Keras versions. epochs=epochs, shuffle=True, callbacks=callbacks, use_multiprocessing=False, workers=0, max_queue_size=queuesize, verbose=1, ) metrics = model.evaluate_generator( generator=test_generator, use_multiprocessing=False, workers=0, max_queue_size=queuesize, verbose=1, ) else: model.fit_generator( generator=train_generator, steps_per_epoch=train_generator.__len__(), # Compatibility with old Keras versions. validation_data=validation_generator, validation_steps=validation_generator.__len__(), # Compatibility with old Keras versions. epochs=epochs, shuffle=True, callbacks=callbacks, use_multiprocessing=True, workers=extraprocesses, max_queue_size=queuesize, verbose=1, ) metrics = model.evaluate_generator( generator=test_generator, use_multiprocessing=True, workers=extraprocesses, max_queue_size=queuesize, verbose=1, ) print() print("total:") print("test_loss: {0:.4f}".format(metrics[0])) print("test_dice_coef: {0:.4f}".format(metrics[1])) print("test_jacard_coef: {0:.4f}".format(metrics[2])) print("test_accuracy: {0:.4f}".format(metrics[3])) # Saving model. if debug != "": model.save_weights(weights_path) print("finished in {0:.2f} seconds".format(time.time() - start_time))
# create parse_arg function to handle cli commands
[docs]def parse_args() -> argparse.Namespace: """Parse command-line arguments for train module.""" parser = argparse.ArgumentParser(description="Train the unagi model on input data.") parser.add_argument( "--input_path", type=str, default=os.path.join(".", "input"), help="Input path for images.", ) parser.add_argument( "--vis", type=str, default=os.path.join(".", "vis"), help="Visualization path for images.", ) parser.add_argument( "--debug", type=str, default=os.path.join(".", "train_logs"), help="Debug path for images.", ) parser.add_argument( "--loss", type=str, default="dice_coef_loss", help="Loss function to use.", ) parser.add_argument( "--epochs", type=int, default=1, help="Number of epochs to train." ) parser.add_argument( "--batchsize", type=int, default=32, help="Batch size to use in training." ) parser.add_argument( "--augmentate", type=bool, default=True, help="Apply augmentation to training data.", ) parser.add_argument( "--train_split", type=int, default=80, help="Train dataset split percentage.", ) parser.add_argument( "--val_split", type=int, default=10, help="Validation dataset split percentage.", ) parser.add_argument( "--test_split", type=int, default=10, help="Test dataset split percentage.", ) parser.add_argument( "--weights_path", type=str, default=os.path.join(".", "bin_weights.hdf5"), help="Path to weights file.", ) parser.add_argument( "--num_gpus", type=int, default=1, help="Number of GPUs to use for training.", ) parser.add_argument( "--extraprocesses", type=int, default=0, help="Number of extra processes to use.", ) parser.add_argument( "--queuesize", type=int, default=10, help="Number of batches to generate in queue while training.", ) return parser.parse_args()
if __name__ == "__main__": args = parse_args() main( input_path=args.input_path, vis=args.vis, debug=args.debug, loss=args.loss, epochs=args.epochs, batchsize=args.batchsize, augmentate=args.augmentate, train_split=args.train_split, val_split=args.val_split, test_split=args.test_split, weights_path=args.weights_path, num_gpus=args.num_gpus, extraprocesses=args.extraprocesses, queuesize=args.queuesize, )