Semantic Image Segmentation - Part 1¶

Deep Learning Methods for Semantic Segmentation¶

So far, you have seen image classification, where the task of the network is to assign a label or class to an input image. However, suppose you want to know where an object is located in the image, what shape that object is, which pixel belongs to which object, etc. In this case, you will want to segment the image, i.e., give each pixel in the image a label. So, the task of image segmentation is to train a neural network to produce a pixel-wise mask of the image. This helps in understanding the image at a much lower level, i.e., the pixel level. Image segmentation has many applications in medical imaging, self-driving cars, and satellite imagery, to name a few.

image.png

More precisely, semantic image segmentation is the task of labeling each pixel in the image into a predefined set of classes. For example, various objects like cars, trees, people, traffic signs, etc. can be used as classes for semantic image segmentation. So, the task is to take an image (RGB or grayscale) and produce a W x H x 1 matrix, where W and H represent the width and height of the image, respectively. Each cell in this matrix would contain the predicted class IDs for each pixel in the image.

Image Classification vs. Semantic Segmentation¶

image.png

When it comes to semantic segmentation, we usually don’t require a fully connected layer at the end because our goal is not to predict the class label of the image.

In semantic segmentation, our goal is to extract features before using them to separate the image into multiple segments.

However, the problem with convolutional networks is that the image size gets reduced as it passes through the network due to the max-pooling layers.

To efficiently separate the image into multiple segments, we need to upsample using an interpolation technique, which is achieved using deconvolutional layers.

image.png

Labels in Semantic Segmentation¶

In deep learning, we express categorical class labels as one-hot encoded vectors. Similarly, in semantic segmentation, we can express the output matrix using a one-hot encoding scheme by essentially creating a channel for each class label and marking the cells by 1 that contain the pixel of the corresponding class and marking the remaining cells by 0.

image.png

Implementing a Semantic Segmentation Architecture¶

Like other computer vision tasks, using a CNN for semantic segmentation would be the obvious choice. When using CNN for semantic segmentation, the output would be an image with the same resolution as the input, unlike a fixed-length vector in the case of image classification.

A naive approach to building a neural network architecture for this task is to simply stack multiple convolutional layers (with the same padding to preserve dimensions) and output a final segmentation map. This directly learns a mapping from the input image to its corresponding segmentation through successive transformation of feature mappings; however, it is quite computationally expensive to preserve full resolution across the network.

image.png

Recall that for deep convolutional networks, earlier layers tend to learn low-level concepts, while later layers develop more high-level (and specialized) feature mappings. To maintain expressiveness, we typically need to increase the number of feature maps (channels) as we go deeper into the network. This is not necessarily a problem for the task of image classification, since for that task we only care about what the image contains (and not where it is located). Thus, we could alleviate the computational burden by periodically downsampling our feature maps via pooling or strided convolutions (i.e., compressing the spatial resolution) without worrying. However, for image segmentation, we would like our model to produce a full-resolution semantic prediction. A popular approach for image segmentation models is to follow an encoder/decoder framework where we reduce the spatial resolution of the input, developing lower resolution feature mappings that are highly efficient at discriminating between classes, and up-resolution the feature representations into a full resolution segmentation map.

image.png

Methods for upsampling:¶

There are a few different approaches we can use to increase the resolution of a feature map. While pooling operations decrease resolution by summarizing a local area with a single value (i.e., average or max pooling), unpooling operations increase resolution by distributing a single value across a higher resolution.

image.png

However, transposed convolutions are by far the most popular approach as they allow us to build on learned upsampling.

image.png

Whereas a typical convolution operation will take the dot product of the values ​​currently in the filter view and produce a single value for the corresponding output position, a transposition convolution essentially does the opposite. For a transposition convolution, we take a single value from the low-resolution feature map and multiply all of the weights in our filter by that value, projecting those weighted values ​​onto the output feature map.

For filter sizes that produce an overlap in the output feature map (e.g. 3x3 filter with stride 2 - as shown in the example below), the overlapping values ​​are simply summed. Unfortunately, this tends to produce a checkerboard artifact in the output and is undesirable, so it is best to ensure that the filter size does not produce an overlap.

Loss Functions for Semantic Segmentation¶

The loss function is used to guide the neural network towards optimization. Let’s discuss some popular loss functions for the semantic segmentation task. Semantic segmentation models usually use a simple categorical entropy loss function during training. However, if you are interested in getting the granular information of an image, you will need to revert to slightly more advanced loss functions.

Distribution-based loss¶

Cross Entropy Loss¶

The most commonly used loss function for the image segmentation task is a pixel-wise cross-entropy loss. This loss examines each pixel individually, comparing the class predictions (pixel depth vector) with our one-hot encoded target vector.

image.png

Since the cross-entropy loss evaluates the class predictions for each pixel vector individually and then averages them across all pixels, we are essentially asserting equal learning for each pixel in the image. This can be a problem if your various classes have an imbalanced representation in the image, as training may be dominated by the most prevalent class. Long et al. (FCN paper) discusses weighting this loss for each output channel in order to counteract a class imbalance present in the dataset.

Weighted cross entropy¶

It is an extension of the CE, which assigns different weights to each class. In general, classes not presented will have higher weights.

Focal Loss¶

It adapts the CE pattern to handle extreme foreground-background class imbalance, where the loss attributed to well-classified examples is reduced.

Region-based Loss¶

Region-based loss functions aim to minimize the mismatch or maximize the overlapping regions between the ground truth and the predicted segmentation.

Dice Loss¶

The Dice function is nothing but the F1 score. This loss function directly tries to optimize the F1 score. Similarly, the direct IOU score can also be used to perform optimization

image.png

image.png

IoU Loss¶

IoU Loss (also called Jaccard Loss), similar to Dice Loss, is also used to directly optimize the targeting metric.

Tversky loss¶

Sets different weights for false negative (FN) and false positive (FP), which is different from data loss using equal weights for FN and FP.

Focal Tversky loss¶

Applies the concept of focal loss to focus on difficult cases with low probabilities.

Metricas para Segmentação Semântica¶

Let's discuss the metrics that are commonly used to understand and evaluate the results of a model.

Pixel Accuracy¶

Pixel accuracy is the most basic metric that can be used to validate the results. Accuracy is obtained by taking the proportion of correctly classified pixels to the total pixels

Accuracy = (TP+TN)/(TP+TN+FP+FN)

The main drawback of using this technique is that the result may look good if one class dominates the other. Say, for example, the background class covers 90% of the input image, we can achieve 90% accuracy just by classifying every pixel as background

Metrics based on IoU¶

Intuitively, a successful prediction is one that maximizes the overlap between the predicted and true objects. Two related but different metrics for this goal are the Dice and Jaccard coefficients (or indices):

image.png

IOU is defined as the ratio of the intersection of the ground truth and predicted segmentation outputs over their union. If we are calculating for multiple classes, then the IOU of each class is calculated and its average is taken. It is a better metric compared to pixel-wise accuracy because if each pixel is given as background in a 2-class input, then the IOU value is (90/100+0/100)/2 i.e. 45% IOU which gives a better representation compared to 90% accuracy.

image.png

In terms of the confusion matrix, the metrics can be reformulated in terms of true/false positives/negatives:

image.png

