Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to analyze the UNet network structure and code writing in Pytorch

2025-03-28 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >

Share

Shulou(Shulou.com)06/01 Report--

This article introduces how to analyze the UNet network structure and code in Pytorch. The content is very detailed. Interested friends can use it for reference. I hope it will be helpful to you.

I. Preface

Windows environment development, the environment is as follows:

Development environment: Windows

Development language: Python3.7.4

Framework version: Pytorch2.3.0

CUDA:10.2

CuDNN:7.6.0

It mainly explains the structure of UNet network and the coding of the corresponding code.

PS: all the codes that appear in this article can be downloaded from my github. Welcome to Follow, Star: click to view

II. UNet network structure

In the field of semantic segmentation, the first work of semantic segmentation algorithm based on deep learning is FCN (Fully Convolutional Networks for Semantic Segmentation), while UNet follows the principle of FCN and makes corresponding improvements to adapt to the simple segmentation of small samples.

UNet paper address: click to view

To study a deep learning algorithm, we can first look at the network structure, understand the network structure, and then Loss calculation methods, training methods and so on. This paper mainly explains the network structure of UNet, and other contents will be explained in the following chapters.

1. The principle of network structure

UNet was first published at the MICCAI conference in 2015. In more than 4 years, the citation of papers has reached more than 9700 times.

UNet has become the baseline that mostly does the task of semantic segmentation of medical images. at the same time, it has inspired a large number of researchers to study the U-shaped network structure, and published a number of papers on improved methods based on UNet network structure.

The two main features of UNet network structure are U-shaped network structure and Skip Connection hopping connection.

This kind of "stacking" operation is called Concat.

By the same token, for feature map, a feature map with a size of 256-256-64, that is, the w (width) of feature map is 256 h (height) is 256 (the number of channels) is 64. By Concat fusion with a feature map with a size of 256 '256' 32, you will get an feature map with a size of 256 '256' 96.

In practice, the sizes of the two feature map fused by Concat are not necessarily the same, for example, the feature map of 256 '256' 64 and the feature map of 240 '240' 32 are used for Concat.

At times like this, there are two ways:

The first kind: cut the feature map with a size of 256 / 256 / 64 to a feature map with 240 / 240 / 64, for example, the top and bottom, each discarding 8 pixel, and then Concat after cutting to get a 240 / 240 / 96 feature map.

The second is to perform padding operation on the feature map with a small size of 240,24032 and a feature map with a padding of 25625632.For example, if the upper and lower parts are filled with 8 pixel,padding and then Concat is performed, the feature map of 256mm 256mm 96 is obtained.

The Concat scheme adopted by UNet is the second, and the way to padding,padding a small feature map is to fill 0, a regular constant fill.

2. Code

Some friends may not know much about Pytorch and recommend an official tutorial for getting started. In an hour, you can master some basic concepts and Pytorch coding methods.

Official basics of Pytorch: click to view

We split the whole UNet network into several modules to explain.

DoubleConv module:

Let's take a look at two consecutive convolution operations.

As can be seen from the UNet network, whether it is the downsampling process or the upsampling process, each layer performs two consecutive convolution operations. This operation is repeated many times in the UNet network, and a separate DoubleConv module can be written:

Import torch.nn as nnclass DoubleConv (nn.Module): "" (convolution = > [BN] = > ReLU) * 2 "" def _ init__ (self, in_channels, out_channels): super (). _ init__ () self.double_conv = nn.Sequential (nn.Conv2d (in_channels, out_channels, kernel_size=3, padding=0), nn.BatchNorm2d (out_channels) Nn.ReLU (inplace=True), nn.Conv2d (out_channels, out_channels, kernel_size=3, padding=0), nn.BatchNorm2d (out_channels), nn.ReLU (inplace=True) def forward (self, x): return self.double_conv (x)

Explain the Pytorch code above: torch.nn.Sequential is a timing container, and Modules is added to the container in the order in which they are passed in. For example, the operation order of the above code is convolution-> BN- > ReLU- > convolution-> BN- > ReLU.

The in_channels and out_channels of the DoubleConv module can be flexibly set for extended use.

For the network shown in the figure above, the in_channels is set to 1 and 64.

The input image size is 572 minutes 572, and the feature map of 570 minutes 570 is obtained after 3 steps of 3 convolution with a step size of 1 padding and 0, and then a convolution of 568 seconds 568 feature map is obtained.

Calculation formula: O = (H − Found2 × P) / Schroe1

H is the size of input feature map, O is the size of output feature map, F is the size of convolution kernel, P is the size of padding, and S is step size.

Down module:

The UNet network has a total of 4 downsampling processes. The modular code is as follows:

Class Down (nn.Module): "" Downscaling with maxpool then double conv "def _ init__ (self, in_channels, out_channels): super (). _ init__ () self.maxpool_conv = nn.Sequential (nn.MaxPool2d (2), DoubleConv (in_channels, out_channels) def forward (self, x): return self.maxpool_conv (x)

The code here is very simple, which is a maxpool pooling layer, downsampling, and then a DoubleConv module.

At this point, the code for the downsampling process in the left half of the UNet network is written, followed by the upsampling process in the right half.

Up module:

Of course, up-sampling is the most commonly used in the up-sampling process, in addition to the conventional up-sampling operation, there is also feature fusion.

The code for this piece is also a little more complicated to implement:

Class Up (nn.Module): "" Upscaling then double conv "def _ init__ (self, in_channels, out_channels, bilinear=True): super (). _ init__ () # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample (scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d (in_channels In_channels / / 2, kernel_size=2, stride=2) self.conv = DoubleConv (in_channels, out_channels) def forward (self, x1) X2): X1 = self.up (x1) # input is CHW diffY = torch.tensor ([x2.size () [2]-x1.size () [2]]) diffX = torch.tensor ([x2.size () [3]-x1.size () [3]]) x1 = F.pad (x1, [diffX / / 2, diffX-diffX / / 2, diffY / / 2) DiffY-diffY / / 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat ([x2, x1], dim=1) return self.conv (x)

The code is more complex, we can look at it separately, first of all, the up-sampling method and convolution defined in the _ _ init__ initialization function uses DoubleConv. Upsampling, two methods are defined: Upsample and ConvTranspose2d, that is, bilinear interpolation and deconvolution.

Bilinear interpolation is easy to understand, schematic diagram:

Friends who are familiar with bilinear interpolation should be familiar with this picture. To put it simply: knowing the coordinates of Q11, Q12, Q21 and Q22, finding R1 through Q11 and Q21, then R2 through Q12 and Q22, and finally P through R1 and R2, this process is bilinear interpolation.

For a feature map, it is actually a patch in the middle of a pixel, and the value of the complement is determined by the value of the adjacent pixels.

Deconvolution, as the name implies, is inverse convolution. Convolution is to make featuer map smaller and smaller, deconvolution is to make feature map bigger and bigger, diagram:

The blue below is the original picture, the white dotted box around it is the result of padding, usually 0, and the green above is the convoluted image.

This diagram is a feature map process from 2 to 2 feature map- > 4 to 4.

In the forward forward propagation function, x1 receives upsampled data and x2 receives feature fusion data. The feature fusion method is that, as mentioned above, the small feature map is first padding, and then concat.

OutConv module:

The main network structure of UNet can be spelled out by using the above DoubleConv module, Down module and Up module. The output of the UNet network needs to integrate the output channels according to the number of partitions. The result is shown in the following figure:

The operation is very simple, which is the transformation of channel. The figure above shows the case of classification 2 (channel 2).

Although this operation is very simple, it will only be called once, and it will be encapsulated for the sake of beauty and cleanliness.

Class OutConv (nn.Module): def _ init__ (self, in_channels, out_channels): super (OutConv, self). _ _ init__ () self.conv = nn.Conv2d (in_channels, out_channels, kernel_size=1) def forward (self, x): return self.conv (x)

So far, all the modules used in the UNet network have been written. We can put all the above module codes in a unet_parts.py file, and then create a unet_model.py. According to the UNet network structure, we can set the number of input and output channels of each module and the calling order, and write the following code:

"" Full assembly of the parts to form the complete network "Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""import torch.nn.functional as Ffrom unet_parts import * class UNet (nn.Module): def _ _ init__ (self, n_channels, n_classes, bilinear=False): super (UNet Self). _ init__ () self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv (n_channels, 64) self.down1 = Down (64,128) self.down2 = Down (128,256) self.down3 = Down (256,512) self.down4 = Down (512,512) self.up1 = Up (1024 Bilinear) self.up2 = Up (512,256, bilinear) self.up3 = Up (256,128, bilinear) self.up4 = Up (128,64, bilinear) self.outc = OutConv (64, n_classes) def forward (self X): X1 = self.inc (x) x2 = self.down1 (x1) x3 = self.down2 (x2) x4 = self.down3 (x3) x5 = self.down4 (x4) x = self.up1 (x5, x4) x = self.up2 (x, x3) x = self.up3 (x, x2) x = self.up4 (x X1) logits = self.outc (x) return logits if _ _ name__ = ='_ main__': net = UNet (n_channels=3, n_classes=1) print (net)

Using the command python unet_model.py, if there are no errors, you will get the following results:

UNet ((inc): DoubleConv ((double_conv): Sequential ((0): Conv2d (3,64, kernel_size= (3,3), stride= (1,1)) (1): BatchNorm2d (64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (64, 64, kernel_size= (3, 3), stride= (1, 1)) (4): BatchNorm2d (64, eps=1e-05) Momentum=0.1, affine=True, track_running_stats=True) (5): ReLU (inplace=True)) (down1): Down ((maxpool_conv): Sequential ((0): MaxPool2d (kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): DoubleConv ((double_conv): Sequential ((0): Conv2d (64,128, kernel_size= (3,3), stride= (1) ) (1): BatchNorm2d (128,128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (128,128, kernel_size= (3,3), stride= (1,1)) (4): BatchNorm2d (128,128, eps=1e-05, momentum=0.1, affine=True Track_running_stats=True) (5): ReLU (inplace=True)) (down2): Down ((maxpool_conv): Sequential ((0): MaxPool2d (kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): DoubleConv ((double_conv): Sequential (0): Conv2d (128,256, kernel_size= (3,3), stride= (1) ) (1): BatchNorm2d (256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (256,256, kernel_size= (3,3), stride= (1,1)) (4): BatchNorm2d (256,256, eps=1e-05, momentum=0.1, affine=True Track_running_stats=True) (5): ReLU (inplace=True)) (down3): Down ((maxpool_conv): Sequential ((0): MaxPool2d (kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): DoubleConv ((double_conv): Sequential ((0): Conv2d (256512, kernel_size= (3,3), stride= (1) ) (1): BatchNorm2d (512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (512,512, kernel_size= (3,3), stride= (1,1)) (4): BatchNorm2d (512, eps=1e-05, momentum=0.1, affine=True Track_running_stats=True) (5): ReLU (inplace=True)) (down4): Down ((maxpool_conv): Sequential ((0): MaxPool2d (kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): DoubleConv ((double_conv): Sequential ((0): Conv2d (512, 1024, kernel_size= (3,3), stride= (1) ) (1): BatchNorm2d (1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (1024, 1024, kernel_size= (3,3), stride= (1,1)) (4): BatchNorm2d (1024, eps=1e-05, momentum=0.1, affine=True Track_running_stats=True) (5): ReLU (inplace=True)) (up1): Up ((up): ConvTranspose2d (1024, 512, kernel_size= (2,2), stride= (2,2)) (conv): DoubleConv ((double_conv): Sequential (0): Conv2d (1024, 512, kernel_size= (3,3), stride= (1,1)) (1): BatchNorm2d Eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (512,512, kernel_size= (3,3), stride= (1,1)) (4): BatchNorm2d (512,512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU (inplace=True) (up2): Up ((up): ConvTranspose2d (512,256, kernel_size= (2) 2), stride= (2,2) (conv): DoubleConv ((double_conv): Sequential ((0): Conv2d (512,256, kernel_size= (3,3), stride= (1,1)) (1): BatchNorm2d (256,256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (256,256, kernel_size= (3,3), stride= (1) ) (up3): Up ((up): ConvTranspose2d (256,128, kernel_size= (2,2), stride= (2,2)) (conv): DoubleConv ((double_conv): Sequential ((0): Conv2d (256,128, kernel_size= (3,3)) Stride= (1,1) (1): BatchNorm2d (128,128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (128,128, kernel_size= (3,3), stride= (1,1)) (4): BatchNorm2d (128,eps=1e-05, momentum=0.1, affine=True Track_running_stats=True) (5): ReLU (inplace=True) (up4): Up ((up): ConvTranspose2d (128,64, kernel_size= (2,2), stride= (2,2)) (conv): DoubleConv ((double_conv): Sequential (0): Conv2d (128,64, kernel_size= (3,3), stride= (1,1)) (1): BatchNorm2d (64, eps=1e-05, momentum=0.1 Affine=True, track_running_stats=True) (2): ReLU (inplace=True) (3): Conv2d (64, 64, kernel_size= (3, 3), stride= (1, 1)) (4): BatchNorm2d (64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU (inplace=True)) (outc): OutConv ((conv): Conv2d (64, 1, kernel_size= (1, 1), stride= (1) 1) how to analyze the UNet network structure in Pytorch and how to write the code are shared here. I hope the above content can be of some help to you and learn more knowledge. If you think the article is good, you can share it for more people to see.

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Internet Technology

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report