#!/usr/bin/python3
import os
from typing import Any, Dict, List
import cv2
import imageio
import numpy as np
from alt_model_checkpoint.tensorflow import AltModelCheckpoint
from tensorflow.python.keras.callbacks import Callback, EarlyStopping, TensorBoard
from tensorflow.python.keras.models import Model as keras_model
from unagi.utils.img_processing_utils import ImageUtils
[docs]class Visualisation(Callback):
"""Custom Keras callback for visualizing training through GIFs.
Parameters
----------
batchsize: int
batchsize to generate samples
dir_name: str, optional
folder with images to visualize training
monitor: str, optional
metric to monitor
save_best_epochs_only: bool, optional
save the weights only when the metric improves
mode: str, optional
mode for the metric to monitor
"""
def __init__(
self,
batchsize: int,
dir_name: str = "vis",
monitor: str = "val_loss",
save_best_epochs_only: bool = False,
mode: str = "min",
):
"""Custom Keras callback for visualizing training through GIFs."""
super(Visualisation, self).__init__()
self.dir_name = dir_name
self.batchsize = batchsize
self.epoch_number = 0
self.fnames = os.listdir(self.dir_name)
for fname in self.fnames:
ImageUtils.mkdir_s(
os.path.join(self.dir_name, fname[: fname.rfind(".")] + "_frames")
)
self.monitor = monitor
self.save_best_epochs_only = save_best_epochs_only
self.mode = mode
self.curr_metric = None
[docs] def on_train_end(self, logs=None) -> None: # type: ignore
"""Saves metrics for each iteration to a dictionary and saves images as gifs."""
for fname in self.fnames:
frames = []
for frame_name in sorted(
os.listdir(
os.path.join(self.dir_name, fname[: fname.rfind(".")] + "_frames")
)
):
frames.append(
imageio.imread(
os.path.join(
self.dir_name,
fname[: fname.rfind(".")] + "_frames",
frame_name,
)
)
)
imageio.mimsave(
os.path.join(self.dir_name, fname[: fname.rfind(".")] + ".gif"),
frames,
format="GIF",
duration=0.5,
)
# rmtree(os.path.join(
# self.dir_name,
# fname[:fname.rfind('.')] + '_frames'))
[docs] def on_epoch_end(self, epoch: int, logs: Dict[str, Any]) -> None:
"""Saves prediction of image after epoch into a folder with the same name.
Parameters
----------
epoch: int
current epoch number
logs: Dict[str, Any]
dictionary of metrics
"""
self.epoch_number += 1
if (not self.save_best_epochs_only) or (
(self.curr_metric is None)
or (self.mode == "min" and logs[self.monitor] < self.curr_metric)
or (self.mode == "max" and logs[self.monitor] > self.curr_metric)
):
self.curr_metric = logs[self.monitor]
for fname in self.fnames:
img = cv2.imread(
os.path.join(self.dir_name, fname), cv2.IMREAD_GRAYSCALE
).astype(np.float32)
img = ImageUtils.binarize_img(img, self.model, self.batchsize)
cv2.imwrite(
os.path.join(
self.dir_name,
fname[: fname.rfind(".")] + "_frames",
str(self.epoch_number) + "_out.png",
),
img,
)
[docs]def create_callbacks(
model: keras_model,
original_model: keras_model,
debug: str,
num_gpus: int,
batchsize: int,
vis: str,
weights_path: str,
) -> List[str]:
"""Create Keras callbacks for training.
Parameters
----------
model: keras_model
keras model
original_model: keras_model
model to use when num_gpus > 1
debug: str
path to save weights and tensorboard logs
num_gpus: int
number of gpus
batchsize: int
batchsize to use during training visualization
vis: str
images to read for training visualization
weights_path: str
path to save final weights
Returns
-------
List[str]
list of callbacks tu use in training.
See Also
--------
Visualisation()
Example
-------
unagi.utils.callbacks_utils.create_callbacks(
model,
gpu_model,
logs,
1,
32,
vis_imgs,
weights
)
"""
callbacks = []
# Model checkpoint.
if num_gpus == 1:
model_checkpoint = AltModelCheckpoint(
weights_path
if debug == ""
else os.path.join(debug, "weights", "weights-improvement-{epoch:02d}.hdf5"),
model,
monitor="val_dice_coef",
mode="max",
verbose=1,
save_best_only=True,
save_weights_only=True,
)
else:
model_checkpoint = AltModelCheckpoint(
weights_path
if debug == ""
else os.path.join(debug, "weights", "weights-improvement-{epoch:02d}.hdf5"),
original_model,
monitor="val_dice_coef",
mode="max",
verbose=1,
save_best_only=True,
save_weights_only=True,
)
callbacks.append(model_checkpoint)
# Early stopping.
model_early_stopping = EarlyStopping(
monitor="val_dice_coef", min_delta=0.001, patience=20, verbose=1, mode="max"
)
callbacks.append(model_early_stopping)
# Tensorboard logs.
if debug != "":
ImageUtils.mkdir_s(debug)
ImageUtils.mkdir_s(os.path.join(debug, "weights"))
ImageUtils.mkdir_s(os.path.join(debug, "logs"))
model_tensorboard = TensorBoard(
log_dir=os.path.join(debug, "logs"),
histogram_freq=0,
write_graph=True,
write_images=True,
)
callbacks.append(model_tensorboard)
# Training visualisation.
if vis != "":
model_visualisation = Visualisation(
dir_name=vis,
batchsize=batchsize,
monitor="val_dice_coef",
save_best_epochs_only=True,
mode="max",
)
callbacks.append(model_visualisation)
return callbacks