Here is an illustration of the Dice and IoU metrics given two circles representing the ground truth and predicted masks for an arbitrary object class:

image.png

Frequency weighted IOU¶

This is an extension on top of the average IOU we discussed and is used to combat class imbalance. If a class dominates most of the images in a dataset, such as background, it needs to be weighted compared to other classes. So, instead of taking the average of all the results for the class, a weighted average is taken based on the frequency of the region of the class in the dataset.

FCN - Fully Convolutional Networks¶

The general architecture of a CNN consists of a few convolutional and pooling layers followed by a few fully connected layers at the end. The Fully Convolutional Network paper released in 2014 argues that the final fully connected layer can be thought of as doing a 1x1 convolution that covers the entire region.

image.png

Thus, the final dense layers can be replaced with a convolution layer achieving the same result. But now the advantage of doing so is that the input size no longer needs to be fixed. By involving dense layers, the input size is restricted and hence when a different size input is to be provided, it has to be resized. But by replacing a dense layer with convolution, this restriction does not exist.

Furthermore, when a larger size of image is provided as input, the output produced will be a feature map and not just a class output as for a normal input size image. Furthermore, the observed behavior of the final feature map represents the heatmap of the required class, i.e. the position of the object is highlighted in the feature map. Since the output of the feature map is a heatmap of the required object, it is valid information for our segmentation use case.

Since the feature map obtained at the output layer is a downsampled due to the set of convolutions performed, we would like to upsample it using an interpolation technique. Bilinear upsampling works, but the paper proposes using learned upsampling with deconvolution, which can even learn nonlinear upsampling.

The downsampling part of the network is called the encoder, and the upsampling part is called the decoder. This is a pattern we will see in many architectures, namely downsampling with the encoder and then upsampling with the decoder. In an ideal world, we would not want to downsample using pooling and keep the same size, but this would lead to a huge amount of parameters and would be computationally infeasible.

image.png

Although the output results obtained were decent, the observed output is rough and not smooth. The reason for this is the loss of information in the final feature layer due to 32x downsampling using convolutional layers. Now it becomes very difficult for the network to do 32x upsampling using this little information. This architecture is called FCN-32

To solve this problem, the paper proposed 2 other architectures FCN-16, FCN-8. In FCN-16, the information from the previous pooling layer is used along with the final feature map and hence now the task of the network is to learn 16x upsampling which is better as compared to FCN-32. FCN-8 tries to make it even better by including information from one more previous pooling layer.

image.png

Use Case: Usage and Coverage Class Segmentation¶

Let's start with an example using the Landcover AI dataset:

The LandCover.ai (Land Cover from Aerial Imagery) dataset is a dataset for automatic mapping of buildings, forests, water and roads from aerial imagery.

Dataset features

  • land cover of Poland, Central Europe 1
  • three spectral bands - RGB
  • 33 orthophotos with 25 cm pixel resolution (~9000x9500 px)
  • 8 orthophotos with 50 cm pixel resolution (~4200x4700 px)
  • total area of ​​216.27 km2

Dataset format

  • rasters are three-channel GeoTiffs with EPSG:2180 spatial reference system
  • masks are single-channel GeoTiffs with EPSG:2180 spatial reference system

image.png

In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
In [ ]:
import glob
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

Mosaics are stored in Drive. The first step is to divide them into 512x512 pixel patches

In [ ]:
IMGS_DIR = "drive/MyDrive/Datasets/Landcover_AI/images"
MASKS_DIR = "drive/MyDrive/Datasets/Landcover_AI/masks"
OUTPUT_DIR = "./output"
TARGET_SIZE = 1024

img_paths = glob.glob(os.path.join(IMGS_DIR, "*.tif"))
mask_paths = glob.glob(os.path.join(MASKS_DIR, "*.tif"))

img_paths.sort()
mask_paths.sort()

os.makedirs(OUTPUT_DIR)
for i, (img_path, mask_path) in enumerate(zip(img_paths, mask_paths)):
    img_filename = os.path.splitext(os.path.basename(img_path))[0]
    mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)
    mask = mask[:,:,0]
    assert img_filename == mask_filename and img.shape[:2] == mask.shape[:2]

    k = 0
    for y in range(0, img.shape[0], TARGET_SIZE):
        for x in range(0, img.shape[1], TARGET_SIZE):
            img_tile = img[y:y + TARGET_SIZE, x:x + TARGET_SIZE]
            mask_tile = mask[y:y + TARGET_SIZE, x:x + TARGET_SIZE]

            if img_tile.shape[0] == TARGET_SIZE and img_tile.shape[1] == TARGET_SIZE:
                out_img_path = os.path.join(OUTPUT_DIR, "{}_{}.jpg".format(img_filename, k))
                cv2.imwrite(out_img_path, img_tile)

                out_mask_path = os.path.join(OUTPUT_DIR, "{}_{}_m.png".format(mask_filename, k))
                cv2.imwrite(out_mask_path, mask_tile)


            k += 1

    print("Processed {} {}/{}".format(img_filename, i + 1, len(img_paths)))
Processed M-33-20-D-c-4-2 1/41
Processed M-33-20-D-d-3-3 2/41
Processed M-33-32-B-b-4-4 3/41
Processed M-33-48-A-c-4-4 4/41
Processed M-33-7-A-d-2-3 5/41
Processed M-33-7-A-d-3-2 6/41
Processed M-34-32-B-a-4-3 7/41
Processed M-34-32-B-b-1-3 8/41
Processed M-34-5-D-d-4-2 9/41
Processed M-34-51-C-b-2-1 10/41
Processed M-34-51-C-d-4-1 11/41
Processed M-34-55-B-b-4-1 12/41
Processed M-34-56-A-b-1-4 13/41
Processed M-34-6-A-d-2-2 14/41
Processed M-34-65-D-a-4-4 15/41
Processed M-34-65-D-c-4-2 16/41
Processed M-34-65-D-d-4-1 17/41
Processed M-34-68-B-a-1-3 18/41
Processed M-34-77-B-c-2-3 19/41
Processed N-33-104-A-c-1-1 20/41
Processed N-33-119-C-c-3-3 21/41
Processed N-33-130-A-d-3-3 22/41
Processed N-33-130-A-d-4-4 23/41
Processed N-33-139-C-d-2-2 24/41
Processed N-33-139-C-d-2-4 25/41
Processed N-33-139-D-c-1-3 26/41
Processed N-33-60-D-c-4-2 27/41
Processed N-33-60-D-d-1-2 28/41
Processed N-33-96-D-d-1-1 29/41
Processed N-34-106-A-b-3-4 30/41
Processed N-34-106-A-c-1-3 31/41
Processed N-34-140-A-b-3-2 32/41
Processed N-34-140-A-b-4-2 33/41
Processed N-34-140-A-d-3-4 34/41
Processed N-34-140-A-d-4-2 35/41
Processed N-34-61-B-a-1-1 36/41
Processed N-34-66-C-c-4-3 37/41
Processed N-34-77-A-b-1-4 38/41
Processed N-34-94-A-b-2-4 39/41
Processed N-34-97-C-b-1-2 40/41
Processed N-34-97-D-c-2-4 41/41

Patches were created and stored in a folder in Colab content. Please note that once your session ends, this data will be lost.

In [ ]:
path = '/content/output'

