#!/usr/bin/python3
import glob
import os
import time
from typing import List, Optional
import cv2
import numpy as np
from tensorflow.keras.optimizers import Adam
from .utils.img_processing_utils import binarize_img, mkdir_s
from .utils.metric_utils import dice_coef, dice_coef_loss
from .utils.model_utils import unet
[docs]def main(
input_path: str = os.path.join(".", "input"),
output_path: str = os.path.join(".", "output"),
weights_path: Optional[str] = None,
batchsize: int = 2,
) -> List[np.ndarray]:
"""Binarize images from input directory and write them to output directory.
Parameters
----------
input_path: str, optional
input path for images. (default is input folder in current directory)
output_path: str, optional
output path to save images. (default is output folder in current directory)
weights_path: str or None, optional
path to weights file.
if None default weights will be loaded from package root directory.
batchsize: int, optional
batchsize to use in model prediction
Returns
-------
List[numpy.ndarray]
list of binary images in np.ndarray format
Note
----
All input image names should be in png format "sample_1.png".
All output image names will end with "_bin" like "sample_1_bin.png".
Example
-------
unagi.binarize.main('input_path', 'output_path', 2)
"""
try:
assert (batchsize > 0) and isinstance(batchsize, int)
except Exception:
print("batchsize should be > 0 and int but given {}".format(batchsize))
try:
assert os.path.isdir(input_path)
except Exception:
print("Input path is not valid")
if weights_path is None:
weights_path = os.path.realpath(__file__)
weights_path = weights_path.replace(
"unagi/binarize.py", "weights/bin_weights_file.hdf5"
)
else:
pass
start_time = time.time()
fnames_in = list(
glob.iglob(os.path.join(input_path, "**", "*.png*"), recursive=True)
)
model = None
if len(fnames_in) != 0:
mkdir_s(output_path)
model = unet()
model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=[dice_coef])
model.load_weights(weights_path)
bin_img_list = []
for fname in fnames_in:
print("binarizing -> {0}".format(fname))
img = cv2.imread(fname, cv2.IMREAD_GRAYSCALE).astype(np.float32)
img = binarize_img(img, model, batchsize)
cv2.imwrite(
os.path.join(
output_path, os.path.split(fname)[-1].replace(".png", "_bin.png")
),
img,
)
bin_img_list.append(img)
print("finished in {0:.2f} seconds".format(time.time() - start_time))
return bin_img_list
if __name__ == "__main__":
main()