Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to realize the training and Prediction of LeNet Network Model by Python

2025-01-18 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >

Share

Shulou(Shulou.com)06/02 Report--

What this article shares with you is about how Python realizes the training and prediction of LeNet network model. The editor thinks it is very practical, so I share it with you to learn. I hope you can get something after reading this article. Let's take a look at it with the editor.

1.LeNet model training script

The overall training code is as follows. I will explain the meaning of these codes in detail.

Import torchimport torchvisionfrom torchvision.transforms import transformsimport torch.nn as nnfrom torch.utils.data import DataLoaderfrom pytorch.lenet.model import LeNetimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as plttransform = transforms.Compose (# converts the dataset into tensor form [transforms.ToTensor (), # standardizes, 0.5 is the mean and variance The corresponding three dimensions are all 0.5 transforms.Normalize ((0.5,0.5,0.5), (0.5,0.5,0.5)]) # when downloading a complete dataset, download=True, the first is the saved path, and after download, the download is changed to False# as the training set, train=True, when it is the test set Train=Falsetrain_set = torchvision.datasets.CIFAR10 ('. / data', train=True, download=False, transform=transform) # load the training set, set the batch size, whether to disturb, number_works is the number of threads, if window is not set to 0, an error will be reported Linux can set non-zero train _ loader = DataLoader (train_set, batch_size=36, shuffle=True, num_workers=0) test_set = torchvision.datasets.CIFAR10 ('. / data', train=False, download=False, transform=transform) # set batch size to transfer all test set images to test_loader = DataLoader (test_set, batch_size=10000) at once Shuffle=False, num_workers=0) # the image data and tag values of the iterative test set test_img, test_label = next (iter (test_loader)) # CIFAR10's ten category names classes = ('plane',' car', 'bird',' cat', 'deer',' dog', 'frog',' horse', 'ship' 'truck') # #-- display picture-- # def imshow (img, label): # fig = plt.figure () # for i in range (len (img)): # ax = fig.add_subplot (1, len (img)) ITun1) # nping = IMG [I] .numpy () .transpose ([1,2] 0]) # npimg = (nping * 2 + 0.5) # plt.imshow (npimg) # title ='{} '.format (classes [label]) # ax.set_title (title) # plt.axis (' off') # plt.show () # batch_image = test_img [: 5] # label_img = test_label [: 5] # imshow (batch_image Label_img) #-net = LeNet () # define the loss function Nn.CrossEntropyLoss () comes with softmax function So the last layer of the model does not need softmax to activate loss_function = nn.CrossEntropyLoss () # define the optimizer, optimize all parameters of the network model optimizer = optim.Adam (net.parameters (), lr=0.001) # iterate five times for epoch in range (5): # initial loss is set to 0 running_loss = 0 # cyclic training set, for step, data in enumerate (train_loader, start=1): inputs from 1 Labels = data # the gradient of the optimizer needs to be cleared to zero, otherwise the gradient will be superimposed indefinitely, which is equivalent to increasing the batch size optimizer.zero_grad () # input the picture data into the model outputs = net (inputs) # input the predicted value and the real value, and calculate the current loss value loss = loss_function (outputs) Labels) # loss back propagation loss.backward () # perform gradient update optimizer.step () # calculate the total loss of the round Because loss is a tensor type, you need to use item () to take the specific value running_loss + = loss.item () # to print the log every 500 times, and to predict the test set if step% 500 = 0: # torch.no_grad () is context management, and there is no need for gradient updates during testing. Do not track gradient with torch.no_grad (): # pass in all test set images to predict dim=1 in outputs = net (test_img) # torch.max () because the result is in the form of (batch, 10), we only need to take the maximum value of the second dimension # max this function returns [maximum value Maximum index], we just need to fetch the index. So use [1] predict_y = torch.max (outputs, dim=1) [1] # (predict_y = = test_label) the same return True, unequal return False,sum () to stack the correct rate # because the calculated variables are all tensor So you need to use item () to get the value accuracy = (predict_y = = test_label). Sum (). Item () / test_label.size (0) # running_loss/500 is the loss that calculates each step That is, the loss of each step print ('[% d,% 5d] train_loss:% .3f test_accuracy:% .3f'% (epoch+1, step, running_loss/500, accuracy)) running_loss = 0.0print ('Finished training') save_path = 'lenet.pth'# saves the model Dictionary form torch.save (net.state_dict (), save_path) (1). Download the CIFAR10 dataset

First of all, to train a network model, we need enough pictures to make data sets. Here we use the CIFAR10 data sets provided by torchvision.dataset (for more data sets, you can go to the official website of pytorch to see the data sets provided by pytorch).

Train_set = torchvision.datasets.CIFAR10 ('. / data', train=True, download=False, transform=transform) test_set = torchvision.datasets.CIFAR10 ('. / data', train=False, download=False, transform=transform)

This part of the code is to download CIFAR10. The first parameter is the path where the dataset is stored after downloading. Train=True and False correspond to the downloaded training set and test set, and transform is the corresponding image enhancement method.

(2)。 Image enhancement transform = transforms.Compose (# convert the dataset into tensor form [transforms.ToTensor (), # standardize, 0.5 is the mean and also the variance, corresponding to 0.5 transforms.Normalize ((0.5,0.5,0.5), (0.5,0.5,0.5)]))

This is simple image enhancement. Transforms.ToTensor () converts all images of the dataset into tensor. Transforms.Normalize () is standardized, including two tuples corresponding to mean and standard deviation. Each tuple contains three elements corresponding to three dimensions of the picture [channels, height, width]. Why is it sorted in this way? don't ask, ask is required by pytorch, the order can not be changed. After that, you will see the two sets of data transforms.Normalize ([0.485, 0.406, 0.456], [0.229, 0.224, 0.225]), which are the official mean and standard deviation, which are often used in standardization.

(3)。 Load the training set # load the training set, set the batch size, and whether to disturb. Number_works is the number of threads. If window is not set to 0, an error will be reported. Linux can set non-zero train _ loader = DataLoader (dataset=train_set, batch_size=36, shuffle=True, num_workers=0) test_loader = DataLoader (dataset=test_set, batch_size=36, shuffle=False, num_workers=0)

It is also important to simply set four parameters. The first is the training set and test set that need to be loaded. Shuffle=True means to disrupt the data set. Batch_size means to put 36 images into the device at a time and pack them into a batch. Then the shape of the image will go from [3, 32, 32]-- "[36, 3, 32, 32], and the shape passed into the network model must also be [None, channels, height, width]. None represents how many pictures of a batch, otherwise an error will be reported. Number_works represents the number of threads. Window system must be set to 0, otherwise it will report an error. Linux system can set a non-zero number.

(4)。 Display part of the image def imshow (img, label): fig = plt.figure () for i in range (len (img)): ax = fig.add_subplot (1, len (img), iTun1) nping = img [I] .numpy () .transpose ([1,2) 0]) npimg = (nping * 2 + 0.5) plt.imshow (npimg) title ='{} '.format (classes [label]) ax.set_title (title) plt.axis (' off') plt.show () batch_image = test_img [: 5] label_img = test_label [: 5] imshow (batch_image, label_img)

This part of the code is to display the first five pictures in the test set, and after running it, it will show five spliced pictures.

As the pictures of this data set are relatively small, all of them are 32x32 size, and some of them may not be seen clearly. The real label is shown in the picture. Note: the code that displays the picture may give this alarm (Clipping input data to the valid range for imshow with RGB data ([0... 1] for floats or [0... For integers), warning solution: convert the array of pictures to uint8 type, that is, plt.imshow (npimg.astype ('uint8')), but the pictures displayed in that way will change, so you can leave them alone for the time being.

(5)。 Initialization model

After the data and pictures are processed, here is our formal training process.

Net = LeNet () # defines the loss function, and nn.CrossEntropyLoss () has its own softmax function, so the last layer of the model does not need softmax to activate loss_function = nn.CrossEntropyLoss () # define the optimizer, optimize all parameters of the model optimizer = optim.Adam (net.parameters (), lr=0.001)

First initialize the LeNet network, define the cross-entropy loss function, and the Adam optimizer. For comments, we can ctrl+ the left mouse button to view CrossEntropyLoss (). When we turn to the CrossEntropyLoss class, we can see that the standard written by the annotation contains the LogSoftmax function, so the last layer of the LetNet model does not use the softmax activation function.

(6)。 Training model and save model parameters for epoch in range (5): # initial loss is set to 0 running_loss = 0 # cyclic training set, starting from 1, for step, data in enumerate (train_loader, start=1): inputs, labels = data # the gradient of the optimizer needs to be cleared, otherwise the gradient will be superimposed infinitely Equivalent to increasing the batch size optimizer.zero_grad () # input the picture data into the model to get the output outputs = net (inputs) # input the predicted value and the real value, calculate the current loss value loss = loss_function (outputs, labels) # loss back propagation loss.backward () # to update the gradient (update W B) optimizer.step () # calculate the total loss of the ship Because loss is a tensor type, you need to use item () to get the value running_loss + = loss.item () # to print the log every 500 times. Testing the test set if step% 500 = = 0: # torch.no_grad () is context management, and there is no need for gradient updates during testing. Do not track gradient with torch.no_grad (): # pass in all test set images to predict dim=1 in outputs = net (test_img) # torch.max () because the result is in the form of (batch, 10), we only need to take the maximum value of the second dimension The second dimension is the vector # max that contains the probability of each of the ten categories. This function returns [maximum, maximum index]. We just need to take the index. So use [1] predict_y = torch.max (outputs, dim=1) [1] # (predict_y = = test_label) the same return True, unequal return False,sum () to stack the correct result Finally, except for the total number of test set tags # because all the calculated variables are tensor, you need to use item () to get the value accuracy = (predict_y = = test_label). Sum (). Item () / test_label.size (0) # running_loss/500 is the loss that calculates each step That is, the loss of each step print ('[% d,% 5d] train_loss:% .3f test_accuracy:% .3f'% (epoch+1, step, running_loss/500, accuracy)) running_loss = 0.0 print ('Finished trading') save_path = 'lenet.pth'# save the model Dictionary form torch.save (net.state_dict (), save_path)

This code is clearly annotated, and you can understand it if you read it carefully. the process is not complicated, and you can understand it by reading it a few more times. Finally, you can save the trained model.

two。 Prediction script

The above model has been trained, and the lenet.pth parameter file is obtained. The prediction is very simple. You can go to the Internet to find any category image contained in the data set, load the model parameter file into the model, and then feed the model through a little processing of the image. Here is the code:

Import torchimport numpy as npimport torchvision.transforms as transformsfrom PIL import Imagefrom pytorch.lenet.model import LeNetclasses = ('plane',' car', 'bird',' cat', 'deer',' dog', 'frog',' horse', 'ship',' truck') transforms = transforms.Compose (# resize data images [transforms.Resize ([32,32]), transforms.ToTensor (), transforms.Normalize ((0.5,0.5)) Net = LeNet () # load the pre-training model net.load_state_dict (torch.load ('lenet.pth')) # img_path ='.. / Photo/cat2.jpg'img = Image.open (img_path) # Image processing img = transforms (img) # add a dimension, (channels, height) Width)-> (batch, channels, height, width) Pytorch requires that you enter the following shapeimg = torch.unsqueeze (img, dim=0) with torch.no_grad (): output = net (img) # dim=1, only take the dimension of the 10 categories in [batch, 10], and index the maximum value of the predicted result. And convert it to numpy type prediction1 = torch.max (output, dim=1) [1] .data.numpy () # use softmax () to predict a probability matrix prediction2 = torch.softmax (output, dim=1) # to get the highest probability worth indexing prediction2 = np.argmax (prediction2) # both ways can get the final result print (classes [prediction1]) print (classes [prediction2])

Anyway, I finally predicted that the cat was identified as a dog, and there was still a 90.01% probability, which was outrageous, but it also showed that the LeNet network model was indeed very shallow, and the feature extraction was not deep enough for this kind of appearance.

The above is how Python realizes the training and prediction of LeNet network model. The editor believes that there are some knowledge points that we may see or use in our daily work. I hope you can learn more from this article. For more details, please follow the industry information channel.

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Development

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report