Let's then import the images and their respective masks and convert them into numpy arrays:

In [ ]:
list_img = [f for f in os.listdir(path) if f.endswith('.jpg')]
In [ ]:
X = []
Y = []
NCLASSES = 5
for path_img in list_img:
  full_path = os.path.join(path, path_img)
  img = cv2.imread(full_path)
  img = cv2.resize(img, (256,256))
  id_img = path_img.split('.')[0]
  mask_id = id_img + '_m.png'
  mask_path = os.path.join(path, mask_id)
  mask = cv2.imread(mask_path)
  mask = cv2.resize(mask, (256,256))
  mask = mask[:,:,0]
  mask = to_categorical(mask,NCLASSES)
  X.append(img)
  Y.append(mask)
In [ ]:
X = np.array(X)
Y = np.array(Y)
In [ ]:
X.shape
Out[ ]:
(2513, 256, 256, 3)

So we can plot an example of an image and its mask:

In [ ]:
plt.figure(figsize=[6,6])
plt.imshow(X[100])
plt.axis('off')
Out[ ]:
(-0.5, 255.5, 255.5, -0.5)
No description has been provided for this image
In [ ]:
plt.figure(figsize=[6,6])
plt.imshow(np.argmax(Y[100], axis=2))
plt.axis('off')
Out[ ]:
(-0.5, 255.5, 255.5, -0.5)
No description has been provided for this image

Now it's time to split the data into training and testing, rescaling the values ​​to a range from 0 to 1:

In [ ]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=10)
In [ ]:
del X, Y
In [ ]:
del list_img
In [ ]:
x_train = x_train/255

The next step is to import some keras and tensorflow functions for the FCN implementation:

In [ ]:
from keras.models import Model
from keras.regularizers import l2
from keras.layers import *
from keras.models import *
from tensorflow.keras import backend as K
import tensorflow as tf
#from tensorflow.keras.optimizers import Adam
#from tensorflow.keras.optimizers.legacy import Adam
from keras.optimizers import Adam
from tensorflow.keras.losses import Dice

This cell creates the BilinearUpsampling2D operation that will be used in the final part of the FCN architecture:

In [ ]:
def resize_images_bilinear(X, height_factor=1, width_factor=1, target_height=None, target_width=None, data_format='default'):
    '''Resizes the images contained in a 4D tensor of shape
    - [batch, channels, height, width] (for 'channels_first' data_format)
    - [batch, height, width, channels] (for 'channels_last' data_format)
    by a factor of (height_factor, width_factor). Both factors should be
    positive integers.
    '''
    if data_format == 'default':
        data_format = K.image_data_format()
    if data_format == 'channels_first':
        original_shape = K.int_shape(X)
        if target_height and target_width:
            new_shape = tf.constant(np.array((target_height, target_width)).astype('int32'))
        else:
            new_shape = tf.shape(X)[2:]
            new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
        X = permute_dimensions(X, [0, 2, 3, 1])
        X = tf.image.resize_bilinear(X, new_shape)
        X = permute_dimensions(X, [0, 3, 1, 2])
        if target_height and target_width:
            X.set_shape((None, None, target_height, target_width))
        else:
            X.set_shape((None, None, original_shape[2] * height_factor, original_shape[3] * width_factor))
        return X
    elif data_format == 'channels_last':
        original_shape = K.int_shape(X)
        if target_height and target_width:
            new_shape = tf.constant(np.array((target_height, target_width)).astype('int32'))
        else:
            new_shape = tf.shape(X)[1:3]
            new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
        X = tf.compat.v1.image.resize_bilinear(X, new_shape)
        if target_height and target_width:
            X.set_shape((None, target_height, target_width, None))
        else:
            X.set_shape((None, original_shape[1] * height_factor, original_shape[2] * width_factor, None))
        return X
    else:
        raise Exception('Invalid data_format: ' + data_format)

class BilinearUpSampling2D(Layer):
    def __init__(self, size=(1, 1), target_size=None, data_format='default', **kwargs):
        if data_format == 'default':
            data_format = K.image_data_format()
        self.size = tuple(size)
        if target_size is not None:
            self.target_size = tuple(target_size)
        else:
            self.target_size = None
        assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {tf, th}'
        self.data_format = data_format
        self.input_spec = [InputSpec(ndim=4)]
        super(BilinearUpSampling2D, self).__init__(**kwargs)

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            width = int(self.size[0] * input_shape[2] if input_shape[2] is not None else None)
            height = int(self.size[1] * input_shape[3] if input_shape[3] is not None else None)
            if self.target_size is not None:
                width = self.target_size[0]
                height = self.target_size[1]
            return (input_shape[0],
                    input_shape[1],
                    width,
                    height)
        elif self.data_format == 'channels_last':
            width = int(self.size[0] * input_shape[1] if input_shape[1] is not None else None)
            height = int(self.size[1] * input_shape[2] if input_shape[2] is not None else None)
            if self.target_size is not None:
                width = self.target_size[0]
                height = self.target_size[1]
            return (input_shape[0],
                    width,
                    height,
                    input_shape[3])
        else:
            raise Exception('Invalid data_format: ' + self.data_format)

    def call(self, x, mask=None):
        if self.target_size is not None:
            return resize_images_bilinear(x, target_height=self.target_size[0], target_width=self.target_size[1], data_format=self.data_format)
        else:
            return resize_images_bilinear(x, height_factor=self.size[0], width_factor=self.size[1], data_format=self.data_format)

    def get_config(self):
        config = {'size': self.size, 'target_size': self.target_size}
        base_config = super(BilinearUpSampling2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

And then we can implement FCN:

In [ ]:
weight_decay = 0.
img_input = Input(shape=x_train.shape[1:])

# Block 1
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer=l2(weight_decay))(img_input)
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

# Block 2
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)

# Block 3
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)

# Block 4
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)

# Block 5
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)

# Convolutional layers transfered from fully-connected layers
x = Conv2D(4096, (7, 7), activation='relu', padding='same', name='fc1', kernel_regularizer=l2(weight_decay))(x)
x = Dropout(0.5)(x)
x = Conv2D(4096, (1, 1), activation='relu', padding='same', name='fc2', kernel_regularizer=l2(weight_decay))(x)
x = Dropout(0.5)(x)
#classifying layer
x = Conv2D(NCLASSES, (1, 1), kernel_initializer='he_normal', activation='softmax', padding='valid', strides=(1, 1), kernel_regularizer=l2(weight_decay))(x)

x = BilinearUpSampling2D(size=(32, 32))(x)

