In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-04-09 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >
Share
Shulou(Shulou.com)06/01 Report--
It is believed that many inexperienced people have no idea about how to use transfer learning resnet18 to train mnist data sets in pytorch. Therefore, this paper summarizes the causes and solutions of the problems. Through this article, I hope you can solve this problem.
Preliminary knowledge
Build your own cnn model to train mnist (without using transfer learning)
Https://blog.csdn.net/qq_42951560/article/details/109565625
Pytorch official transfer learning course (ant, bee classification)
Https://blog.csdn.net/qq_42951560/article/details/109950786
Learning goal
Today we try to train mnist datasets using migration learning in pytorch.
How to transfer the pre-training model
Transfer learning needs to choose a pre-training model, we this task is not very big, choose resnet18 on the line.
Data preprocessing
The CHW entered by resnet18 is (3,224,224)
The CHW of a single picture in the mnist dataset is (1,28,28)
So we need to preprocess the mnist dataset:
# preprocessing my_transform = transforms.Compose ([transforms.Resize (224,224)), transforms.Grayscale (3), transforms.ToTensor (), transforms.Normalize ((0.1307) 0.1307), (0.3081) (0.3081)),] # training set train_file = datasets.MNIST (root='./dataset/',train=True,transform=my_transform) # Test set test_file = datasets.MNIST (root='./dataset/',train=False,transform=my_transform)
The tutorial on data enhancement and image processing in pytorch (torchvision.transforms) can see my article.
Change to full connection layer
Resnet18 is trained on imagenet, and the number of output features is 1000; for mnist, it needs to be divided into 10 categories, so you need to change the output of the full connection layer.
Model = models.resnet18 (pretrained=True) in_features = model.fc.in_featuresmodel.fc = nn.Linear (in_features, 10) adjust the learning rate
The previously set learning rate of Adam is 1e-3, but now transfer learning is used, so the learning rate is reduced to 1e-4.
Training result
Resnet18 is deeper than an ordinary layer or two convolution network, and the mnsit data set is quite large, with a total of 70, 000 images. To save time, we used 7 GeForce GTX 1080 Ti to train:
Data parallel (DataParallel)
EPOCH: 01/10 STEP: 67/67 LOSS: 0.0266 ACC: 0.9940 VAL-LOSS: 0.0246 VAL-ACC: 0.9938 TOTAL-TIME: 102EPOCH: 02/10 STEP: 67/67 LOSS: 0.0141 ACC: 0.9973 VAL-LOSS: 0.0177 VAL-ACC: 0.9948 TOTAL-TIME: 80EPOCH: 03/10 STEP: 67/67 LOSS: 0.0067 ACC: 0.9990 VAL-LOSS: 0.0147 VAL-ACC: 0.9958 TOTAL-TIME: 80EPOCH: 04/10 STEP : 67/67 LOSS: 0.0042 ACC: 0.9995 VAL-LOSS: 0.0151 VAL-ACC: 0.9948 TOTAL-TIME: 80EPOCH: 05/10 STEP: 67/67 LOSS: 0.0029 ACC: 0.9997 VAL-LOSS: 0.0143 VAL-ACC: 0.9955 TOTAL-TIME: 80EPOCH: 06/10 STEP: 67/67 LOSS: 0.0019 ACC: 0.9999 VAL-LOSS: 0.0133 VAL-ACC: 0.9962 TOTAL-TIME: 80EPOCH: 07/10 STEP: 67/67 LOSS: 0.0013 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963 TOTAL-TIME: 80EPOCH: 08/10 STEP: 67/67 LOSS: 0.0008 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963 TOTAL-TIME: 79EPOCH: 09/10 STEP: 67/67 LOSS: 0.0006 ACC: 1.0000 VAL-LOSS: 0.0122 VAL-ACC: 0.9962 TOTAL-TIME: 79EPOCH: 10/10 STEP: 67/67 LOSS: 0.0005 ACC: 1 .0000 VAL-LOSS: 0.0131 VAL-ACC: 0.9959 TOTAL-TIME: 79 | BEST-MODEL | EPOCH: 07 VAL-ACC 10 STEP: 67 ACC 67 LOSS: 0.0013 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963
After 10 rounds of training, the best model appears in the seventh round, with a maximum accuracy of 0.9963. In this article, we built two layers of convolution and trained for 10 rounds, with a maximum accuracy of 0.9923. The accuracy has been improved by 0.0040, we need to know that there are 10, 000 images in the test set, that is, 40 images have been predicted correctly, which has been improved a lot. Of course, because the network has become deeper, the time spent on training has increased.
After reading the above, have you mastered how to use transfer learning resnet18 to train mnist datasets in pytorch? If you want to learn more skills or want to know more about it, you are welcome to follow the industry information channel, thank you for reading!
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.