Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to use attention Mechanism to explain Medical Image Segmentation and Pytorch implementation

2025-01-16 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >

Share

Shulou(Shulou.com)06/02 Report--

This article is about how to use the attention mechanism to do medical image segmentation explanation and Pytorch implementation, the editor feels very practical, so share with you to learn, I hope you can get something after reading this article, say no more, follow the editor to have a look.

Click "AI Park" above, follow the official account, and select "Star" or "Top".

Author: l é o Fillioux

Compilation: ronghuaiyang

Guide reading

Two recent articles on segmentation using attention mechanism are analyzed, and a simple Pytorch implementation is given.

From natural language processing to recent computer vision tasks, attention mechanism has always been one of the hottest areas in deep learning. In this article, we will focus on how attention affects the latest architecture of medical image segmentation. To this end, we will describe the architecture introduced in the last two papers and try to give some intuition about the methods mentioned in these two articles, hoping that it will give you some ideas so that you can apply the attention mechanism to your own problems. We will also see a simple PyTorch implementation.

There are two main differences between medical image segmentation and natural image.

Most medical images are very similar because they are taken in standardized settings, which means that there is little change in the direction, position, pixel range, and so on. There is usually a large imbalance between positive sample pixels (or voxels) and negative sample pixels, such as when trying to segment a tumor.

Note: of course, the code and explanation are simplifications of the complex architecture described in the paper, and the main purpose is to give an intuition and a good idea about what has been done, rather than explaining every detail.

1. Attention UNet

UNet is the main architecture for segmentation, and most of the advances in segmentation use this architecture as the backbone. In this paper, the author proposes a method of applying attention mechanism to standard UNet.

1.1. What method has been proposed?

The structure uses standard UNet as the backbone and does not change the contraction path. What changes is the expansion path, or more precisely, the attention mechanism is integrated into the jump connection.

The block diagram of attention UNet, and the extension path block is framed in red

To explain how the block of the expansion path works, let's call the input from the previous block g and the skip link from the extension path x. The following formula describes how this module works.

Upsample blocks are very simple, while ConvBlock is just a sequence of two (convolution + batch norm + ReLU) blocks. The only thing to explain is attention.

The block diagram of attention block. The dimension here assumes that the input image dimension is 3.

Both x and g are fed into the 1x1 convolution, changing them to the same number of channels without changing the size. After the upsampling operation (with the same size), they are accumulated and through ReLU through the convolution of another 1x1 and a sigmoid, they get an importance score of 0 to 1, which is assigned to each part of the feature graph and then multiplies it by the skip input to produce the final output of this attention block. Why is this effective?

In UNet, the shrinking path can be regarded as an encoder and the extended path as a decoder. What's interesting about UNet is that hopping connections allow features extracted by the encoder to be used directly during the decoder. In this way, the network learns to use these features when "rebuilding" the mask of the image, because the characteristics of the shrinking path are connected to the characteristics of the extended path.

Applying an attention block before this connection allows the network to exert more weight on the features related to the jump connection. It allows direct connections to focus on specific parts of the input, rather than entering each feature.

Multiply the attention distribution by the jump connection feature graph, leaving only the important parts. This attention distribution is extracted from the so-called query (input) and value (jump join). Attention operations allow you to selectively select the information contained in the value. This selection is based on query.

Summary: input and jump connections are used to determine which parts of the jump connection to focus on. Then we use this subset of the skip connection, as well as the input in the standard deployment path.

1.3. A brief implementation

The following code defines the attention block (simplified version) and the "up-block" for the UNet extension path. "down-block" is the same as the original UNet.

Class AttentionBlock (nn.Module):

Def _ _ init__ (self, in_channels_x, in_channels_g, int_channels):

Super (AttentionBlock, self). _ _ init__ ()

Self.Wx = nn.Sequential (nn.Conv2d (in_channels_x, int_channels, kernel_size = 1)

Nn.BatchNorm2d (int_channels))

Self.Wg = nn.Sequential (nn.Conv2d (in_channels_g, int_channels, kernel_size = 1)

Nn.BatchNorm2d (int_channels))