model = Model(img_input, x)
In [ ]:
model.compile(optimizer=Adam(learning_rate = 1e-5), loss = 'categorical_crossentropy', metrics=['accuracy'])
model.summary()
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)             │ (None, 224, 224, 3)         │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block1_conv1 (Conv2D)                │ (None, 224, 224, 64)        │           1,792 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block1_conv2 (Conv2D)                │ (None, 224, 224, 64)        │          36,928 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block1_pool (MaxPooling2D)           │ (None, 112, 112, 64)        │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block2_conv1 (Conv2D)                │ (None, 112, 112, 128)       │          73,856 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block2_conv2 (Conv2D)                │ (None, 112, 112, 128)       │         147,584 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block2_pool (MaxPooling2D)           │ (None, 56, 56, 128)         │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block3_conv1 (Conv2D)                │ (None, 56, 56, 256)         │         295,168 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block3_conv2 (Conv2D)                │ (None, 56, 56, 256)         │         590,080 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block3_conv3 (Conv2D)                │ (None, 56, 56, 256)         │         590,080 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block3_pool (MaxPooling2D)           │ (None, 28, 28, 256)         │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block4_conv1 (Conv2D)                │ (None, 28, 28, 512)         │       1,180,160 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block4_conv2 (Conv2D)                │ (None, 28, 28, 512)         │       2,359,808 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block4_conv3 (Conv2D)                │ (None, 28, 28, 512)         │       2,359,808 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block4_pool (MaxPooling2D)           │ (None, 14, 14, 512)         │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block5_conv1 (Conv2D)                │ (None, 14, 14, 512)         │       2,359,808 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block5_conv2 (Conv2D)                │ (None, 14, 14, 512)         │       2,359,808 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block5_conv3 (Conv2D)                │ (None, 14, 14, 512)         │       2,359,808 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ block5_pool (MaxPooling2D)           │ (None, 7, 7, 512)           │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ fc1 (Conv2D)                         │ (None, 7, 7, 4096)          │     102,764,544 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout (Dropout)                    │ (None, 7, 7, 4096)          │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ fc2 (Conv2D)                         │ (None, 7, 7, 4096)          │      16,781,312 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout_1 (Dropout)                  │ (None, 7, 7, 4096)          │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ conv2d (Conv2D)                      │ (None, 7, 7, 5)             │          20,485 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ bilinear_up_sampling2d               │ (None, 224, 224, 5)         │               0 │
│ (BilinearUpSampling2D)               │                             │                 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 134,281,029 (512.24 MB)
 Trainable params: 134,281,029 (512.24 MB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
history = model.fit(x = x_train, y= y_train, batch_size=8, epochs=100, verbose=1, shuffle=True, validation_split=0.25)
Epoch 1/100
165/165 [==============================] - 41s 177ms/step - loss: 1.0905 - accuracy: 0.5832 - val_loss: 1.0422 - val_accuracy: 0.5485
Epoch 2/100
165/165 [==============================] - 27s 161ms/step - loss: 0.9695 - accuracy: 0.5864 - val_loss: 1.0068 - val_accuracy: 0.5485
Epoch 3/100
165/165 [==============================] - 27s 161ms/step - loss: 0.8548 - accuracy: 0.6902 - val_loss: 0.8560 - val_accuracy: 0.7204
Epoch 4/100
165/165 [==============================] - 27s 161ms/step - loss: 0.7804 - accuracy: 0.7409 - val_loss: 0.7923 - val_accuracy: 0.7244
Epoch 5/100
165/165 [==============================] - 27s 163ms/step - loss: 0.7381 - accuracy: 0.7495 - val_loss: 0.7717 - val_accuracy: 0.7271
Epoch 6/100
165/165 [==============================] - 27s 161ms/step - loss: 0.7251 - accuracy: 0.7521 - val_loss: 0.7804 - val_accuracy: 0.7207
Epoch 7/100
165/165 [==============================] - 27s 161ms/step - loss: 0.7018 - accuracy: 0.7624 - val_loss: 0.7149 - val_accuracy: 0.7506
Epoch 8/100
165/165 [==============================] - 27s 161ms/step - loss: 0.6700 - accuracy: 0.7740 - val_loss: 0.7115 - val_accuracy: 0.7498
Epoch 9/100
165/165 [==============================] - 27s 161ms/step - loss: 0.6525 - accuracy: 0.7773 - val_loss: 0.6939 - val_accuracy: 0.7515
Epoch 10/100
165/165 [==============================] - 27s 161ms/step - loss: 0.6362 - accuracy: 0.7861 - val_loss: 0.6591 - val_accuracy: 0.7661
Epoch 11/100
165/165 [==============================] - 27s 161ms/step - loss: 0.6241 - accuracy: 0.7889 - val_loss: 0.6630 - val_accuracy: 0.7617
Epoch 12/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5960 - accuracy: 0.7969 - val_loss: 0.6481 - val_accuracy: 0.7665
Epoch 13/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5921 - accuracy: 0.7951 - val_loss: 0.6258 - val_accuracy: 0.7752
Epoch 14/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5846 - accuracy: 0.8002 - val_loss: 0.6122 - val_accuracy: 0.7759
Epoch 15/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5727 - accuracy: 0.8018 - val_loss: 0.6036 - val_accuracy: 0.7738
Epoch 16/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5550 - accuracy: 0.8016 - val_loss: 0.5861 - val_accuracy: 0.7892
Epoch 17/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5340 - accuracy: 0.8195 - val_loss: 0.5730 - val_accuracy: 0.8023
Epoch 18/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5217 - accuracy: 0.8235 - val_loss: 0.5458 - val_accuracy: 0.8135
Epoch 19/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5062 - accuracy: 0.8292 - val_loss: 0.5255 - val_accuracy: 0.8115
Epoch 20/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5021 - accuracy: 0.8286 - val_loss: 0.5174 - val_accuracy: 0.8159
Epoch 21/100
165/165 [==============================] - 27s 161ms/step - loss: 0.5065 - accuracy: 0.8260 - val_loss: 0.5452 - val_accuracy: 0.8066
Epoch 22/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4868 - accuracy: 0.8362 - val_loss: 0.4854 - val_accuracy: 0.8329
Epoch 23/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4566 - accuracy: 0.8470 - val_loss: 0.5565 - val_accuracy: 0.7996
Epoch 24/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4757 - accuracy: 0.8405 - val_loss: 0.6245 - val_accuracy: 0.7993
Epoch 25/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4723 - accuracy: 0.8434 - val_loss: 0.5246 - val_accuracy: 0.8215
Epoch 26/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4704 - accuracy: 0.8395 - val_loss: 0.4847 - val_accuracy: 0.8290
Epoch 27/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4427 - accuracy: 0.8500 - val_loss: 0.4659 - val_accuracy: 0.8391
Epoch 28/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4352 - accuracy: 0.8537 - val_loss: 0.5479 - val_accuracy: 0.8169
Epoch 29/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4457 - accuracy: 0.8498 - val_loss: 0.4910 - val_accuracy: 0.8341
Epoch 30/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4205 - accuracy: 0.8600 - val_loss: 0.4809 - val_accuracy: 0.8327
Epoch 31/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4232 - accuracy: 0.8565 - val_loss: 0.4565 - val_accuracy: 0.8410
Epoch 32/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4317 - accuracy: 0.8526 - val_loss: 0.4814 - val_accuracy: 0.8335
Epoch 33/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4258 - accuracy: 0.8567 - val_loss: 0.4397 - val_accuracy: 0.8491
Epoch 34/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4129 - accuracy: 0.8595 - val_loss: 0.4518 - val_accuracy: 0.8442
Epoch 35/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4127 - accuracy: 0.8623 - val_loss: 0.4490 - val_accuracy: 0.8476
Epoch 36/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4287 - accuracy: 0.8491 - val_loss: 0.4880 - val_accuracy: 0.8325
Epoch 37/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4117 - accuracy: 0.8608 - val_loss: 0.4987 - val_accuracy: 0.8255
Epoch 38/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3940 - accuracy: 0.8666 - val_loss: 0.4560 - val_accuracy: 0.8453
Epoch 39/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4500 - accuracy: 0.8465 - val_loss: 0.5320 - val_accuracy: 0.8182
Epoch 40/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3978 - accuracy: 0.8653 - val_loss: 0.4477 - val_accuracy: 0.8453
Epoch 41/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3945 - accuracy: 0.8678 - val_loss: 0.4358 - val_accuracy: 0.8513
Epoch 42/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3925 - accuracy: 0.8687 - val_loss: 0.4972 - val_accuracy: 0.8295
Epoch 43/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3867 - accuracy: 0.8711 - val_loss: 0.4341 - val_accuracy: 0.8489
Epoch 44/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3717 - accuracy: 0.8751 - val_loss: 0.4177 - val_accuracy: 0.8598
Epoch 45/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3817 - accuracy: 0.8727 - val_loss: 0.4346 - val_accuracy: 0.8490
Epoch 46/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3861 - accuracy: 0.8686 - val_loss: 0.5097 - val_accuracy: 0.8237
Epoch 47/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3781 - accuracy: 0.8718 - val_loss: 0.4344 - val_accuracy: 0.8508
Epoch 48/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3717 - accuracy: 0.8746 - val_loss: 0.4155 - val_accuracy: 0.8589
Epoch 49/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3773 - accuracy: 0.8728 - val_loss: 0.4277 - val_accuracy: 0.8544
Epoch 50/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4029 - accuracy: 0.8619 - val_loss: 0.4199 - val_accuracy: 0.8544
Epoch 51/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3673 - accuracy: 0.8751 - val_loss: 0.4124 - val_accuracy: 0.8595
Epoch 52/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3544 - accuracy: 0.8813 - val_loss: 0.4291 - val_accuracy: 0.8520
Epoch 53/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3687 - accuracy: 0.8755 - val_loss: 0.4444 - val_accuracy: 0.8515
Epoch 54/100
165/165 [==============================] - 26s 161ms/step - loss: 0.3524 - accuracy: 0.8812 - val_loss: 0.4079 - val_accuracy: 0.8626
Epoch 55/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3541 - accuracy: 0.8797 - val_loss: 0.4564 - val_accuracy: 0.8452
Epoch 56/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3654 - accuracy: 0.8774 - val_loss: 0.4007 - val_accuracy: 0.8644
Epoch 57/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3455 - accuracy: 0.8837 - val_loss: 0.4594 - val_accuracy: 0.8446
Epoch 58/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3585 - accuracy: 0.8782 - val_loss: 0.4240 - val_accuracy: 0.8534
Epoch 59/100
165/165 [==============================] - 26s 161ms/step - loss: 0.3512 - accuracy: 0.8811 - val_loss: 0.4029 - val_accuracy: 0.8645
Epoch 60/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3585 - accuracy: 0.8782 - val_loss: 0.4735 - val_accuracy: 0.8411
Epoch 61/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3477 - accuracy: 0.8811 - val_loss: 0.4168 - val_accuracy: 0.8578
Epoch 62/100
165/165 [==============================] - 27s 161ms/step - loss: 0.4076 - accuracy: 0.8617 - val_loss: 0.4335 - val_accuracy: 0.8528
Epoch 63/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3782 - accuracy: 0.8718 - val_loss: 0.3981 - val_accuracy: 0.8646
Epoch 64/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3410 - accuracy: 0.8859 - val_loss: 0.4115 - val_accuracy: 0.8586
Epoch 65/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3413 - accuracy: 0.8842 - val_loss: 0.4061 - val_accuracy: 0.8631
Epoch 66/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3341 - accuracy: 0.8873 - val_loss: 0.4045 - val_accuracy: 0.8601
Epoch 67/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3303 - accuracy: 0.8880 - val_loss: 0.3982 - val_accuracy: 0.8642
Epoch 68/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3597 - accuracy: 0.8771 - val_loss: 0.3985 - val_accuracy: 0.8632
Epoch 69/100
165/165 [==============================] - 27s 163ms/step - loss: 0.3452 - accuracy: 0.8818 - val_loss: 0.3951 - val_accuracy: 0.8650
Epoch 70/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3252 - accuracy: 0.8900 - val_loss: 0.3883 - val_accuracy: 0.8682
Epoch 71/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3270 - accuracy: 0.8895 - val_loss: 0.4153 - val_accuracy: 0.8575
Epoch 72/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3311 - accuracy: 0.8872 - val_loss: 0.3993 - val_accuracy: 0.8648
Epoch 73/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3498 - accuracy: 0.8800 - val_loss: 0.3948 - val_accuracy: 0.8666
Epoch 74/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3568 - accuracy: 0.8795 - val_loss: 0.4228 - val_accuracy: 0.8534
Epoch 75/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3290 - accuracy: 0.8880 - val_loss: 0.4173 - val_accuracy: 0.8572
Epoch 76/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3258 - accuracy: 0.8892 - val_loss: 0.3878 - val_accuracy: 0.8678
Epoch 77/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3172 - accuracy: 0.8921 - val_loss: 0.3799 - val_accuracy: 0.8696
Epoch 78/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3168 - accuracy: 0.8926 - val_loss: 0.3936 - val_accuracy: 0.8669
Epoch 79/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3249 - accuracy: 0.8891 - val_loss: 0.3927 - val_accuracy: 0.8665
Epoch 80/100
165/165 [==============================] - 26s 161ms/step - loss: 0.3186 - accuracy: 0.8910 - val_loss: 0.3868 - val_accuracy: 0.8687
Epoch 81/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3138 - accuracy: 0.8931 - val_loss: 0.3955 - val_accuracy: 0.8674
Epoch 82/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3453 - accuracy: 0.8803 - val_loss: 0.4385 - val_accuracy: 0.8475
Epoch 83/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3368 - accuracy: 0.8844 - val_loss: 0.3986 - val_accuracy: 0.8652
Epoch 84/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3212 - accuracy: 0.8910 - val_loss: 0.3982 - val_accuracy: 0.8658
Epoch 85/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3103 - accuracy: 0.8947 - val_loss: 0.3842 - val_accuracy: 0.8692
Epoch 86/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3087 - accuracy: 0.8951 - val_loss: 0.3829 - val_accuracy: 0.8713
Epoch 87/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3050 - accuracy: 0.8961 - val_loss: 0.3820 - val_accuracy: 0.8706
Epoch 88/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3168 - accuracy: 0.8911 - val_loss: 0.4049 - val_accuracy: 0.8597
Epoch 89/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3054 - accuracy: 0.8959 - val_loss: 0.3844 - val_accuracy: 0.8690
Epoch 90/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3044 - accuracy: 0.8958 - val_loss: 0.3796 - val_accuracy: 0.8708
Epoch 91/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3144 - accuracy: 0.8915 - val_loss: 0.4657 - val_accuracy: 0.8440
Epoch 92/100
165/165 [==============================] - 26s 160ms/step - loss: 0.4228 - accuracy: 0.8547 - val_loss: 0.4615 - val_accuracy: 0.8361
Epoch 93/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3361 - accuracy: 0.8855 - val_loss: 0.4121 - val_accuracy: 0.8582
Epoch 94/100
165/165 [==============================] - 27s 162ms/step - loss: 0.3144 - accuracy: 0.8924 - val_loss: 0.3858 - val_accuracy: 0.8692
Epoch 95/100
165/165 [==============================] - 27s 161ms/step - loss: 0.3044 - accuracy: 0.8959 - val_loss: 0.3856 - val_accuracy: 0.8690
Epoch 96/100
165/165 [==============================] - 27s 161ms/step - loss: 0.2998 - accuracy: 0.8974 - val_loss: 0.3817 - val_accuracy: 0.8715
Epoch 97/100
165/165 [==============================] - 27s 161ms/step - loss: 0.2977 - accuracy: 0.8978 - val_loss: 0.3778 - val_accuracy: 0.8698
Epoch 98/100
165/165 [==============================] - 27s 161ms/step - loss: 0.2939 - accuracy: 0.8991 - val_loss: 0.3941 - val_accuracy: 0.8670
Epoch 99/100
165/165 [==============================] - 27s 161ms/step - loss: 0.2939 - accuracy: 0.8990 - val_loss: 0.3785 - val_accuracy: 0.8704
Epoch 100/100
165/165 [==============================] - 27s 161ms/step - loss: 0.2976 - accuracy: 0.8974 - val_loss: 0.3766 - val_accuracy: 0.8721

After training, we will get the loss and accuracy curves:

In [ ]:
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='lower right')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper right')
plt.show()
No description has been provided for this image
No description has been provided for this image

