U-Net Paper Implementation using PyTorch

Main Diagram

Network Architecture

The network architecture is illustrated in Figure 1. It consists of a contracting path (left side) and an expansive path (right side). The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling. At each downsampling step, we double the number of feature channels. Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels in every convolution. At the final layer, a 1x1 convolution is used to map each 64- component feature vector to the desired number of classes. In total the network has 23 convolutional layers. To allow a seamless tiling of the output segmentation map (see Figure 2), it is important to select the input tile size such that all 2x2 max-pooling operations are applied to a layer with an even x- and y-size.

"""
Created on Mon Jun 20 16:08:19 2021

@author: mr-siddy
"""

import torch
import torch.nn as nn


def double_conv(in_c, out_c):
conv = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, kernel_size=3),
nn.ReLU(inplace=True)
)
return conv

def crop_img(tensor, target_tensor):
target_size = target_tensor.size()[2]
tensor_size = tensor.size()[2]
delta = tensor_size - target_size
delta = delta // 2
return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]


class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()

self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down_conv_1 = double_conv(1, 64)
self.down_conv_2 = double_conv(64, 128)
self.down_conv_3 = double_conv(128, 256)
self.down_conv_4 = double_conv(256, 512)
self.down_conv_5 = double_conv(512, 1024)

self.up_trans_1 = nn.ConvTranspose2d(
in_channels=1024,
out_channels=512,
kernel_size=2,
stride=2)

self.up_conv_1 = double_conv(1024, 512)

self.up_trans_2 = nn.ConvTranspose2d(
in_channels=512,
out_channels=256,
kernel_size=2,
stride=2)

self.up_conv_2 = double_conv(512, 256)

self.up_trans_3 = nn.ConvTranspose2d(
in_channels=256,
out_channels=128,
kernel_size=2,
stride=2)

self.up_conv_3 = double_conv(256, 128)

self.up_trans_4 = nn.ConvTranspose2d(
in_channels=128,
out_channels=64,
kernel_size=2,
stride=2)

self.up_conv_4 = double_conv(128, 64)

self.out = nn.Conv2d(
in_channels=64,
out_channels=2,
kernel_size=1)


def forward(self, image):
# bs, channel = c, height = h, width = w

# encoder
x1 = self.down_conv_1(image) # copy and crop
x2 = self.max_pool_2x2(x1)
x3 = self.down_conv_2(x2) # copy and crop
x4 = self.max_pool_2x2(x3)
x5 = self.down_conv_3(x4) # copy and crop
x6 = self.max_pool_2x2(x5)
x7 = self.down_conv_4(x6) # copy and crop
x8 = self.max_pool_2x2(x7)
x9 = self.down_conv_5(x8)

# decoder
x = self.up_trans_1(x9)
y = crop_img(x7, x)
x = self.up_conv_1(torch.cat([x, y], 1))

x = self.up_trans_2(x)
y = crop_img(x5, x)
x = self.up_conv_2(torch.cat([x, y], 1))

x = self.up_trans_3(x)
y = crop_img(x3, x)
x = self.up_conv_3(torch.cat([x, y], 1))

x = self.up_trans_4(x)
y = crop_img(x1, x)
x = self.up_conv_4(torch.cat([x, y], 1))

x = self.out(x)
print(x.size())

return x




if __name__ == '__main__':
image = torch.rand((1, 1, 572, 572))
model = UNet()
print(model(image))