Semantic Image Segmentation - Part 1¶
Deep Learning Methods for Semantic Segmentation¶
So far, you have seen image classification, where the task of the network is to assign a label or class to an input image. However, suppose you want to know where an object is located in the image, what shape that object is, which pixel belongs to which object, etc. In this case, you will want to segment the image, i.e., give each pixel in the image a label. So, the task of image segmentation is to train a neural network to produce a pixel-wise mask of the image. This helps in understanding the image at a much lower level, i.e., the pixel level. Image segmentation has many applications in medical imaging, self-driving cars, and satellite imagery, to name a few.
More precisely, semantic image segmentation is the task of labeling each pixel in the image into a predefined set of classes. For example, various objects like cars, trees, people, traffic signs, etc. can be used as classes for semantic image segmentation. So, the task is to take an image (RGB or grayscale) and produce a W x H x 1 matrix, where W and H represent the width and height of the image, respectively. Each cell in this matrix would contain the predicted class IDs for each pixel in the image.
Image Classification vs. Semantic Segmentation¶
When it comes to semantic segmentation, we usually don’t require a fully connected layer at the end because our goal is not to predict the class label of the image.
In semantic segmentation, our goal is to extract features before using them to separate the image into multiple segments.
However, the problem with convolutional networks is that the image size gets reduced as it passes through the network due to the max-pooling layers.
To efficiently separate the image into multiple segments, we need to upsample using an interpolation technique, which is achieved using deconvolutional layers.
Labels in Semantic Segmentation¶
In deep learning, we express categorical class labels as one-hot encoded vectors. Similarly, in semantic segmentation, we can express the output matrix using a one-hot encoding scheme by essentially creating a channel for each class label and marking the cells by 1 that contain the pixel of the corresponding class and marking the remaining cells by 0.
Implementing a Semantic Segmentation Architecture¶
Like other computer vision tasks, using a CNN for semantic segmentation would be the obvious choice. When using CNN for semantic segmentation, the output would be an image with the same resolution as the input, unlike a fixed-length vector in the case of image classification.
A naive approach to building a neural network architecture for this task is to simply stack multiple convolutional layers (with the same padding to preserve dimensions) and output a final segmentation map. This directly learns a mapping from the input image to its corresponding segmentation through successive transformation of feature mappings; however, it is quite computationally expensive to preserve full resolution across the network.
Recall that for deep convolutional networks, earlier layers tend to learn low-level concepts, while later layers develop more high-level (and specialized) feature mappings. To maintain expressiveness, we typically need to increase the number of feature maps (channels) as we go deeper into the network. This is not necessarily a problem for the task of image classification, since for that task we only care about what the image contains (and not where it is located). Thus, we could alleviate the computational burden by periodically downsampling our feature maps via pooling or strided convolutions (i.e., compressing the spatial resolution) without worrying. However, for image segmentation, we would like our model to produce a full-resolution semantic prediction. A popular approach for image segmentation models is to follow an encoder/decoder framework where we reduce the spatial resolution of the input, developing lower resolution feature mappings that are highly efficient at discriminating between classes, and up-resolution the feature representations into a full resolution segmentation map.
Methods for upsampling:¶
There are a few different approaches we can use to increase the resolution of a feature map. While pooling operations decrease resolution by summarizing a local area with a single value (i.e., average or max pooling), unpooling operations increase resolution by distributing a single value across a higher resolution.
However, transposed convolutions are by far the most popular approach as they allow us to build on learned upsampling.
Whereas a typical convolution operation will take the dot product of the values currently in the filter view and produce a single value for the corresponding output position, a transposition convolution essentially does the opposite. For a transposition convolution, we take a single value from the low-resolution feature map and multiply all of the weights in our filter by that value, projecting those weighted values onto the output feature map.
For filter sizes that produce an overlap in the output feature map (e.g. 3x3 filter with stride 2 - as shown in the example below), the overlapping values are simply summed. Unfortunately, this tends to produce a checkerboard artifact in the output and is undesirable, so it is best to ensure that the filter size does not produce an overlap.
Loss Functions for Semantic Segmentation¶
The loss function is used to guide the neural network towards optimization. Let’s discuss some popular loss functions for the semantic segmentation task. Semantic segmentation models usually use a simple categorical entropy loss function during training. However, if you are interested in getting the granular information of an image, you will need to revert to slightly more advanced loss functions.
Distribution-based loss¶
Cross Entropy Loss¶
The most commonly used loss function for the image segmentation task is a pixel-wise cross-entropy loss. This loss examines each pixel individually, comparing the class predictions (pixel depth vector) with our one-hot encoded target vector.
Since the cross-entropy loss evaluates the class predictions for each pixel vector individually and then averages them across all pixels, we are essentially asserting equal learning for each pixel in the image. This can be a problem if your various classes have an imbalanced representation in the image, as training may be dominated by the most prevalent class. Long et al. (FCN paper) discusses weighting this loss for each output channel in order to counteract a class imbalance present in the dataset.
Weighted cross entropy¶
It is an extension of the CE, which assigns different weights to each class. In general, classes not presented will have higher weights.
Focal Loss¶
It adapts the CE pattern to handle extreme foreground-background class imbalance, where the loss attributed to well-classified examples is reduced.
Region-based Loss¶
Region-based loss functions aim to minimize the mismatch or maximize the overlapping regions between the ground truth and the predicted segmentation.
Dice Loss¶
The Dice function is nothing but the F1 score. This loss function directly tries to optimize the F1 score. Similarly, the direct IOU score can also be used to perform optimization
IoU Loss¶
IoU Loss (also called Jaccard Loss), similar to Dice Loss, is also used to directly optimize the targeting metric.
Tversky loss¶
Sets different weights for false negative (FN) and false positive (FP), which is different from data loss using equal weights for FN and FP.
Focal Tversky loss¶
Applies the concept of focal loss to focus on difficult cases with low probabilities.
Metricas para Segmentação Semântica¶
Let's discuss the metrics that are commonly used to understand and evaluate the results of a model.
Pixel Accuracy¶
Pixel accuracy is the most basic metric that can be used to validate the results. Accuracy is obtained by taking the proportion of correctly classified pixels to the total pixels
Accuracy = (TP+TN)/(TP+TN+FP+FN)
The main drawback of using this technique is that the result may look good if one class dominates the other. Say, for example, the background class covers 90% of the input image, we can achieve 90% accuracy just by classifying every pixel as background
Metrics based on IoU¶
Intuitively, a successful prediction is one that maximizes the overlap between the predicted and true objects. Two related but different metrics for this goal are the Dice and Jaccard coefficients (or indices):
IOU is defined as the ratio of the intersection of the ground truth and predicted segmentation outputs over their union. If we are calculating for multiple classes, then the IOU of each class is calculated and its average is taken. It is a better metric compared to pixel-wise accuracy because if each pixel is given as background in a 2-class input, then the IOU value is (90/100+0/100)/2 i.e. 45% IOU which gives a better representation compared to 90% accuracy.
In terms of the confusion matrix, the metrics can be reformulated in terms of true/false positives/negatives:
Here is an illustration of the Dice and IoU metrics given two circles representing the ground truth and predicted masks for an arbitrary object class:
Frequency weighted IOU¶
This is an extension on top of the average IOU we discussed and is used to combat class imbalance. If a class dominates most of the images in a dataset, such as background, it needs to be weighted compared to other classes. So, instead of taking the average of all the results for the class, a weighted average is taken based on the frequency of the region of the class in the dataset.
FCN - Fully Convolutional Networks¶
The general architecture of a CNN consists of a few convolutional and pooling layers followed by a few fully connected layers at the end. The Fully Convolutional Network paper released in 2014 argues that the final fully connected layer can be thought of as doing a 1x1 convolution that covers the entire region.
Thus, the final dense layers can be replaced with a convolution layer achieving the same result. But now the advantage of doing so is that the input size no longer needs to be fixed. By involving dense layers, the input size is restricted and hence when a different size input is to be provided, it has to be resized. But by replacing a dense layer with convolution, this restriction does not exist.
Furthermore, when a larger size of image is provided as input, the output produced will be a feature map and not just a class output as for a normal input size image. Furthermore, the observed behavior of the final feature map represents the heatmap of the required class, i.e. the position of the object is highlighted in the feature map. Since the output of the feature map is a heatmap of the required object, it is valid information for our segmentation use case.
Since the feature map obtained at the output layer is a downsampled due to the set of convolutions performed, we would like to upsample it using an interpolation technique. Bilinear upsampling works, but the paper proposes using learned upsampling with deconvolution, which can even learn nonlinear upsampling.
The downsampling part of the network is called the encoder, and the upsampling part is called the decoder. This is a pattern we will see in many architectures, namely downsampling with the encoder and then upsampling with the decoder. In an ideal world, we would not want to downsample using pooling and keep the same size, but this would lead to a huge amount of parameters and would be computationally infeasible.
Although the output results obtained were decent, the observed output is rough and not smooth. The reason for this is the loss of information in the final feature layer due to 32x downsampling using convolutional layers. Now it becomes very difficult for the network to do 32x upsampling using this little information. This architecture is called FCN-32
To solve this problem, the paper proposed 2 other architectures FCN-16, FCN-8. In FCN-16, the information from the previous pooling layer is used along with the final feature map and hence now the task of the network is to learn 16x upsampling which is better as compared to FCN-32. FCN-8 tries to make it even better by including information from one more previous pooling layer.
Use Case: Usage and Coverage Class Segmentation¶
Let's start with an example using the Landcover AI dataset:
The LandCover.ai (Land Cover from Aerial Imagery) dataset is a dataset for automatic mapping of buildings, forests, water and roads from aerial imagery.
Dataset features
- land cover of Poland, Central Europe 1
- three spectral bands - RGB
- 33 orthophotos with 25 cm pixel resolution (~9000x9500 px)
- 8 orthophotos with 50 cm pixel resolution (~4200x4700 px)
- total area of 216.27 km2
Dataset format
- rasters are three-channel GeoTiffs with EPSG:2180 spatial reference system
- masks are single-channel GeoTiffs with EPSG:2180 spatial reference system
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import glob
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
Mosaics are stored in Drive. The first step is to divide them into 512x512 pixel patches
IMGS_DIR = "drive/MyDrive/Datasets/Landcover_AI/images"
MASKS_DIR = "drive/MyDrive/Datasets/Landcover_AI/masks"
OUTPUT_DIR = "./output"
TARGET_SIZE = 1024
img_paths = glob.glob(os.path.join(IMGS_DIR, "*.tif"))
mask_paths = glob.glob(os.path.join(MASKS_DIR, "*.tif"))
img_paths.sort()
mask_paths.sort()
os.makedirs(OUTPUT_DIR)
for i, (img_path, mask_path) in enumerate(zip(img_paths, mask_paths)):
img_filename = os.path.splitext(os.path.basename(img_path))[0]
mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
img = cv2.imread(img_path)
mask = cv2.imread(mask_path)
mask = mask[:,:,0]
assert img_filename == mask_filename and img.shape[:2] == mask.shape[:2]
k = 0
for y in range(0, img.shape[0], TARGET_SIZE):
for x in range(0, img.shape[1], TARGET_SIZE):
img_tile = img[y:y + TARGET_SIZE, x:x + TARGET_SIZE]
mask_tile = mask[y:y + TARGET_SIZE, x:x + TARGET_SIZE]
if img_tile.shape[0] == TARGET_SIZE and img_tile.shape[1] == TARGET_SIZE:
out_img_path = os.path.join(OUTPUT_DIR, "{}_{}.jpg".format(img_filename, k))
cv2.imwrite(out_img_path, img_tile)
out_mask_path = os.path.join(OUTPUT_DIR, "{}_{}_m.png".format(mask_filename, k))
cv2.imwrite(out_mask_path, mask_tile)
k += 1
print("Processed {} {}/{}".format(img_filename, i + 1, len(img_paths)))
Processed M-33-20-D-c-4-2 1/41 Processed M-33-20-D-d-3-3 2/41 Processed M-33-32-B-b-4-4 3/41 Processed M-33-48-A-c-4-4 4/41 Processed M-33-7-A-d-2-3 5/41 Processed M-33-7-A-d-3-2 6/41 Processed M-34-32-B-a-4-3 7/41 Processed M-34-32-B-b-1-3 8/41 Processed M-34-5-D-d-4-2 9/41 Processed M-34-51-C-b-2-1 10/41 Processed M-34-51-C-d-4-1 11/41 Processed M-34-55-B-b-4-1 12/41 Processed M-34-56-A-b-1-4 13/41 Processed M-34-6-A-d-2-2 14/41 Processed M-34-65-D-a-4-4 15/41 Processed M-34-65-D-c-4-2 16/41 Processed M-34-65-D-d-4-1 17/41 Processed M-34-68-B-a-1-3 18/41 Processed M-34-77-B-c-2-3 19/41 Processed N-33-104-A-c-1-1 20/41 Processed N-33-119-C-c-3-3 21/41 Processed N-33-130-A-d-3-3 22/41 Processed N-33-130-A-d-4-4 23/41 Processed N-33-139-C-d-2-2 24/41 Processed N-33-139-C-d-2-4 25/41 Processed N-33-139-D-c-1-3 26/41 Processed N-33-60-D-c-4-2 27/41 Processed N-33-60-D-d-1-2 28/41 Processed N-33-96-D-d-1-1 29/41 Processed N-34-106-A-b-3-4 30/41 Processed N-34-106-A-c-1-3 31/41 Processed N-34-140-A-b-3-2 32/41 Processed N-34-140-A-b-4-2 33/41 Processed N-34-140-A-d-3-4 34/41 Processed N-34-140-A-d-4-2 35/41 Processed N-34-61-B-a-1-1 36/41 Processed N-34-66-C-c-4-3 37/41 Processed N-34-77-A-b-1-4 38/41 Processed N-34-94-A-b-2-4 39/41 Processed N-34-97-C-b-1-2 40/41 Processed N-34-97-D-c-2-4 41/41
Patches were created and stored in a folder in Colab content. Please note that once your session ends, this data will be lost.
path = '/content/output'
Let's then import the images and their respective masks and convert them into numpy arrays:
list_img = [f for f in os.listdir(path) if f.endswith('.jpg')]
X = []
Y = []
NCLASSES = 5
for path_img in list_img:
full_path = os.path.join(path, path_img)
img = cv2.imread(full_path)
img = cv2.resize(img, (256,256))
id_img = path_img.split('.')[0]
mask_id = id_img + '_m.png'
mask_path = os.path.join(path, mask_id)
mask = cv2.imread(mask_path)
mask = cv2.resize(mask, (256,256))
mask = mask[:,:,0]
mask = to_categorical(mask,NCLASSES)
X.append(img)
Y.append(mask)
X = np.array(X)
Y = np.array(Y)
X.shape
(2513, 256, 256, 3)
So we can plot an example of an image and its mask:
plt.figure(figsize=[6,6])
plt.imshow(X[100])
plt.axis('off')
(-0.5, 255.5, 255.5, -0.5)
plt.figure(figsize=[6,6])
plt.imshow(np.argmax(Y[100], axis=2))
plt.axis('off')
(-0.5, 255.5, 255.5, -0.5)
Now it's time to split the data into training and testing, rescaling the values to a range from 0 to 1:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=10)
del X, Y
del list_img
x_train = x_train/255
The next step is to import some keras and tensorflow functions for the FCN implementation:
from keras.models import Model
from keras.regularizers import l2
from keras.layers import *
from keras.models import *
from tensorflow.keras import backend as K
import tensorflow as tf
#from tensorflow.keras.optimizers import Adam
#from tensorflow.keras.optimizers.legacy import Adam
from keras.optimizers import Adam
from tensorflow.keras.losses import Dice
This cell creates the BilinearUpsampling2D operation that will be used in the final part of the FCN architecture:
def resize_images_bilinear(X, height_factor=1, width_factor=1, target_height=None, target_width=None, data_format='default'):
'''Resizes the images contained in a 4D tensor of shape
- [batch, channels, height, width] (for 'channels_first' data_format)
- [batch, height, width, channels] (for 'channels_last' data_format)
by a factor of (height_factor, width_factor). Both factors should be
positive integers.
'''
if data_format == 'default':
data_format = K.image_data_format()
if data_format == 'channels_first':
original_shape = K.int_shape(X)
if target_height and target_width:
new_shape = tf.constant(np.array((target_height, target_width)).astype('int32'))
else:
new_shape = tf.shape(X)[2:]
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
X = permute_dimensions(X, [0, 2, 3, 1])
X = tf.image.resize_bilinear(X, new_shape)
X = permute_dimensions(X, [0, 3, 1, 2])
if target_height and target_width:
X.set_shape((None, None, target_height, target_width))
else:
X.set_shape((None, None, original_shape[2] * height_factor, original_shape[3] * width_factor))
return X
elif data_format == 'channels_last':
original_shape = K.int_shape(X)
if target_height and target_width:
new_shape = tf.constant(np.array((target_height, target_width)).astype('int32'))
else:
new_shape = tf.shape(X)[1:3]
new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
X = tf.compat.v1.image.resize_bilinear(X, new_shape)
if target_height and target_width:
X.set_shape((None, target_height, target_width, None))
else:
X.set_shape((None, original_shape[1] * height_factor, original_shape[2] * width_factor, None))
return X
else:
raise Exception('Invalid data_format: ' + data_format)
class BilinearUpSampling2D(Layer):
def __init__(self, size=(1, 1), target_size=None, data_format='default', **kwargs):
if data_format == 'default':
data_format = K.image_data_format()
self.size = tuple(size)
if target_size is not None:
self.target_size = tuple(target_size)
else:
self.target_size = None
assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {tf, th}'
self.data_format = data_format
self.input_spec = [InputSpec(ndim=4)]
super(BilinearUpSampling2D, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
width = int(self.size[0] * input_shape[2] if input_shape[2] is not None else None)
height = int(self.size[1] * input_shape[3] if input_shape[3] is not None else None)
if self.target_size is not None:
width = self.target_size[0]
height = self.target_size[1]
return (input_shape[0],
input_shape[1],
width,
height)
elif self.data_format == 'channels_last':
width = int(self.size[0] * input_shape[1] if input_shape[1] is not None else None)
height = int(self.size[1] * input_shape[2] if input_shape[2] is not None else None)
if self.target_size is not None:
width = self.target_size[0]
height = self.target_size[1]
return (input_shape[0],
width,
height,
input_shape[3])
else:
raise Exception('Invalid data_format: ' + self.data_format)
def call(self, x, mask=None):
if self.target_size is not None:
return resize_images_bilinear(x, target_height=self.target_size[0], target_width=self.target_size[1], data_format=self.data_format)
else:
return resize_images_bilinear(x, height_factor=self.size[0], width_factor=self.size[1], data_format=self.data_format)
def get_config(self):
config = {'size': self.size, 'target_size': self.target_size}
base_config = super(BilinearUpSampling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
And then we can implement FCN:
weight_decay = 0.
img_input = Input(shape=x_train.shape[1:])
# Block 1
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer=l2(weight_decay))(img_input)
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer=l2(weight_decay))(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer=l2(weight_decay))(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
# Convolutional layers transfered from fully-connected layers
x = Conv2D(4096, (7, 7), activation='relu', padding='same', name='fc1', kernel_regularizer=l2(weight_decay))(x)
x = Dropout(0.5)(x)
x = Conv2D(4096, (1, 1), activation='relu', padding='same', name='fc2', kernel_regularizer=l2(weight_decay))(x)
x = Dropout(0.5)(x)
#classifying layer
x = Conv2D(NCLASSES, (1, 1), kernel_initializer='he_normal', activation='softmax', padding='valid', strides=(1, 1), kernel_regularizer=l2(weight_decay))(x)
x = BilinearUpSampling2D(size=(32, 32))(x)
model = Model(img_input, x)
model.compile(optimizer=Adam(learning_rate = 1e-5), loss = 'categorical_crossentropy', metrics=['accuracy'])
model.summary()
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 224, 224, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block1_conv1 (Conv2D) │ (None, 224, 224, 64) │ 1,792 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block1_conv2 (Conv2D) │ (None, 224, 224, 64) │ 36,928 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block1_pool (MaxPooling2D) │ (None, 112, 112, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block2_conv1 (Conv2D) │ (None, 112, 112, 128) │ 73,856 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block2_conv2 (Conv2D) │ (None, 112, 112, 128) │ 147,584 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block2_pool (MaxPooling2D) │ (None, 56, 56, 128) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block3_conv1 (Conv2D) │ (None, 56, 56, 256) │ 295,168 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block3_conv2 (Conv2D) │ (None, 56, 56, 256) │ 590,080 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block3_conv3 (Conv2D) │ (None, 56, 56, 256) │ 590,080 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block3_pool (MaxPooling2D) │ (None, 28, 28, 256) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block4_conv1 (Conv2D) │ (None, 28, 28, 512) │ 1,180,160 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block4_conv2 (Conv2D) │ (None, 28, 28, 512) │ 2,359,808 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block4_conv3 (Conv2D) │ (None, 28, 28, 512) │ 2,359,808 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block4_pool (MaxPooling2D) │ (None, 14, 14, 512) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block5_conv1 (Conv2D) │ (None, 14, 14, 512) │ 2,359,808 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block5_conv2 (Conv2D) │ (None, 14, 14, 512) │ 2,359,808 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block5_conv3 (Conv2D) │ (None, 14, 14, 512) │ 2,359,808 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ block5_pool (MaxPooling2D) │ (None, 7, 7, 512) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ fc1 (Conv2D) │ (None, 7, 7, 4096) │ 102,764,544 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout (Dropout) │ (None, 7, 7, 4096) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ fc2 (Conv2D) │ (None, 7, 7, 4096) │ 16,781,312 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_1 (Dropout) │ (None, 7, 7, 4096) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d (Conv2D) │ (None, 7, 7, 5) │ 20,485 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ bilinear_up_sampling2d │ (None, 224, 224, 5) │ 0 │ │ (BilinearUpSampling2D) │ │ │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 134,281,029 (512.24 MB)
Trainable params: 134,281,029 (512.24 MB)
Non-trainable params: 0 (0.00 B)
history = model.fit(x = x_train, y= y_train, batch_size=8, epochs=100, verbose=1, shuffle=True, validation_split=0.25)
Epoch 1/100 165/165 [==============================] - 41s 177ms/step - loss: 1.0905 - accuracy: 0.5832 - val_loss: 1.0422 - val_accuracy: 0.5485 Epoch 2/100 165/165 [==============================] - 27s 161ms/step - loss: 0.9695 - accuracy: 0.5864 - val_loss: 1.0068 - val_accuracy: 0.5485 Epoch 3/100 165/165 [==============================] - 27s 161ms/step - loss: 0.8548 - accuracy: 0.6902 - val_loss: 0.8560 - val_accuracy: 0.7204 Epoch 4/100 165/165 [==============================] - 27s 161ms/step - loss: 0.7804 - accuracy: 0.7409 - val_loss: 0.7923 - val_accuracy: 0.7244 Epoch 5/100 165/165 [==============================] - 27s 163ms/step - loss: 0.7381 - accuracy: 0.7495 - val_loss: 0.7717 - val_accuracy: 0.7271 Epoch 6/100 165/165 [==============================] - 27s 161ms/step - loss: 0.7251 - accuracy: 0.7521 - val_loss: 0.7804 - val_accuracy: 0.7207 Epoch 7/100 165/165 [==============================] - 27s 161ms/step - loss: 0.7018 - accuracy: 0.7624 - val_loss: 0.7149 - val_accuracy: 0.7506 Epoch 8/100 165/165 [==============================] - 27s 161ms/step - loss: 0.6700 - accuracy: 0.7740 - val_loss: 0.7115 - val_accuracy: 0.7498 Epoch 9/100 165/165 [==============================] - 27s 161ms/step - loss: 0.6525 - accuracy: 0.7773 - val_loss: 0.6939 - val_accuracy: 0.7515 Epoch 10/100 165/165 [==============================] - 27s 161ms/step - loss: 0.6362 - accuracy: 0.7861 - val_loss: 0.6591 - val_accuracy: 0.7661 Epoch 11/100 165/165 [==============================] - 27s 161ms/step - loss: 0.6241 - accuracy: 0.7889 - val_loss: 0.6630 - val_accuracy: 0.7617 Epoch 12/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5960 - accuracy: 0.7969 - val_loss: 0.6481 - val_accuracy: 0.7665 Epoch 13/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5921 - accuracy: 0.7951 - val_loss: 0.6258 - val_accuracy: 0.7752 Epoch 14/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5846 - accuracy: 0.8002 - val_loss: 0.6122 - val_accuracy: 0.7759 Epoch 15/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5727 - accuracy: 0.8018 - val_loss: 0.6036 - val_accuracy: 0.7738 Epoch 16/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5550 - accuracy: 0.8016 - val_loss: 0.5861 - val_accuracy: 0.7892 Epoch 17/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5340 - accuracy: 0.8195 - val_loss: 0.5730 - val_accuracy: 0.8023 Epoch 18/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5217 - accuracy: 0.8235 - val_loss: 0.5458 - val_accuracy: 0.8135 Epoch 19/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5062 - accuracy: 0.8292 - val_loss: 0.5255 - val_accuracy: 0.8115 Epoch 20/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5021 - accuracy: 0.8286 - val_loss: 0.5174 - val_accuracy: 0.8159 Epoch 21/100 165/165 [==============================] - 27s 161ms/step - loss: 0.5065 - accuracy: 0.8260 - val_loss: 0.5452 - val_accuracy: 0.8066 Epoch 22/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4868 - accuracy: 0.8362 - val_loss: 0.4854 - val_accuracy: 0.8329 Epoch 23/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4566 - accuracy: 0.8470 - val_loss: 0.5565 - val_accuracy: 0.7996 Epoch 24/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4757 - accuracy: 0.8405 - val_loss: 0.6245 - val_accuracy: 0.7993 Epoch 25/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4723 - accuracy: 0.8434 - val_loss: 0.5246 - val_accuracy: 0.8215 Epoch 26/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4704 - accuracy: 0.8395 - val_loss: 0.4847 - val_accuracy: 0.8290 Epoch 27/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4427 - accuracy: 0.8500 - val_loss: 0.4659 - val_accuracy: 0.8391 Epoch 28/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4352 - accuracy: 0.8537 - val_loss: 0.5479 - val_accuracy: 0.8169 Epoch 29/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4457 - accuracy: 0.8498 - val_loss: 0.4910 - val_accuracy: 0.8341 Epoch 30/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4205 - accuracy: 0.8600 - val_loss: 0.4809 - val_accuracy: 0.8327 Epoch 31/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4232 - accuracy: 0.8565 - val_loss: 0.4565 - val_accuracy: 0.8410 Epoch 32/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4317 - accuracy: 0.8526 - val_loss: 0.4814 - val_accuracy: 0.8335 Epoch 33/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4258 - accuracy: 0.8567 - val_loss: 0.4397 - val_accuracy: 0.8491 Epoch 34/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4129 - accuracy: 0.8595 - val_loss: 0.4518 - val_accuracy: 0.8442 Epoch 35/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4127 - accuracy: 0.8623 - val_loss: 0.4490 - val_accuracy: 0.8476 Epoch 36/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4287 - accuracy: 0.8491 - val_loss: 0.4880 - val_accuracy: 0.8325 Epoch 37/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4117 - accuracy: 0.8608 - val_loss: 0.4987 - val_accuracy: 0.8255 Epoch 38/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3940 - accuracy: 0.8666 - val_loss: 0.4560 - val_accuracy: 0.8453 Epoch 39/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4500 - accuracy: 0.8465 - val_loss: 0.5320 - val_accuracy: 0.8182 Epoch 40/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3978 - accuracy: 0.8653 - val_loss: 0.4477 - val_accuracy: 0.8453 Epoch 41/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3945 - accuracy: 0.8678 - val_loss: 0.4358 - val_accuracy: 0.8513 Epoch 42/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3925 - accuracy: 0.8687 - val_loss: 0.4972 - val_accuracy: 0.8295 Epoch 43/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3867 - accuracy: 0.8711 - val_loss: 0.4341 - val_accuracy: 0.8489 Epoch 44/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3717 - accuracy: 0.8751 - val_loss: 0.4177 - val_accuracy: 0.8598 Epoch 45/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3817 - accuracy: 0.8727 - val_loss: 0.4346 - val_accuracy: 0.8490 Epoch 46/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3861 - accuracy: 0.8686 - val_loss: 0.5097 - val_accuracy: 0.8237 Epoch 47/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3781 - accuracy: 0.8718 - val_loss: 0.4344 - val_accuracy: 0.8508 Epoch 48/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3717 - accuracy: 0.8746 - val_loss: 0.4155 - val_accuracy: 0.8589 Epoch 49/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3773 - accuracy: 0.8728 - val_loss: 0.4277 - val_accuracy: 0.8544 Epoch 50/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4029 - accuracy: 0.8619 - val_loss: 0.4199 - val_accuracy: 0.8544 Epoch 51/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3673 - accuracy: 0.8751 - val_loss: 0.4124 - val_accuracy: 0.8595 Epoch 52/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3544 - accuracy: 0.8813 - val_loss: 0.4291 - val_accuracy: 0.8520 Epoch 53/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3687 - accuracy: 0.8755 - val_loss: 0.4444 - val_accuracy: 0.8515 Epoch 54/100 165/165 [==============================] - 26s 161ms/step - loss: 0.3524 - accuracy: 0.8812 - val_loss: 0.4079 - val_accuracy: 0.8626 Epoch 55/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3541 - accuracy: 0.8797 - val_loss: 0.4564 - val_accuracy: 0.8452 Epoch 56/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3654 - accuracy: 0.8774 - val_loss: 0.4007 - val_accuracy: 0.8644 Epoch 57/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3455 - accuracy: 0.8837 - val_loss: 0.4594 - val_accuracy: 0.8446 Epoch 58/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3585 - accuracy: 0.8782 - val_loss: 0.4240 - val_accuracy: 0.8534 Epoch 59/100 165/165 [==============================] - 26s 161ms/step - loss: 0.3512 - accuracy: 0.8811 - val_loss: 0.4029 - val_accuracy: 0.8645 Epoch 60/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3585 - accuracy: 0.8782 - val_loss: 0.4735 - val_accuracy: 0.8411 Epoch 61/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3477 - accuracy: 0.8811 - val_loss: 0.4168 - val_accuracy: 0.8578 Epoch 62/100 165/165 [==============================] - 27s 161ms/step - loss: 0.4076 - accuracy: 0.8617 - val_loss: 0.4335 - val_accuracy: 0.8528 Epoch 63/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3782 - accuracy: 0.8718 - val_loss: 0.3981 - val_accuracy: 0.8646 Epoch 64/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3410 - accuracy: 0.8859 - val_loss: 0.4115 - val_accuracy: 0.8586 Epoch 65/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3413 - accuracy: 0.8842 - val_loss: 0.4061 - val_accuracy: 0.8631 Epoch 66/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3341 - accuracy: 0.8873 - val_loss: 0.4045 - val_accuracy: 0.8601 Epoch 67/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3303 - accuracy: 0.8880 - val_loss: 0.3982 - val_accuracy: 0.8642 Epoch 68/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3597 - accuracy: 0.8771 - val_loss: 0.3985 - val_accuracy: 0.8632 Epoch 69/100 165/165 [==============================] - 27s 163ms/step - loss: 0.3452 - accuracy: 0.8818 - val_loss: 0.3951 - val_accuracy: 0.8650 Epoch 70/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3252 - accuracy: 0.8900 - val_loss: 0.3883 - val_accuracy: 0.8682 Epoch 71/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3270 - accuracy: 0.8895 - val_loss: 0.4153 - val_accuracy: 0.8575 Epoch 72/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3311 - accuracy: 0.8872 - val_loss: 0.3993 - val_accuracy: 0.8648 Epoch 73/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3498 - accuracy: 0.8800 - val_loss: 0.3948 - val_accuracy: 0.8666 Epoch 74/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3568 - accuracy: 0.8795 - val_loss: 0.4228 - val_accuracy: 0.8534 Epoch 75/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3290 - accuracy: 0.8880 - val_loss: 0.4173 - val_accuracy: 0.8572 Epoch 76/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3258 - accuracy: 0.8892 - val_loss: 0.3878 - val_accuracy: 0.8678 Epoch 77/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3172 - accuracy: 0.8921 - val_loss: 0.3799 - val_accuracy: 0.8696 Epoch 78/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3168 - accuracy: 0.8926 - val_loss: 0.3936 - val_accuracy: 0.8669 Epoch 79/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3249 - accuracy: 0.8891 - val_loss: 0.3927 - val_accuracy: 0.8665 Epoch 80/100 165/165 [==============================] - 26s 161ms/step - loss: 0.3186 - accuracy: 0.8910 - val_loss: 0.3868 - val_accuracy: 0.8687 Epoch 81/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3138 - accuracy: 0.8931 - val_loss: 0.3955 - val_accuracy: 0.8674 Epoch 82/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3453 - accuracy: 0.8803 - val_loss: 0.4385 - val_accuracy: 0.8475 Epoch 83/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3368 - accuracy: 0.8844 - val_loss: 0.3986 - val_accuracy: 0.8652 Epoch 84/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3212 - accuracy: 0.8910 - val_loss: 0.3982 - val_accuracy: 0.8658 Epoch 85/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3103 - accuracy: 0.8947 - val_loss: 0.3842 - val_accuracy: 0.8692 Epoch 86/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3087 - accuracy: 0.8951 - val_loss: 0.3829 - val_accuracy: 0.8713 Epoch 87/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3050 - accuracy: 0.8961 - val_loss: 0.3820 - val_accuracy: 0.8706 Epoch 88/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3168 - accuracy: 0.8911 - val_loss: 0.4049 - val_accuracy: 0.8597 Epoch 89/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3054 - accuracy: 0.8959 - val_loss: 0.3844 - val_accuracy: 0.8690 Epoch 90/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3044 - accuracy: 0.8958 - val_loss: 0.3796 - val_accuracy: 0.8708 Epoch 91/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3144 - accuracy: 0.8915 - val_loss: 0.4657 - val_accuracy: 0.8440 Epoch 92/100 165/165 [==============================] - 26s 160ms/step - loss: 0.4228 - accuracy: 0.8547 - val_loss: 0.4615 - val_accuracy: 0.8361 Epoch 93/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3361 - accuracy: 0.8855 - val_loss: 0.4121 - val_accuracy: 0.8582 Epoch 94/100 165/165 [==============================] - 27s 162ms/step - loss: 0.3144 - accuracy: 0.8924 - val_loss: 0.3858 - val_accuracy: 0.8692 Epoch 95/100 165/165 [==============================] - 27s 161ms/step - loss: 0.3044 - accuracy: 0.8959 - val_loss: 0.3856 - val_accuracy: 0.8690 Epoch 96/100 165/165 [==============================] - 27s 161ms/step - loss: 0.2998 - accuracy: 0.8974 - val_loss: 0.3817 - val_accuracy: 0.8715 Epoch 97/100 165/165 [==============================] - 27s 161ms/step - loss: 0.2977 - accuracy: 0.8978 - val_loss: 0.3778 - val_accuracy: 0.8698 Epoch 98/100 165/165 [==============================] - 27s 161ms/step - loss: 0.2939 - accuracy: 0.8991 - val_loss: 0.3941 - val_accuracy: 0.8670 Epoch 99/100 165/165 [==============================] - 27s 161ms/step - loss: 0.2939 - accuracy: 0.8990 - val_loss: 0.3785 - val_accuracy: 0.8704 Epoch 100/100 165/165 [==============================] - 27s 161ms/step - loss: 0.2976 - accuracy: 0.8974 - val_loss: 0.3766 - val_accuracy: 0.8721
After training, we will get the loss and accuracy curves:
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='lower right')
plt.show()
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper right')
plt.show()
So we can calculate the accuracy for the test set:
del x_train, y_train
x_test = x_test/255
predict = model.predict(x_test)
24/24 [==============================] - 6s 159ms/step
predict = np.round(predict)
pred = np.argmax(predict, axis=3)
true = np.argmax(y_test, axis=3)
accuracy = accuracy_score(true.flatten(),pred.flatten())
print(accuracy)
0.8773271069918767
Finally, let's plot an example of the predicted result compared to the original mask:
i = 100
plt.figure(figsize=[20,8])
plt.subplot(131)
plt.imshow(x_test[i])
plt.title('RGB Image')
plt.axis('off')
plt.subplot(132)
plt.imshow(true[i], vmin=0,vmax=4)
plt.title('Original Mask')
plt.axis('off')
plt.subplot(133)
plt.imshow(pred[i], vmin=0,vmax=4)
plt.title('Predict Mask')
plt.axis('off')
(-0.5, 255.5, 255.5, -0.5)
SegNet¶
SegNet has an encoder network and a corresponding decoder network, followed by a final pixel-wise classification layer. This architecture is illustrated in the figure above. The encoder network consists of 13 convolutional layers that correspond to the first 13 convolutional layers in the VGG16 network designed for object classification. They discard the fully connected layers in favor of retaining higher resolution feature maps in the output of the deeper encoder. This also significantly reduces the number of parameters in the SegNet encoder network (from 134M to 14.7M) compared to other recent architectures. Each encoder layer has a corresponding decoder layer, and thus the decoder network has 13 layers. The final output of the decoder is fed to a multi-class soft-max classifier to produce class probabilities for each pixel independently.
The encoder network performs convolution with a filter bank to produce a set of feature maps. These are then batch normalized. Next, an element-wise rectified linear nonlinearity (ReLU) max(0, x) is applied. Next, max pooling with a 2 × 2 window and stride 2 (non-overlapping window) is performed and the resulting output is subsampled by a factor of 2. Max pooling is used to achieve translation invariance at small spatial offsets in the input image. Subsampling results in a large input image context (spatial window) for each pixel in the feature map. Although multiple layers of max pooling and subsampling can achieve more translation invariance for robust classification, correspondingly there is a loss of spatial resolution of the feature maps. The increasingly lossy image representation (contour detail) is not beneficial for segmentation where contour delineation is vital. Therefore, it is necessary to capture and store threshold information on the encoder feature maps before performing downsampling. If memory during inference is not constrained, all encoder feature maps (after downsampling) can be stored. This is often not the case in practical applications, and so we propose a more efficient way to store this information. It involves storing only the max-pooling indices, i.e., the locations of the maximum feature value in each pooling window are memorized for each encoder feature map. In principle, this can be done using 2 bits for each 2×2 pooling window, and is therefore much more efficient to store compared to memorizing the feature map(s) in floating precision.
The decoder network samples its input feature map using the memorized max pooling indices from the corresponding encoder feature map(s). This step produces sparse feature map(s). This SegNet decoding technique is illustrated in the figure below. These feature maps are then combined with a trainable decoder filter bank to produce dense feature maps. A batch normalization step is then applied to each of these maps. Note that the decoder corresponding to the first encoder (closest to the input image) produces a multi-channel feature map, even though its encoder input has 3 channels (RGB). This is different from the other decoders in the network, which produce feature maps with the same size and number of channels as their encoder inputs. The high-dimensional feature representation at the final decoder output is fed to a trainable soft-max classifier. This soft-max classifies each pixel independently. The output of the soft-max classifier is an image with K channel probabilities where K is the number of classes. The predicted segmentation corresponds to the class with maximum probability at each pixel.
SegNet and FCN decoders. a, b, c, d correspond to values in a feature map. SegNet uses the max pooling indices to upsample (without learning) the feature map and convolves it with a trainable decoder filter bank. FCN upsamples by learning to deconvolve the input feature map and adds the corresponding encoder feature map to produce the decoder output. This feature map is the output of the max pooling layer (includes downsampling) in the corresponding encoder. Note that there are no trainable decoder filters in FCN.
This is the SegNet implementation:
droprate = 0.1
pool_size = (2,2)
inputs = Input(shape=x_train.shape[1:])
conv_1 = Conv2D(32, (3, 3), kernel_initializer='he_uniform', padding='same')(inputs)
conv_1 = BatchNormalization()(conv_1)
conv_1 = Activation("relu")(conv_1)
conv_1 = Dropout(droprate)(conv_1)
conv_2 = Conv2D(32, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_1)
conv_2 = BatchNormalization()(conv_2)
conv_2 = Activation("relu")(conv_2)
pool_1 = MaxPooling2D(pool_size=pool_size)(conv_2)
conv_3 = Conv2D(64, (3, 3), kernel_initializer='he_uniform', padding='same')(pool_1)
conv_3 = BatchNormalization()(conv_3)
conv_3 = Activation("relu")(conv_3)
conv_3 = Dropout(droprate)(conv_3)
conv_4 = Conv2D(64, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_3)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Activation("relu")(conv_4)
pool_2 = MaxPooling2D(pool_size=pool_size)(conv_4)
conv_5 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(pool_2)
conv_5 = BatchNormalization()(conv_5)
conv_5 = Activation("relu")(conv_5)
conv_5 = Dropout(droprate)(conv_5)
conv_6 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_5)
conv_6 = BatchNormalization()(conv_6)
conv_6 = Activation("relu")(conv_6)
conv_6 = Dropout(droprate)(conv_6)
conv_7 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_6)
conv_7 = BatchNormalization()(conv_7)
conv_7 = Activation("relu")(conv_7)
pool_3 = MaxPooling2D(pool_size=pool_size)(conv_7)
conv_8 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(pool_3)
conv_8 = BatchNormalization()(conv_8)
conv_8 = Activation("relu")(conv_8)
conv_8 = Dropout(droprate)(conv_8)
conv_9 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_8)
conv_9 = BatchNormalization()(conv_9)
conv_9 = Activation("relu")(conv_9)
conv_9 = Dropout(droprate)(conv_9)
conv_10 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_9)
conv_10 = BatchNormalization()(conv_10)
conv_10 = Activation("relu")(conv_10)
pool_4 = MaxPooling2D(pool_size=pool_size)(conv_10)
conv_11 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(pool_4)
conv_11 = BatchNormalization()(conv_11)
conv_11 = Activation("relu")(conv_11)
conv_11 = Dropout(droprate)(conv_11)
conv_12 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_11)
conv_12 = BatchNormalization()(conv_12)
conv_12 = Activation("relu")(conv_12)
conv_12 = Dropout(droprate)(conv_12)
conv_13 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_12)
conv_13 = BatchNormalization()(conv_13)
conv_13 = Activation("relu")(conv_13)
pool_5 = MaxPooling2D(pool_size=pool_size)(conv_13)
unpool_1 = UpSampling2D(size=pool_size)(pool_5)
conv_14 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_1)
conv_14 = BatchNormalization()(conv_14)
conv_14 = Activation("relu")(conv_14)
conv_14 = Dropout(droprate)(conv_14)
conv_15 = Conv2D(512, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_14)
conv_15 = BatchNormalization()(conv_15)
conv_15 = Activation("relu")(conv_15)
conv_15 = Dropout(droprate)(conv_15)
conv_16 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_15)
conv_16 = BatchNormalization()(conv_16)
conv_16 = Activation("relu")(conv_16)
unpool_2 = UpSampling2D(size=pool_size)(conv_16)
conv_17 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_2)
conv_17 = BatchNormalization()(conv_17)
conv_17 = Activation("relu")(conv_17)
conv_17 = Dropout(droprate)(conv_17)
conv_18 = Conv2D(256, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_17)
conv_18 = BatchNormalization()(conv_18)
conv_18 = Activation("relu")(conv_18)
conv_18 = Dropout(droprate)(conv_18)
conv_19 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_18)
conv_19 = BatchNormalization()(conv_19)
conv_19 = Activation("relu")(conv_19)
unpool_3 = UpSampling2D(size=pool_size)(conv_19)
conv_20 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_3)
conv_20 = BatchNormalization()(conv_20)
conv_20 = Activation("relu")(conv_20)
conv_20 = Dropout(droprate)(conv_20)
conv_21 = Conv2D(128, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_20)
conv_21 = BatchNormalization()(conv_21)
conv_21 = Activation("relu")(conv_21)
conv_21 = Dropout(droprate)(conv_21)
conv_22 = Conv2D(64, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_21)
conv_22 = BatchNormalization()(conv_22)
conv_22 = Activation("relu")(conv_22)
unpool_4 = UpSampling2D(size=pool_size)(conv_22)
conv_23 = Conv2D(64, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_4)
conv_23 = BatchNormalization()(conv_23)
conv_23 = Activation("relu")(conv_23)
conv_23 = Dropout(droprate)(conv_23)
conv_24 = Conv2D(32, (3, 3), kernel_initializer='he_uniform', padding='same')(conv_23)
conv_24 = BatchNormalization()(conv_24)
conv_24 = Activation("relu")(conv_24)
unpool_5 = UpSampling2D(size=pool_size)(conv_24)
conv_25 = Conv2D(32, (3, 3), kernel_initializer='he_uniform', padding='same')(unpool_5)
conv_25 = BatchNormalization()(conv_25)
conv_25 = Activation("relu")(conv_25)
conv_26 = Conv2D(NCLASSES, (1, 1), kernel_initializer='he_uniform', padding='same')(conv_25)
conv_26 = BatchNormalization()(conv_26)
outputs = Activation("softmax")(conv_26)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=Adam(learning_rate=0.00001), loss = Dice, metrics=[ 'accuracy'])
model.summary()
Model: "model_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ conv2d_26 (Conv2D) (None, 224, 224, 32) 896 _________________________________________________________________ batch_normalization_26 (Batc (None, 224, 224, 32) 128 _________________________________________________________________ activation_26 (Activation) (None, 224, 224, 32) 0 _________________________________________________________________ dropout_15 (Dropout) (None, 224, 224, 32) 0 _________________________________________________________________ conv2d_27 (Conv2D) (None, 224, 224, 32) 9248 _________________________________________________________________ batch_normalization_27 (Batc (None, 224, 224, 32) 128 _________________________________________________________________ activation_27 (Activation) (None, 224, 224, 32) 0 _________________________________________________________________ max_pooling2d_5 (MaxPooling2 (None, 112, 112, 32) 0 _________________________________________________________________ conv2d_28 (Conv2D) (None, 112, 112, 64) 18496 _________________________________________________________________ batch_normalization_28 (Batc (None, 112, 112, 64) 256 _________________________________________________________________ activation_28 (Activation) (None, 112, 112, 64) 0 _________________________________________________________________ dropout_16 (Dropout) (None, 112, 112, 64) 0 _________________________________________________________________ conv2d_29 (Conv2D) (None, 112, 112, 64) 36928 _________________________________________________________________ batch_normalization_29 (Batc (None, 112, 112, 64) 256 _________________________________________________________________ activation_29 (Activation) (None, 112, 112, 64) 0 _________________________________________________________________ max_pooling2d_6 (MaxPooling2 (None, 56, 56, 64) 0 _________________________________________________________________ conv2d_30 (Conv2D) (None, 56, 56, 128) 73856 _________________________________________________________________ batch_normalization_30 (Batc (None, 56, 56, 128) 512 _________________________________________________________________ activation_30 (Activation) (None, 56, 56, 128) 0 _________________________________________________________________ dropout_17 (Dropout) (None, 56, 56, 128) 0 _________________________________________________________________ conv2d_31 (Conv2D) (None, 56, 56, 128) 147584 _________________________________________________________________ batch_normalization_31 (Batc (None, 56, 56, 128) 512 _________________________________________________________________ activation_31 (Activation) (None, 56, 56, 128) 0 _________________________________________________________________ dropout_18 (Dropout) (None, 56, 56, 128) 0 _________________________________________________________________ conv2d_32 (Conv2D) (None, 56, 56, 128) 147584 _________________________________________________________________ batch_normalization_32 (Batc (None, 56, 56, 128) 512 _________________________________________________________________ activation_32 (Activation) (None, 56, 56, 128) 0 _________________________________________________________________ max_pooling2d_7 (MaxPooling2 (None, 28, 28, 128) 0 _________________________________________________________________ conv2d_33 (Conv2D) (None, 28, 28, 256) 295168 _________________________________________________________________ batch_normalization_33 (Batc (None, 28, 28, 256) 1024 _________________________________________________________________ activation_33 (Activation) (None, 28, 28, 256) 0 _________________________________________________________________ dropout_19 (Dropout) (None, 28, 28, 256) 0 _________________________________________________________________ conv2d_34 (Conv2D) (None, 28, 28, 256) 590080 _________________________________________________________________ batch_normalization_34 (Batc (None, 28, 28, 256) 1024 _________________________________________________________________ activation_34 (Activation) (None, 28, 28, 256) 0 _________________________________________________________________ dropout_20 (Dropout) (None, 28, 28, 256) 0 _________________________________________________________________ conv2d_35 (Conv2D) (None, 28, 28, 256) 590080 _________________________________________________________________ batch_normalization_35 (Batc (None, 28, 28, 256) 1024 _________________________________________________________________ activation_35 (Activation) (None, 28, 28, 256) 0 _________________________________________________________________ max_pooling2d_8 (MaxPooling2 (None, 14, 14, 256) 0 _________________________________________________________________ conv2d_36 (Conv2D) (None, 14, 14, 512) 1180160 _________________________________________________________________ batch_normalization_36 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ activation_36 (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ dropout_21 (Dropout) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_37 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ batch_normalization_37 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ activation_37 (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ dropout_22 (Dropout) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_38 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ batch_normalization_38 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ activation_38 (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ max_pooling2d_9 (MaxPooling2 (None, 7, 7, 512) 0 _________________________________________________________________ up_sampling2d_5 (UpSampling2 (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_39 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ batch_normalization_39 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ activation_39 (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ dropout_23 (Dropout) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_40 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ batch_normalization_40 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ activation_40 (Activation) (None, 14, 14, 512) 0 _________________________________________________________________ dropout_24 (Dropout) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_41 (Conv2D) (None, 14, 14, 256) 1179904 _________________________________________________________________ batch_normalization_41 (Batc (None, 14, 14, 256) 1024 _________________________________________________________________ activation_41 (Activation) (None, 14, 14, 256) 0 _________________________________________________________________ up_sampling2d_6 (UpSampling2 (None, 28, 28, 256) 0 _________________________________________________________________ conv2d_42 (Conv2D) (None, 28, 28, 256) 590080 _________________________________________________________________ batch_normalization_42 (Batc (None, 28, 28, 256) 1024 _________________________________________________________________ activation_42 (Activation) (None, 28, 28, 256) 0 _________________________________________________________________ dropout_25 (Dropout) (None, 28, 28, 256) 0 _________________________________________________________________ conv2d_43 (Conv2D) (None, 28, 28, 256) 590080 _________________________________________________________________ batch_normalization_43 (Batc (None, 28, 28, 256) 1024 _________________________________________________________________ activation_43 (Activation) (None, 28, 28, 256) 0 _________________________________________________________________ dropout_26 (Dropout) (None, 28, 28, 256) 0 _________________________________________________________________ conv2d_44 (Conv2D) (None, 28, 28, 128) 295040 _________________________________________________________________ batch_normalization_44 (Batc (None, 28, 28, 128) 512 _________________________________________________________________ activation_44 (Activation) (None, 28, 28, 128) 0 _________________________________________________________________ up_sampling2d_7 (UpSampling2 (None, 56, 56, 128) 0 _________________________________________________________________ conv2d_45 (Conv2D) (None, 56, 56, 128) 147584 _________________________________________________________________ batch_normalization_45 (Batc (None, 56, 56, 128) 512 _________________________________________________________________ activation_45 (Activation) (None, 56, 56, 128) 0 _________________________________________________________________ dropout_27 (Dropout) (None, 56, 56, 128) 0 _________________________________________________________________ conv2d_46 (Conv2D) (None, 56, 56, 128) 147584 _________________________________________________________________ batch_normalization_46 (Batc (None, 56, 56, 128) 512 _________________________________________________________________ activation_46 (Activation) (None, 56, 56, 128) 0 _________________________________________________________________ dropout_28 (Dropout) (None, 56, 56, 128) 0 _________________________________________________________________ conv2d_47 (Conv2D) (None, 56, 56, 64) 73792 _________________________________________________________________ batch_normalization_47 (Batc (None, 56, 56, 64) 256 _________________________________________________________________ activation_47 (Activation) (None, 56, 56, 64) 0 _________________________________________________________________ up_sampling2d_8 (UpSampling2 (None, 112, 112, 64) 0 _________________________________________________________________ conv2d_48 (Conv2D) (None, 112, 112, 64) 36928 _________________________________________________________________ batch_normalization_48 (Batc (None, 112, 112, 64) 256 _________________________________________________________________ activation_48 (Activation) (None, 112, 112, 64) 0 _________________________________________________________________ dropout_29 (Dropout) (None, 112, 112, 64) 0 _________________________________________________________________ conv2d_49 (Conv2D) (None, 112, 112, 32) 18464 _________________________________________________________________ batch_normalization_49 (Batc (None, 112, 112, 32) 128 _________________________________________________________________ activation_49 (Activation) (None, 112, 112, 32) 0 _________________________________________________________________ up_sampling2d_9 (UpSampling2 (None, 224, 224, 32) 0 _________________________________________________________________ conv2d_50 (Conv2D) (None, 224, 224, 32) 9248 _________________________________________________________________ batch_normalization_50 (Batc (None, 224, 224, 32) 128 _________________________________________________________________ activation_50 (Activation) (None, 224, 224, 32) 0 _________________________________________________________________ conv2d_51 (Conv2D) (None, 224, 224, 5) 165 _________________________________________________________________ batch_normalization_51 (Batc (None, 224, 224, 5) 20 _________________________________________________________________ activation_51 (Activation) (None, 224, 224, 5) 0 ================================================================= Total params: 15,639,193 Trainable params: 15,628,687 Non-trainable params: 10,506 _________________________________________________________________
/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.")
References:
https://nanonets.com/blog/how-to-do-semantic-segmentation-using-deep-learning/
https://nanonets.com/blog/semantic-image-segmentation-2020/
https://www.jeremyjordan.me/semantic-segmentation/
https://towardsdatascience.com/review-fcn-semantic-segmentation-eb8c9b50d2d1
http://ronny.rest/tutorials/module/seg_01/segmentation_04_training/
https://github.com/ykamikawa/tf-keras-SegNet/blob/master/model.py
https://ai-pool.com/m/segnet-1555409707
https://www.kaggle.com/santhalnr/cityscapes-image-segmentation-pspnet
https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html