So we can calculate the accuracy for the test set:

In [ ]:
del x_train, y_train
In [ ]:
x_test = x_test/255
In [ ]:
predict = model.predict(x_test)
24/24 [==============================] - 6s 159ms/step
In [ ]:
predict = np.round(predict)
In [ ]:
pred = np.argmax(predict, axis=3)
In [ ]:
true = np.argmax(y_test, axis=3)
In [ ]:
accuracy = accuracy_score(true.flatten(),pred.flatten())
print(accuracy)
0.8773271069918767

Finally, let's plot an example of the predicted result compared to the original mask:

In [ ]:
i = 100
plt.figure(figsize=[20,8])
plt.subplot(131)
plt.imshow(x_test[i])
plt.title('RGB Image')
plt.axis('off')
plt.subplot(132)
plt.imshow(true[i], vmin=0,vmax=4)
plt.title('Original Mask')
plt.axis('off')
plt.subplot(133)
plt.imshow(pred[i], vmin=0,vmax=4)
plt.title('Predict Mask')
plt.axis('off')
Out[ ]:
(-0.5, 255.5, 255.5, -0.5)
No description has been provided for this image

SegNet¶

image.png

SegNet has an encoder network and a corresponding decoder network, followed by a final pixel-wise classification layer. This architecture is illustrated in the figure above. The encoder network consists of 13 convolutional layers that correspond to the first 13 convolutional layers in the VGG16 network designed for object classification. They discard the fully connected layers in favor of retaining higher resolution feature maps in the output of the deeper encoder. This also significantly reduces the number of parameters in the SegNet encoder network (from 134M to 14.7M) compared to other recent architectures. Each encoder layer has a corresponding decoder layer, and thus the decoder network has 13 layers. The final output of the decoder is fed to a multi-class soft-max classifier to produce class probabilities for each pixel independently.