Self.psi = nn.Sequential (nn.Conv2d (int_channels, 1, kernel_size = 1)

Nn.BatchNorm2d (1)

Nn.Sigmoid ()

Def forward (self, x, g):

# apply the Wx to the skip connection

X1 = self.Wx (x)

# after applying Wg to the input, upsample to the size of the skip connection

G1 = nn.functional.interpolate (self.Wg (g), x1.shape [2:], mode = 'bilinear', align_corners = False)

Out = self.psi (nn.ReLU () (x1 + G1))

Out = nn.Sigmoid () (out)

Return out*x

Class AttentionUpBlock (nn.Module):

Def _ _ init__ (self, in_channels, out_channels):

Super (AttentionUpBlock, self). _ _ init__ ()

Self.upsample = nn.ConvTranspose2d (in_channels, out_channels, kernel_size = 2, stride = 2)

Self.attention = AttentionBlock (out_channels, in_channels, int (out_channels / 2))

Self.conv_bn1 = ConvBatchNorm (in_channels+out_channels, out_channels)

Self.conv_bn2 = ConvBatchNorm (out_channels, out_channels)

Def forward (self, x, x_skip):

# note: x_skip is the skip connection and x is the input from the previous block

# apply the attention block to the skip connection, using x as context

X_attention = self.attention (x_skip, x)

# upsample x to have th same size as the attention map

X = nn.functional.interpolate (x, x_skip.shape [2:], mode = 'bilinear', align_corners = False)

# stack their channels to feed to both convolution blocks

X = torch.cat ((x_attention, x), dim = 1)

X = self.conv_bn1 (x)

Return self.conv_bn2 (x)

When using attention, pay attention to the simple implementation of force blocks and UNet extension path blocks.

Note: ConvBatchNorm is a sequence consisting of Conv2d, BatchNorm2d, and ReLU activation functions.

2. Multi-scale guided attention

The second architecture we are going to discuss is more original than the first. It does not rely on the UNet architecture, but on feature extraction, followed by a guided attention block.

Block diagram of the proposed method

The first part is to extract features from the image. For this reason, we input the input image into a pre-trained ResNet and extract four different levels of feature images. This is interesting because low-level features tend to appear at the beginning of the network, while high-level features tend to appear at the end of the network, so we will be able to access multiple-scale features. Use bilinear interpolation to sample all the feature graphs to the largest one. This gives us four feature graphs of the same size, which are connected and fed into a convolution block. The output of this convolutional block (multi-scale feature map) is connected to each of the four feature map, which gives us the input of our attention blocks, which is a little more complex than the previous one.

2.1. What did you put forward?

The guided attention block depends on the location and channel attention module, so let's start with the overall description.

Block diagram of position and channel attention module

We will try to understand what is going on in these modules, but we will not go into detail about each operation in these two modules (which can be understood in the following code section).

The two blocks are actually very similar, and the only difference between them is that information is extracted from the channel or location. Convolution before flatten makes position more important because the number of channels is reduced during convolution. In the channel attention module, in the process of reshape, the number of original channels is retained, so more weight is given to the channel.

In each block, it is important to note that the top two branches are responsible for extracting the specific attention distribution. For example, in the positional attention module, we have an attention distribution of (WH) x (WH), where the * (I, j) element indicates how much influence location I has on location J*. In the channel block, we have a CxC attention distribution that tells us how much influence one channel has on the other. In the third branch of each module, this particular attention distribution is multiplied by the transformation of the input to get the attention distribution of the channel or position. As mentioned in the previous article, under the background of a given multi-scale feature, the attention distribution is multiplied by the input to extract the relevant information of the input. Then the outputs of the two modules are added element by element, and the final self-attention characteristics are given. Now, let's look at how to use the output of these two modules in the global framework.

The block diagram of the two refinement steps of the guide attention module

Direct attention to establish a continuous number of refinement steps for each scale (there are four scales in the proposed structure). The input feature map is sent to the location and channel output module, and a single feature map is output. It also passes an automatic encoder that reconstructs the input. In each block, notice that the graph is generated by multiplying the two outputs. This attention is then multiplied by the previously generated multi-scale feature map. Therefore, the output indicates which part of a particular scale we need to focus on. Then, by connecting the output of a block to a multi-scale attention map and using it as input to the next block, you can get such a sequence of guided attention modules.

The combined loss of the two is necessary to ensure that the refinement steps work correctly:

Standard reconstruction loss to ensure that the automatic encoder correctly reconstructs the input feature map guide loss, which attempts to minimize the distance between the two subsequent potential representations of the input

After that, each attention feature predicts the mask by convolution blocks. In order to get the final prediction result, it is necessary to average the four mask, which can be regarded as an integration of the model under different scale characteristics.

2.2. Why is this effective?

Because this structure is much more complex than the previous one, it is difficult to understand the situation behind the attention module. Here is my understanding of the contribution of each block.

The position attention module attempts to specify the location of the specific scale feature to focus according to the multi-scale representation of the input image. The channel attention module does the same thing by specifying how much attention each channel needs to pay attention to. The specific operation used in any block is to give channel or location information an attention distribution and to allocate where is more important. Combining these two modules, we get an attention map for the scoring of each location-channel pair, that is, each element in the feature graph.

Autoencoder is used to ensure that subsequent representations of feature map are not completely changed from step to step. Because the latent space is low-dimensional, only key information is extracted. We don't want to change this information from one refinement step to the next, we just want to make minor adjustments. These will not be seen in the potential representation.

Using a series of guided attention modules, the final attention can be refined, and the noise will disappear gradually, giving more weight to the really important areas.

The integration of several such multi-scale networks can make the network have both global and local characteristics. Then these features are combined into multi-scale feature maps. By applying attention with each specific scale to the multi-scale feature map, we can better understand which features are more valuable to the final output.

2.3. A short implementation class PositionAttentionModule (nn.Module):

Def _ _ init__ (self, in_channels):

Super (PositionAttentionModule, self). _ _ init__ ()

Self.first_branch_conv = nn.Conv2d (in_channels, int (in_channels/8), kernel_size = 1)

Self.second_branch_conv = nn.Conv2d (in_channels, int (in_channels/8), kernel_size = 1)

Self.third_branch_conv = nn.Conv2d (in_channels, in_channels, kernel_size = 1)

Self.output_conv = nn.Conv2d (in_channels, in_channels, kernel_size = 1)

Def forward (self, F):

# first branch

F1 = self.first_branch_conv (F) # (C par 8, W, H)

F1 = F1.reshape ((F1.size (0), F1.size (1),-1)) # (Cpay8, Wendh)

F1 = torch.transpose (F1,-2,-1) # (Wendh, Cmax 8)

# second branch

F2 = self.second_branch_conv (F) # (C par 8, W, H)

F2 = F2.reshape ((F2.size (0), F2.size (1),-1)) # (Cpay8, Wendh)

F2 = nn.Softmax (dim =-1) (torch.matmul (F1, F2)) # (Wendh, Wendh)

# third branch

F3 = self.third_branch_conv (F) # (C, W, H)

F3 = F3.reshape ((F3.size (0), F3.size (1),-1)) # (C, Wash)

F3 = torch.matmul (F3, F2) # (C, Wash)

F3 = F3.reshape (F.shape) # (C, W, H)

Return self.output_conv (F3roomF)

Class ChannelAttentionModule (nn.Module):

Def _ _ init__ (self, in_channels):

Super (ChannelAttentionModule, self). _ _ init__ ()

Self.output_conv = nn.Conv2d (in_channels, in_channels, kernel_size = 1)

Def forward (self, F):

# first branch

F1 = F.reshape ((F.size (0), F.size (1),-1)) # (C, Wash)

F1 = torch.transpose (F1,-2,-1) # (Whih, C)

# second branch

F2 = F.reshape ((F.size (0), F.size (1),-1)) # (C, Wash)

F2 = nn.Softmax (dim =-1) (torch.matmul (F2, F1)) # (C, C)

# third branch

F3 = F.reshape ((F.size (0), F.size (1),-1)) # (C, Wash)

F3 = torch.matmul (F2, F3) # (C, Wash)

F3 = F3.reshape (F.shape) # (C, W, H)

Return self.output_conv (F3roomF)

Class GuidedAttentionModule (nn.Module):

Def _ _ init__ (self, in_channels_F, in_channels_Fms):

Super (GuidedAttentionModule, self). _ _ init__ ()

In_channels = in_channels_F + in_channels_Fms

Self.pam = PositionAttentionModule (in_channels)

Self.cam = ChannelAttentionModule (in_channels)

Self.encoder = nn.Sequential (nn.Conv2d (in_channels, 2*in_channels, kernel_size = 3)

Nn.BatchNorm2d (2*in_channels)

Nn.Conv2d (2*in_channels, 4*in_channels, kernel_size = 3)

Nn.BatchNorm2d (4*in_channels)

Nn.ReLU ()

Self.decoder = nn.Sequential (nn.ConvTranspose2d (4*in_channels, 2*in_channels, kernel_size = 3)

Nn.BatchNorm2d (2*in_channels)

Nn.ConvTranspose2d (2*in_channels, in_channels, kernel_size = 3)

Nn.BatchNorm2d (in_channels)

Nn.ReLU ()

Self.attention_map_conv = nn.Sequential (nn.Conv2d (in_channels, in_channels_Fms, kernel_size = 1)

Nn.BatchNorm2d (in_channels_Fms)

Nn.ReLU ()

Def forward (self, F, F_ms):

F = torch.cat ((F, F_ms), dim = 1) # concatenate the extracted feature map with the multi scale feature map

F_pcam = self.pam (F) + self.cam (F) # sum the ouputs of the position and channel attention modules

F_latent = self.encoder (F) # latent-space representation, used for the guided loss

F_reconstructed = self.decoder (F_latent) # output of the autoencoder, used for the reconstruction loss

F_output = self.attention_map_conv (F_reconstructed * F_pcam)

F_output = F_output * F_ms

Return F_output, F_reconstructed, F_latent

A brief implementation of the location attention module, the channel attention module, and a guide attention module.

Attention can be seen as a mechanism that helps to identify characteristics that require attention based on the context of the network.

In UNet, considering the features extracted in the expansion path, which features are extracted in the contraction path need to be paid more attention to. This helps to make the jump connection more meaningful, that is, to convey relevant information rather than each extracted feature.

The above is how to use the attention mechanism to do the interpretation and Pytorch implementation of medical image segmentation. The editor believes that there are some knowledge points that we may see or use in our daily work. I hope you can learn more from this article. For more details, please follow the industry information channel.

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Internet Technology

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report