In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-01-19 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >
Share
Shulou(Shulou.com)06/01 Report--
This article introduces how to deeply analyze the UNet model in Pytorch. The content is very detailed. Interested friends can use it for reference. I hope it will be helpful to you.
I. background of the project
Deep learning algorithm is nothing more than a way for us to solve a problem. What kind of network to choose to train, what kind of preprocessing, what kind of Loss and optimization methods are determined according to the specific task.
So let's take a look at today's task first.
Yes, it is the classic task in UNet's paper: medical image segmentation.
Choose it as today's task because it is simple and easy to use.
A brief description of this task: as shown in the motion picture, give a cell structure map, and we want to separate each cell from each other.
There are only 30 pieces of training data with a resolution of 512x512. These pictures are electron microscopic images of fruit flies.
All right, after the mission introduction, let's prepare the training model.
II. UNet training
To train a deep learning model, you can simply divide it into three steps:
Data loading: how to load data, how to define labels, and what data enhancement methods are used are all carried out in this step.
Model selection: we have prepared the model, which is the UNet network mentioned in the last article in this series.
Algorithm selection: algorithm selection is what loss we choose and what optimization algorithm we use.
Each step is more general, we combined with today's medical image segmentation task, expand the explanation.
1. Data loading
This step, you can do a lot of things, to put it bluntly, nothing more than how to load the picture, how to define the label, in order to increase the robustness of the algorithm or increase the data set, you can do some data enhancement operations.
Since we are dealing with data, let's take a look at what the data looks like before we decide what to do with it.
The data is ready, it's all here (Github): click to view
If the Github download speed is slow, you can use the Baidu link at the end of the article to download the dataset.
The data is divided into training set and test set, 30 pieces each, the training set has labels, and the test sets have no labels.
The processing of data loading is determined according to the task and data set, and we don't need to do much processing for our split task, but because the amount of data is very small, only 30 pieces, we can use some data enhancement methods to expand our data set.
Pytorch provides us with a way to load data, and we can use this framework to load our data. Take a look at the pseudo code:
# = # # Input pipeline for custom dataset # # = # # You should build your custom dataset as below.class CustomDataset (torch.utils.data.Dataset): def _ init__ (self): # TODO # 1. Initialize file paths or a list of file names. Pass def _ getitem__ (self, index): # TODO # 1. Read one data fromfile (e.g. Using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. Torchvision.Transform). # 3. Return a data pair (e.g. Image and label). Pass def _ len__ (self): # You should change 0 to the total size of your dataset. Return 0 # You can then use the prebuilt data loader. Custom_dataset = CustomDataset () train_loader = torch.utils.data.DataLoader (dataset=custom_dataset, batch_size=64, shuffle=True)
This is a standard template that we use to load data, define tags, and make data enhancements.
Create a dataset.py file and write the following code:
Import torchimport cv2import osimport globfrom torch.utils.data import Datasetimport randomclass ISBI_Loader (Dataset): def _ _ init__ (self, data_path): # initialization function Read all the pictures under data_path self.data_path = data_path self.imgs_path = glob.glob (os.path.join (data_path, 'image/*.png')) def augment (self, image, flipCode): # use cv2.flip for data enhancement, filpCode is 1 horizontal flip, 0 vertical flip -1 horizontal + vertical flip flip = cv2.flip (image, flipCode) return flip def _ _ getitem__ (self, index): # read the picture according to index image_path = self.imgs_ path [index] # generate label_path label_path = image_path.replace ('image') based on image_path 'label') # read training picture and tag picture image = cv2.imread (image_path) label = cv2.imread (label_path) # convert data to single-channel picture image = cv2.cvtColor (image, cv2.COLOR_BGR2GRAY) label = cv2.cvtColor (label, cv2.COLOR_BGR2GRAY) image = image.reshape (1, image.shape [0] Image.shape [1]) label = label.reshape (1, label.shape [0], label.shape [1]) # processing tags Change the pixel value of 255to 1 if label.max () > 1: label = label / 255# randomly perform data enhancement FlipCode = random.choice ([- 1,0,1,2]) if flipCode! = 2: image = self.augment (image, flipCode) label = self.augment (label, flipCode) return image Label def _ _ len__ (self): # return training set size return len (self.imgs_path) if _ _ name__ = "_ _ main__": isbi_dataset = ISBI_Loader ("data/train/") print ("number of data:", len (isbi_dataset)) train_loader = torch.utils.data.DataLoader (dataset=isbi_dataset) Batch_size=2, shuffle=True) for image, label in train_loader: print (image.shape)
Run the code, and you can see the following results:
Explain the code:
The _ _ init__ function is the initialization function of this class. According to the specified image path, all the image data is read and stored in the self.imgs_path list.
The _ _ len__ function can return how much data, which is instantiated and called by the len () function.
The _ _ getitem__ function is a data acquisition function, in which you can write how to read and handle the data, and some data preprocessing and data enhancement can be carried out here. My processing here is very simple, just read the picture and process it into a single-channel picture. At the same time, because the picture pixels of label are 0 and 255, it needs to be divided by 255 to become 0 and 1. At the same time, data enhancement is carried out randomly.
The augment function is a defined data enhancement function, and you can do whatever you want. I'm just doing a simple rotation here.
In this class, you don't have to do anything to disrupt the dataset or read the data according to batchsize. Because after instantiating this class, we can use the torch.utils.data.DataLoader method to specify the size of the batchsize and decide whether to disrupt the data.
The DataLoader provided to us by Pytorch is so powerful that we can even specify how many processes are used to load data and whether the data is loaded into CUDA memory.
2. Model selection
Now that we have selected the model, we will use the UNet network structure explained in the previous article "Pytorch Deep Learning practical course (2): UNet semantic Segmentation Network".
But we need to fine-tune the network, exactly according to the structure of the paper, the size of the model output will be slightly smaller than the size of the picture input, if we use the network structure of the paper, we need to do a resize operation after the result output. In order to avoid this step, we can modify the network so that the output size of the network is exactly equal to the input size of the picture.
Create the unet_parts.py file and write the following code:
"Parts of the U-Net model" https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""import torchimport torch.nn as nnimport torch.nn.functional as Fclass DoubleConv (nn.Module): "" (convolution = > [BN] = > ReLU) * 2 "def _ init__ (self, in_channels) Out_channels): super (). _ init__ () self.double_conv = nn.Sequential (nn.Conv2d (in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d (out_channels), nn.ReLU (inplace=True), nn.Conv2d (out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d (out_channels) Nn.ReLU (inplace=True)) def forward (self, x): return self.double_conv (x) class Down (nn.Module): "Downscaling with maxpool then double conv"def _ _ init__ (self, in_channels, out_channels): super (). _ init__ () self.maxpool_conv = nn.Sequential (nn.MaxPool2d (2)) DoubleConv (in_channels, out_channels) def forward (self, x): return self.maxpool_conv (x) class Up (nn.Module): "" Upscaling then double conv "" def _ init__ (self, in_channels, out_channels, bilinear=True): super (). _ init__ () # if bilinear Use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample (scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d (in_channels / / 2, in_channels / / 2, kernel_size=2, stride=2) self.conv = DoubleConv (in_channels, out_channels) def forward (self, x1) X2): X1 = self.up (x1) # input is CHW diffY = torch.tensor ([x2.size () [2]-x1.size () [2]]) diffX = torch.tensor ([x2.size () [3]-x1.size () [3]]) x1 = F.pad (x1, [diffX / / 2, diffX-diffX / / 2, diffY / / 2) DiffY-diffY / / 2]) x = torch.cat ([x2, x1], dim=1) return self.conv (x) class OutConv (nn.Module): def _ init__ (self, in_channels, out_channels): super (OutConv, self). _ _ init__ () self.conv = nn.Conv2d (in_channels, out_channels, kernel_size=1) def forward (self, x): return self.conv (x)
Create the unet_model.py file and write the following code:
"" Full assembly of the parts to form the complete network "Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""import torch.nn.functional as Ffrom .unet _ parts import * class UNet (nn.Module): def _ _ init__ (self, n_channels, n_classes, bilinear=True): super (UNet) Self). _ init__ () self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv (n_channels, 64) self.down1 = Down (64,128) self.down2 = Down (128,256) self.down3 = Down (256,512) self.down4 = Down (512,512) self.up1 = Up (1024 Bilinear) self.up2 = Up (512,128, bilinear) self.up3 = Up (256,64, bilinear) self.up4 = Up (128,64, bilinear) self.outc = OutConv (64, n_classes) def forward (self X): X1 = self.inc (x) x2 = self.down1 (x1) x3 = self.down2 (x2) x4 = self.down3 (x3) x5 = self.down4 (x4) x = self.up1 (x5, x4) x = self.up2 (x, x3) x = self.up3 (x, x2) x = self.up4 (x X1) logits = self.outc (x) return logitsif _ _ name__ = ='_ main__': net = UNet (n_channels=3, n_classes=1) print (net)
After this adjustment, the output size of the network is the same as the input size of the picture.
3. Algorithm selection
The choice of Loss is very important, and the quality of Loss selection will affect the effect of the algorithm to fit the data.
The choice of Loss is also determined by the task. Our task today is to split the edge of the cell, which is a very simple binary task, so we can use BCEWithLogitsLoss.
What is BCEWithLogitsLoss? BCEWithLogitsLoss is a function provided by Pytorch to calculate the cross entropy of binary classification.
Its formula is:
Friends who have read my series of machine learning tutorials must be familiar with this formula, which is the loss function of Logistic regression. It makes use of the feature that the threshold of the Sigmoid function is in [0jue 1] to classify.
For specific formula derivation, you can take a look at my series of machine learning tutorials "Machine Learning practical tutorials (6): gradient ascent algorithm for the basic part of Logistic regression", which will not be repeated here.
The objective function, that is, Loss, has been determined, how to optimize this goal?
The simplest way is that we are familiar with the gradient descent algorithm, gradually approaching the local extremum.
But this simple optimization algorithm is slow to solve, that is, it takes a lot of effort to find the optimal solution.
All kinds of optimization algorithms are gradient descent in essence. For example, the most conventional SGD is the improved random gradient descent algorithm based on gradient descent. Momentum introduces momentum SGD to accumulate historical gradients in the form of exponential decay.
In addition to these basic optimization algorithms, there are also adaptive parameter optimization algorithms. The most important feature of this kind of algorithm is that each parameter has a different learning rate, which automatically adapts to these learning rates in the whole learning process, so as to achieve a better convergence effect.
This paper chooses an adaptive optimization algorithm RMSProp.
Due to the limited space, we will not expand here. It is not enough to write a single article on this optimization algorithm. To understand RMSProp, you need to know what AdaGrad is, because RMSProp is an improvement based on AdaGrad.
There are also more advanced optimization algorithms than RMSProp, such as the famous Adam, which can be thought of as a modified Momentum+RMSProp algorithm.
In short, for beginners, you only need to know that RMSProp is an adaptive optimization algorithm, more advanced on the line.
Now, we can start to write the code to train UNet, create the train.py and write the following code:
From model.unet_model import UNetfrom utils.dataset import ISBI_Loaderfrom torch import optimimport torch.nn as nnimport torchdef train_net (net, device, data_path, epochs=40, batch_size=1, lr=0.00001): # load the training set isbi_dataset = ISBI_Loader (data_path) train_loader = torch.utils.data.DataLoader (dataset=isbi_dataset, batch_size=batch_size) Shuffle=True) # define RMSprop algorithm optimizer = optim.RMSprop (net.parameters (), lr=lr, weight_decay=1e-8, momentum=0.9) # define Loss algorithm criterion = nn.BCEWithLogitsLoss () # best_loss statistics Initialize to positive infinite best_loss = float ('inf') # training epochs times for epoch in range (epochs): # training mode net.train () # start training for image according to batch_size, label in train_loader: optimizer.zero_grad () # copy data to device image = image.to (device=device Dtype=torch.float32) label = label.to (device=device, dtype=torch.float32) # use network parameters Output prediction result pred = net (image) # calculate loss loss = criterion (pred, label) print ('Loss/train', loss.item ()) # save the network parameter if loss with the lowest LOS value
< best_loss: best_loss = loss torch.save(net.state_dict(), 'best_model.pth') # 更新参数 loss.backward() optimizer.step()if __name__ == "__main__": # 选择设备,有cuda用cuda,没有就用cpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载网络,图片单通道1,分类为1。 net = UNet(n_channels=1, n_classes=1) # 将网络拷贝到deivce中 net.to(device=device) # 指定训练集地址,开始训练 data_path = "data/train/" train_net(net, device, data_path) 为了让工程更加清晰简洁,我们创建一个 model 文件夹,里面放模型相关的代码,也就是我们的网络结构代码,unet_parts.py 和 unet_model.py。 创建一个 utils 文件夹,里面放工具相关的代码,比如数据加载工具dataset.py。 这种模块化的管理,大大提高了代码的可维护性。 train.py 放在工程根目录即可,简单解释下代码。 由于数据就30张,我们就不分训练集和验证集了,我们保存训练集 loss 值最低的网络参数作为最佳模型参数。 如果都没有问题,你可以看到 loss 正在逐渐收敛。 三、预测 模型训练好了,我们可以用它在测试集上看下效果。 在工程根目录创建 predict.py 文件,编写如下代码: import globimport numpy as npimport torchimport osimport cv2from model.unet_model import UNetif __name__ == "__main__": # 选择设备,有cuda用cuda,没有就用cpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载网络,图片单通道,分类为1。 net = UNet(n_channels=1, n_classes=1) # 将网络拷贝到deivce中 net.to(device=device) # 加载模型参数 net.load_state_dict(torch.load('best_model.pth', map_location=device)) # 测试模式 net.eval() # 读取所有图片路径 tests_path = glob.glob('data/test/*.png') # 遍历所有图片 for test_path in tests_path: # 保存结果地址 save_res_path = test_path.split('.')[0] + '_res.png' # 读取图片 img = cv2.imread(test_path) # 转为灰度图 img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # 转为batch为1,通道为1,大小为512*512的数组 img = img.reshape(1, 1, img.shape[0], img.shape[1]) # 转为tensor img_tensor = torch.from_numpy(img) # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。 img_tensor = img_tensor.to(device=device, dtype=torch.float32) # 预测 pred = net(img_tensor) # 提取结果 pred = np.array(pred.data.cpu()[0])[0] # 处理结果 pred[pred >= 0.5] = 255 pred [pred < 0.5] = 0 # Save picture cv2.imwrite (save_res_path, pred)
After running, you can see the forecast result in the data/test directory:
The great task has been completed!
On how to in-depth analysis of the UNet model in Pytorch to share here, I hope the above content can be of some help to you, can learn more knowledge. If you think the article is good, you can share it for more people to see.
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.