The encoder network performs convolution with a filter bank to produce a set of feature maps. These are then batch normalized. Next, an element-wise rectified linear nonlinearity (ReLU) max(0, x) is applied. Next, max pooling with a 2 × 2 window and stride 2 (non-overlapping window) is performed and the resulting output is subsampled by a factor of 2. Max pooling is used to achieve translation invariance at small spatial offsets in the input image. Subsampling results in a large input image context (spatial window) for each pixel in the feature map. Although multiple layers of max pooling and subsampling can achieve more translation invariance for robust classification, correspondingly there is a loss of spatial resolution of the feature maps. The increasingly lossy image representation (contour detail) is not beneficial for segmentation where contour delineation is vital. Therefore, it is necessary to capture and store threshold information on the encoder feature maps before performing downsampling. If memory during inference is not constrained, all encoder feature maps (after downsampling) can be stored. This is often not the case in practical applications, and so we propose a more efficient way to store this information. It involves storing only the max-pooling indices, i.e., the locations of the maximum feature value in each pooling window are memorized for each encoder feature map. In principle, this can be done using 2 bits for each 2×2 pooling window, and is therefore much more efficient to store compared to memorizing the feature map(s) in floating precision.

The decoder network samples its input feature map using the memorized max pooling indices from the corresponding encoder feature map(s). This step produces sparse feature map(s). This SegNet decoding technique is illustrated in the figure below. These feature maps are then combined with a trainable decoder filter bank to produce dense feature maps. A batch normalization step is then applied to each of these maps. Note that the decoder corresponding to the first encoder (closest to the input image) produces a multi-channel feature map, even though its encoder input has 3 channels (RGB). This is different from the other decoders in the network, which produce feature maps with the same size and number of channels as their encoder inputs. The high-dimensional feature representation at the final decoder output is fed to a trainable soft-max classifier. This soft-max classifies each pixel independently. The output of the soft-max classifier is an image with K channel probabilities where K is the number of classes. The predicted segmentation corresponds to the class with maximum probability at each pixel.

image.png

SegNet and FCN decoders. a, b, c, d correspond to values ​​in a feature map. SegNet uses the max pooling indices to upsample (without learning) the feature map and convolves it with a trainable decoder filter bank. FCN upsamples by learning to deconvolve the input feature map and adds the corresponding encoder feature map to produce the decoder output. This feature map is the output of the max pooling layer (includes downsampling) in the corresponding encoder. Note that there are no trainable decoder filters in FCN.

This is the SegNet implementation:

In [ ]:
droprate = 0.1
pool_size = (2,2)
inputs = Input(shape=x_train.shape[1:])
conv_1 = Conv2D(32, (3, 3), kernel_initializer='he_uniform', padding='same')(inputs)
conv_1 = BatchNormalization()(conv_1)
conv_1 = Activation("relu")(conv_1)
conv_1 = Dropout(droprate)(conv_1)
conv_2 = Conv2D(32, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_1)
conv_2 = BatchNormalization()(conv_2)
conv_2 = Activation("relu")(conv_2)


pool_1 = MaxPooling2D(pool_size=pool_size)(conv_2)

conv_3 = Conv2D(64, (3, 3), kernel_initializer='he_uniform', padding='same')(pool_1)
conv_3 = BatchNormalization()(conv_3)
conv_3 = Activation("relu")(conv_3)
conv_3 = Dropout(droprate)(conv_3)
conv_4 = Conv2D(64, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_3)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Activation("relu")(conv_4)


pool_2 = MaxPooling2D(pool_size=pool_size)(conv_4)

conv_5 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(pool_2)
conv_5 = BatchNormalization()(conv_5)
conv_5 = Activation("relu")(conv_5)
conv_5 = Dropout(droprate)(conv_5)
conv_6 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_5)
conv_6 = BatchNormalization()(conv_6)
conv_6 = Activation("relu")(conv_6)
conv_6 = Dropout(droprate)(conv_6)
conv_7 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_6)
conv_7 = BatchNormalization()(conv_7)
conv_7 = Activation("relu")(conv_7)



pool_3 = MaxPooling2D(pool_size=pool_size)(conv_7)

