#!/usr/bin/python3
import glob
import argparse
import os
import time
from typing import List, Optional
import cv2
import numpy as np
from tensorflow.keras.optimizers import Adam
from unagi.utils.img_processing_utils import ImageUtils
from unagi.utils.metric_utils import ModelMetrics
from unagi.utils.model_utils import UNAGIModel
[docs]def main(
input_path: str = os.path.join(".", "input"),
output_path: str = os.path.join(".", "output"),
weights_path: Optional[str] = None,
batchsize: int = 2,
) -> List[np.ndarray]:
"""Binarize images from input directory and write them to output directory.
Parameters
----------
input_path: str, optional
input path for images. (default is input folder in current directory)
output_path: str, optional
output path to save images. (default is output folder in current directory)
weights_path: str or None, optional
path to weights file.
if None default weights will be loaded from package root directory.
batchsize: int, optional
batchsize to use in model prediction
Returns
-------
List[numpy.ndarray]
list of binary images in np.ndarray format
Note
----
All input image names should be in png format "sample_1.png".
All output image names will end with "_bin" like "sample_1_bin.png".
Example
-------
unagi.binarize.main('input_path', 'output_path', 2)
"""
try:
assert (batchsize > 0) and isinstance(batchsize, int)
except Exception:
print("batchsize should be > 0 and int but given {}".format(batchsize))
try:
assert os.path.isdir(input_path)
except Exception:
print("Input path is not valid")
if weights_path is None:
weights_path = os.path.realpath(__file__)
weights_path = weights_path.replace(
"unagi/binarize.py", "weights/bin_weights_file.hdf5"
)
else:
pass
start_time = time.time()
fnames_in = list(
glob.iglob(os.path.join(input_path, "**", "*.png*"), recursive=True)
)
model = None
if len(fnames_in) != 0:
ImageUtils.mkdir_s(output_path)
model = UNAGIModel().unet()
model.compile(
optimizer=Adam(lr=1e-4),
loss=ModelMetrics.dice_coef_loss,
metrics=[ModelMetrics.dice_coef],
)
model.load_weights(weights_path)
bin_img_list = []
for fname in fnames_in:
print("binarizing -> {0}".format(fname))
img = cv2.imread(fname, cv2.IMREAD_GRAYSCALE).astype(np.float32)
img = ImageUtils.binarize_img(img, model, batchsize)
cv2.imwrite(
os.path.join(
output_path, os.path.split(fname)[-1].replace(".png", "_bin.png")
),
img,
)
bin_img_list.append(img)
print("finished in {0:.2f} seconds".format(time.time() - start_time))
return bin_img_list
[docs]def parse_args() -> argparse.Namespace:
"""Parse command-line arguments for binarize module."""
parser = argparse.ArgumentParser(
description="Binarize images from input dir and write them to output directory."
)
parser.add_argument(
"--input_path",
type=str,
default=os.path.join(".", "input"),
help="Input path for images.",
)
parser.add_argument(
"--output_path",
type=str,
default=os.path.join(".", "output"),
help="Output path to save images.",
)
parser.add_argument(
"--weights_path",
type=str,
default=None,
help="Path to weights file. If None, default will be loaded from package root.",
)
parser.add_argument(
"--batchsize",
type=int,
default=2,
help="Batch size to use in model prediction.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(
input_path=args.input_path,
output_path=args.output_path,
weights_path=args.weights_path,
batchsize=args.batchsize,
)