Code source de code_beatrix.ai.image_segmentation

"""
Extracts objects from an image based on deep learning.


:githublink:`%|py|5`
"""
from contextlib import redirect_stdout
import io
import os
import numpy
from PIL import Image
import skimage
from skimage.io._plugins.pil_plugin import pil_to_ndarray
import chainer  # pylint: disable=E0401
import fcn  # pylint: disable=E0401
from .dlbase import DeepLearningImage


[docs]class DLImageSegmentation(DeepLearningImage): """ Segments an image. Inspired from `infer.py <https://github.com/wkentaro/fcn/blob/master/examples/voc/infer.py>`_. See notebook :ref:`imagesegmentationrst`. :githublink:`%|py|23` """
[docs] def __init__(self, model="FCN8s", n_class=21, gpu=False, class_name=None, fLOG=None): """ :param model: model name :param n_class: number of classes :param gpu: use gpu :param class_name: class names :param fLOG: logging function List of known models: * ``'FCN8s'``: image segmentation :githublink:`%|py|36` """ self._fLOG = fLOG if model == "FCN8s": self.log( "[DLImageSegmentation] download model '{0}'".format(model)) f = io.StringIO() with redirect_stdout(f): model_file = fcn.models.FCN8s.download() self.log('[DLImageSegmentation] {0}'.format(f.getvalue())) self._model_file = model_file model_class = fcn.models.FCN8s model = model_class(n_class=n_class) self.log("[DLImageSegmentation] load_npz '{0}'".format(model_file)) chainer.serializers.load_npz( # pylint: disable=E1101 model_file, model) # pylint: disable=E1101 else: raise NotImplementedError( "Unable to interpret '{0}'".format(model)) DeepLearningImage.__init__(self, model, gpu=gpu, fLOG=fLOG) self._n_class = n_class if class_name is None: self._class_name = class_name = fcn.datasets.VOC2012ClassSeg.class_names else: self._class_name = class_name self.log("[DLImageSegmentation] class_name '{0}'".format(class_name)) if gpu: self.log("[DLImageSegmentation] gpu") chainer.cuda.get_device(self._gpu).use() # pylint: disable=E1101 model.to_gpu() else: self.log("[DLImageSegmentation] cpu")
@property def ModelFile(self): """ Returns the model file name. :githublink:`%|py|74` """ return self._model_file
[docs] @staticmethod def _new_size(old_size, new_size): """ Computes a new size. :param old_size: current size :param new_size: new desired size :return: new size *new_size* can be of: * (int, int): this is the new size * ('max2', int): this size is divided by 2 until the greater dimension is below a threshold :githublink:`%|py|91` """ if not isinstance(new_size, tuple): raise TypeError("new_size must be a tuple") if not isinstance(old_size, tuple): raise TypeError("old_size must be a tuple") if len(old_size) != 2: raise ValueError("old_size must have two values") if len(new_size) != 2: raise ValueError("new_size must have two values") if isinstance(new_size[0], str): if new_size[0] == 'max2': mx = max(old_size) p = 1 while mx > new_size[1]: mx //= 2 p *= 2 return (old_size[0] // p, old_size[1] // p) else: raise ValueError( "Unable to interpret '{0}'".format(new_size[0])) elif isinstance(new_size[0], int): return new_size else: raise TypeError("new_size[0] must be an int")
[docs] def _load_image(self, img, resize=None): """ Loads an image as a :epkg:`numpy:array`. :param img: image :param resize: resize the image before predicting, see :meth:`_new_size <code_beatrix.ai.image_segmentation.DLImageSegmentation._new_size>` :return: :epkg:`numpy:array` :githublink:`%|py|124` """ if isinstance(img, str): # Loads the image. if not os.path.exists(img): raise FileNotFoundError(img) if resize is None: feat = skimage.io.imread(img) else: pilimg = Image.open(img) si = DLImageSegmentation._new_size(pilimg.size, resize) pilimg2 = pilimg.resize(si) feat = pil_to_ndarray(pilimg2) elif isinstance(img, numpy.ndarray): if resize is None: feat = img else: # Does not work... # feat = skimage.transform.resize(img, resize) # So... pilimg = Image.fromarray(img).convert('RGB') pilimg2 = pilimg.resize(resize) feat = pil_to_ndarray(pilimg) else: raise NotImplementedError( "Not implemented for type '{0}'".format(type(img))) return feat
[docs] def _preprocess(self, feat, preprocess=True): """ Preprocesses the image before prediction. :param feat: image (output of :meth:`_load_image <code_beatrix.ai.image_segmentation.DLImageSegmentation._load_image>`) :param preprocess: applies some preprocessing or not :return: preprocessed image :githublink:`%|py|158` """ if preprocess: input, = fcn.datasets.transform_lsvrc2012_vgg16( (feat,)) # pylint: disable=W0632 input = input[numpy.newaxis, :, :, :] # pylint: disable=E0401,E1126 return input else: return feat
[docs] def predict(self, img, resize=None): """ Applies the model on features *X*. :param img: image :param resize: resize the image before predicting, see :meth:`_new_size <code_beatrix.ai.image_segmentation.DLImageSegmentation._new_size>` :return: (image, prediction) :githublink:`%|py|175` """ feat = self._load_image(img, resize=resize) input = self._preprocess(feat, preprocess=True) if self._gpu: input = chainer.cuda.to_gpu(input) # pylint: disable=E1101 with chainer.no_backprop_mode(): # pylint: disable=E1101 input = chainer.Variable(input) # pylint: disable=E1101 with chainer.using_config('train', False): # pylint: disable=E1101 self._model(input) lbl_pred = chainer.functions.argmax( # pylint: disable=E1101 self._model.score, axis=1)[0] lbl_pred = chainer.cuda.to_cpu( # pylint: disable=E1101 lbl_pred.data) # pylint: disable=E1101 return feat, lbl_pred
[docs] def plot(self, img, pred): """ Displays the segmentation. :param img: initial image :return: new image :githublink:`%|py|198` """ img = self._load_image(img) viz = fcn.utils.visualize_segmentation( lbl_pred=pred, img=img, n_class=self._n_class, label_names=self._class_name) return viz