conv_8 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(pool_3)
conv_8 = BatchNormalization()(conv_8)
conv_8 = Activation("relu")(conv_8)
conv_8 = Dropout(droprate)(conv_8)
conv_9 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_8)
conv_9 = BatchNormalization()(conv_9)
conv_9 = Activation("relu")(conv_9)
conv_9 = Dropout(droprate)(conv_9)
conv_10 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_9)
conv_10 = BatchNormalization()(conv_10)
conv_10 = Activation("relu")(conv_10)



pool_4 = MaxPooling2D(pool_size=pool_size)(conv_10)


conv_11 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(pool_4)
conv_11 = BatchNormalization()(conv_11)
conv_11 = Activation("relu")(conv_11)
conv_11 = Dropout(droprate)(conv_11)
conv_12 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_11)
conv_12 = BatchNormalization()(conv_12)
conv_12 = Activation("relu")(conv_12)
conv_12 = Dropout(droprate)(conv_12)
conv_13 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_12)
conv_13 = BatchNormalization()(conv_13)
conv_13 = Activation("relu")(conv_13)

pool_5 = MaxPooling2D(pool_size=pool_size)(conv_13)

unpool_1 = UpSampling2D(size=pool_size)(pool_5)

conv_14 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_1)
conv_14 = BatchNormalization()(conv_14)
conv_14 = Activation("relu")(conv_14)
conv_14 = Dropout(droprate)(conv_14)
conv_15 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_14)
conv_15 = BatchNormalization()(conv_15)
conv_15 = Activation("relu")(conv_15)
conv_15 = Dropout(droprate)(conv_15)
conv_16 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_15)
conv_16 = BatchNormalization()(conv_16)
conv_16 = Activation("relu")(conv_16)

unpool_2 = UpSampling2D(size=pool_size)(conv_16)

conv_17 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_2)
conv_17 = BatchNormalization()(conv_17)
conv_17 = Activation("relu")(conv_17)
conv_17 = Dropout(droprate)(conv_17)
conv_18 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_17)
conv_18 = BatchNormalization()(conv_18)
conv_18 = Activation("relu")(conv_18)
conv_18 = Dropout(droprate)(conv_18)
conv_19 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_18)
conv_19 = BatchNormalization()(conv_19)
conv_19 = Activation("relu")(conv_19)

unpool_3 = UpSampling2D(size=pool_size)(conv_19)


conv_20 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_3)
conv_20 = BatchNormalization()(conv_20)
conv_20 = Activation("relu")(conv_20)
conv_20 = Dropout(droprate)(conv_20)
conv_21 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_20)
conv_21 = BatchNormalization()(conv_21)
conv_21 = Activation("relu")(conv_21)
conv_21 = Dropout(droprate)(conv_21)
conv_22 = Conv2D(64, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_21)
conv_22 = BatchNormalization()(conv_22)
conv_22 = Activation("relu")(conv_22)


unpool_4 = UpSampling2D(size=pool_size)(conv_22)


conv_23 = Conv2D(64, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_4)
conv_23 = BatchNormalization()(conv_23)
conv_23 = Activation("relu")(conv_23)
conv_23 = Dropout(droprate)(conv_23)
conv_24 = Conv2D(32, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_23)
conv_24 = BatchNormalization()(conv_24)
conv_24 = Activation("relu")(conv_24)


unpool_5 = UpSampling2D(size=pool_size)(conv_24)


conv_25 = Conv2D(32, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_5)
conv_25 = BatchNormalization()(conv_25)
conv_25 = Activation("relu")(conv_25)

conv_26 = Conv2D(NCLASSES, (1, 1), kernel_initializer='he_uniform', padding='same')(conv_25)
conv_26 = BatchNormalization()(conv_26)
outputs = Activation("softmax")(conv_26)


model = Model(inputs=inputs, outputs=outputs)

model.compile(optimizer=Adam(learning_rate=0.00001), loss = Dice, metrics=[ 'accuracy'])
model.summary()
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
conv2d_26 (Conv2D)           (None, 224, 224, 32)      896       
_________________________________________________________________
batch_normalization_26 (Batc (None, 224, 224, 32)      128       
_________________________________________________________________
activation_26 (Activation)   (None, 224, 224, 32)      0         
_________________________________________________________________
dropout_15 (Dropout)         (None, 224, 224, 32)      0         
_________________________________________________________________
conv2d_27 (Conv2D)           (None, 224, 224, 32)      9248      
_________________________________________________________________
batch_normalization_27 (Batc (None, 224, 224, 32)      128       
_________________________________________________________________
activation_27 (Activation)   (None, 224, 224, 32)      0         
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 112, 112, 32)      0         
_________________________________________________________________
conv2d_28 (Conv2D)           (None, 112, 112, 64)      18496     
_________________________________________________________________
batch_normalization_28 (Batc (None, 112, 112, 64)      256       
_________________________________________________________________
activation_28 (Activation)   (None, 112, 112, 64)      0         
_________________________________________________________________
dropout_16 (Dropout)         (None, 112, 112, 64)      0         
_________________________________________________________________
conv2d_29 (Conv2D)           (None, 112, 112, 64)      36928     
_________________________________________________________________
batch_normalization_29 (Batc (None, 112, 112, 64)      256       
_________________________________________________________________
activation_29 (Activation)   (None, 112, 112, 64)      0         
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 56, 56, 64)        0         
_________________________________________________________________
conv2d_30 (Conv2D)           (None, 56, 56, 128)       73856     
_________________________________________________________________
batch_normalization_30 (Batc (None, 56, 56, 128)       512       
_________________________________________________________________
activation_30 (Activation)   (None, 56, 56, 128)       0         
_________________________________________________________________
dropout_17 (Dropout)         (None, 56, 56, 128)       0         
_________________________________________________________________
conv2d_31 (Conv2D)           (None, 56, 56, 128)       147584    
_________________________________________________________________
batch_normalization_31 (Batc (None, 56, 56, 128)       512       
_________________________________________________________________
activation_31 (Activation)   (None, 56, 56, 128)       0         
_________________________________________________________________
dropout_18 (Dropout)         (None, 56, 56, 128)       0         
_________________________________________________________________
conv2d_32 (Conv2D)           (None, 56, 56, 128)       147584    
_________________________________________________________________
batch_normalization_32 (Batc (None, 56, 56, 128)       512       
_________________________________________________________________
activation_32 (Activation)   (None, 56, 56, 128)       0         
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 28, 28, 128)       0         
_________________________________________________________________
conv2d_33 (Conv2D)           (None, 28, 28, 256)       295168    
_________________________________________________________________
batch_normalization_33 (Batc (None, 28, 28, 256)       1024      
_________________________________________________________________
activation_33 (Activation)   (None, 28, 28, 256)       0         
_________________________________________________________________
dropout_19 (Dropout)         (None, 28, 28, 256)       0         
_________________________________________________________________
conv2d_34 (Conv2D)           (None, 28, 28, 256)       590080    
_________________________________________________________________
batch_normalization_34 (Batc (None, 28, 28, 256)       1024      
_________________________________________________________________
activation_34 (Activation)   (None, 28, 28, 256)       0         
_________________________________________________________________
dropout_20 (Dropout)         (None, 28, 28, 256)       0         
_________________________________________________________________
conv2d_35 (Conv2D)           (None, 28, 28, 256)       590080    
_________________________________________________________________
batch_normalization_35 (Batc (None, 28, 28, 256)       1024      
_________________________________________________________________
activation_35 (Activation)   (None, 28, 28, 256)       0         
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 14, 14, 256)       0         
_________________________________________________________________
conv2d_36 (Conv2D)           (None, 14, 14, 512)       1180160   
_________________________________________________________________
batch_normalization_36 (Batc (None, 14, 14, 512)       2048      
_________________________________________________________________
activation_36 (Activation)   (None, 14, 14, 512)       0         
_________________________________________________________________
dropout_21 (Dropout)         (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_37 (Conv2D)           (None, 14, 14, 512)       2359808   
_________________________________________________________________
batch_normalization_37 (Batc (None, 14, 14, 512)       2048      
_________________________________________________________________
activation_37 (Activation)   (None, 14, 14, 512)       0         
_________________________________________________________________
dropout_22 (Dropout)         (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_38 (Conv2D)           (None, 14, 14, 512)       2359808   
_________________________________________________________________
batch_normalization_38 (Batc (None, 14, 14, 512)       2048      
_________________________________________________________________
activation_38 (Activation)   (None, 14, 14, 512)       0         
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 7, 7, 512)         0         
_________________________________________________________________
up_sampling2d_5 (UpSampling2 (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_39 (Conv2D)           (None, 14, 14, 512)       2359808   
_________________________________________________________________
batch_normalization_39 (Batc (None, 14, 14, 512)       2048      
_________________________________________________________________
activation_39 (Activation)   (None, 14, 14, 512)       0         
_________________________________________________________________
dropout_23 (Dropout)         (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_40 (Conv2D)           (None, 14, 14, 512)       2359808   
_________________________________________________________________
batch_normalization_40 (Batc (None, 14, 14, 512)       2048      
_________________________________________________________________
activation_40 (Activation)   (None, 14, 14, 512)       0         
_________________________________________________________________
dropout_24 (Dropout)         (None, 14, 14, 512)       0         
_________________________________________________________________
conv2d_41 (Conv2D)           (None, 14, 14, 256)       1179904   
_________________________________________________________________
batch_normalization_41 (Batc (None, 14, 14, 256)       1024      
_________________________________________________________________
activation_41 (Activation)   (None, 14, 14, 256)       0         
_________________________________________________________________
up_sampling2d_6 (UpSampling2 (None, 28, 28, 256)       0         
_________________________________________________________________
conv2d_42 (Conv2D)           (None, 28, 28, 256)       590080    
_________________________________________________________________
batch_normalization_42 (Batc (None, 28, 28, 256)       1024      
_________________________________________________________________
activation_42 (Activation)   (None, 28, 28, 256)       0         
_________________________________________________________________
dropout_25 (Dropout)         (None, 28, 28, 256)       0         
_________________________________________________________________
conv2d_43 (Conv2D)           (None, 28, 28, 256)       590080    
_________________________________________________________________
batch_normalization_43 (Batc (None, 28, 28, 256)       1024      
_________________________________________________________________
activation_43 (Activation)   (None, 28, 28, 256)       0         
_________________________________________________________________
dropout_26 (Dropout)         (None, 28, 28, 256)       0         
_________________________________________________________________
conv2d_44 (Conv2D)           (None, 28, 28, 128)       295040    
_________________________________________________________________
batch_normalization_44 (Batc (None, 28, 28, 128)       512       
_________________________________________________________________
activation_44 (Activation)   (None, 28, 28, 128)       0         
_________________________________________________________________
up_sampling2d_7 (UpSampling2 (None, 56, 56, 128)       0         
_________________________________________________________________
conv2d_45 (Conv2D)           (None, 56, 56, 128)       147584    
_________________________________________________________________
batch_normalization_45 (Batc (None, 56, 56, 128)       512       
_________________________________________________________________
activation_45 (Activation)   (None, 56, 56, 128)       0         
_________________________________________________________________
dropout_27 (Dropout)         (None, 56, 56, 128)       0         
_________________________________________________________________
conv2d_46 (Conv2D)           (None, 56, 56, 128)       147584    
_________________________________________________________________
batch_normalization_46 (Batc (None, 56, 56, 128)       512       
_________________________________________________________________
activation_46 (Activation)   (None, 56, 56, 128)       0         
_________________________________________________________________
dropout_28 (Dropout)         (None, 56, 56, 128)       0         
_________________________________________________________________
conv2d_47 (Conv2D)           (None, 56, 56, 64)        73792     
_________________________________________________________________
batch_normalization_47 (Batc (None, 56, 56, 64)        256       
_________________________________________________________________
activation_47 (Activation)   (None, 56, 56, 64)        0         
_________________________________________________________________
up_sampling2d_8 (UpSampling2 (None, 112, 112, 64)      0         
_________________________________________________________________
conv2d_48 (Conv2D)           (None, 112, 112, 64)      36928     
_________________________________________________________________
batch_normalization_48 (Batc (None, 112, 112, 64)      256       
_________________________________________________________________
activation_48 (Activation)   (None, 112, 112, 64)      0         
_________________________________________________________________
dropout_29 (Dropout)         (None, 112, 112, 64)      0         
_________________________________________________________________
conv2d_49 (Conv2D)           (None, 112, 112, 32)      18464     
_________________________________________________________________
batch_normalization_49 (Batc (None, 112, 112, 32)      128       
_________________________________________________________________
activation_49 (Activation)   (None, 112, 112, 32)      0         
_________________________________________________________________
up_sampling2d_9 (UpSampling2 (None, 224, 224, 32)      0         
_________________________________________________________________
conv2d_50 (Conv2D)           (None, 224, 224, 32)      9248      
_________________________________________________________________
batch_normalization_50 (Batc (None, 224, 224, 32)      128       
_________________________________________________________________
activation_50 (Activation)   (None, 224, 224, 32)      0         
_________________________________________________________________
conv2d_51 (Conv2D)           (None, 224, 224, 5)       165       
_________________________________________________________________
batch_normalization_51 (Batc (None, 224, 224, 5)       20        
_________________________________________________________________
activation_51 (Activation)   (None, 224, 224, 5)       0         
=================================================================
Total params: 15,639,193
Trainable params: 15,628,687
Non-trainable params: 10,506
_________________________________________________________________
/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")

References:

https://nanonets.com/blog/how-to-do-semantic-segmentation-using-deep-learning/

https://nanonets.com/blog/semantic-image-segmentation-2020/

https://www.jeremyjordan.me/semantic-segmentation/

https://towardsdatascience.com/review-fcn-semantic-segmentation-eb8c9b50d2d1

https://www.meetshah.dev/semantic-segmentation/deep-learning/pytorch/visdom/2017/06/01/semantic-segmentation-over-the-years.html

http://ronny.rest/tutorials/module/seg_01/segmentation_04_training/

https://github.com/ykamikawa/tf-keras-SegNet/blob/master/model.py

https://ai-pool.com/m/segnet-1555409707

https://medium.com/analytics-vidhya/semantic-segmentation-in-pspnet-with-implementation-in-keras-4843d05fc025

https://www.kaggle.com/santhalnr/cityscapes-image-segmentation-pspnet

https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html