Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to realize handwritten Digital Picture recognition in pytorch

2025-04-08 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >

Share

Shulou(Shulou.com)06/01 Report--

Editor to share with you how to achieve handwritten digital picture recognition in pytorch. I hope you will get something after reading this article. Let's discuss it together.

The details are as follows

Dataset: MNIST dataset, which will be downloaded automatically in the code, not manually. The data set is very small, do not need GPU equipment, you can well understand the charm of pytorch.

Model + training + Forecast Program:

Import torchfrom torch import nnfrom torch.nn import functional as Ffrom torch import optimimport torchvisionfrom matplotlib import pyplot as pltfrom utils import plot_image, plot_curve, one_hot# step1 load datasetbatch_size = 512train_loader = torch.utils.data.DataLoader (torchvision.datasets.MNIST ('mnist_data', train=True, download=True, transform=torchvision.transforms.Compose ([torchvision.transforms.ToTensor ()) Torchvision.transforms.Normalize ((0.1307,), (0.3081,)]), batch_size=batch_size Shuffle=True) test_loader = torch.utils.data.DataLoader (torchvision.datasets.MNIST ('mnist_data/', train=False, download=True, transform=torchvision.transforms.Compose ([torchvision.transforms.ToTensor (), torchvision.transforms.Normalize ((0.1307,)) (0.3081,)]), batch_size=batch_size, shuffle=False) x, y = next (iter (train_loader)) print (x.shape, y.shape, x.min (), x.max ()) plot_image (x, y) "image_sample") class Net (nn.Module): def _ _ init__ (self): super (Net, self). _ _ init__ () self.fc1 = nn.Linear (28,28,256) self.fc2 = nn.Linear (256,64) self.fc3 = nn.Linear (64,10) def forward (self, x): # x: [b, 1,28 28] # H2 = relu (xw1 + b1) x = F.relu (self.fc1 (x)) # h3 = relu (h2w2 + b2) x = F.relu (self.fc2 (x)) # h4 = h3w3 + b3 x = self.fc3 (x) return xnet = Net () optimizer = optim.SGD (net.parameters (), lr=0.01 Momentum=0.9) train_loss = [] for epoch in range (3): for batch_idx, (x, y) in enumerate (train_loader): # the image loaded is a four-dimensional tensor X: [B, 1,28,28], y: [512] # but the input of our network is an one-dimensional vector (that is, two-dimensional tensor) So to do the flattening operation x = x.view (x.size (0), 28: 28) # [b, 10] out = net (x) y_onehot = one_hot (y) # loss = mse (out, y_onehot) loss = F.mse_loss (out Y_onehot) optimizer.zero_grad () loss.backward () # w' = w-lr*grad optimizer.step () train_loss.append (loss.item ()) if batch_idx% 10 = 0: print (epoch, batch_idx, loss.item () plot_curve (train_loss) # we get optimal [W1, b1, w2, b2, w3, b3] total_correct = 0for x Y in test_loader: X = x.view (x.size (0), 28,28) out = net (x) # out: [b, 10] pred = out.argmax (dim=1) correct = pred.eq (y). Sum (). Float (). Item () total_correct + = correcttotal_num = len (test_loader.dataset) acc = total_correct/total_numprint ("acc:", acc) x Y = next (iter (test_loader)) out = net (x.view (x.size (0), 28,28)) pred = out.argmax (dim=1) plot_image (x, pred, "test")

The function called in the main program (named utils):

Import torchfrom matplotlib import pyplot as pltdef plot_curve (data): fig = plt.figure () plt.plot (range (len (data)), data, color='blue') plt.legend (['value'], loc='upper right') plt.xlabel (' step') plt.ylabel ('value') plt.show () def plot_image (img, label, name): fig = plt.figure () for i in range (6): plt.subplot (2) 3, I + 1) plt.tight_layout () plt.imshow (IMG [I] [0] * 0.3081 / 0.1307, cmap='gray', interpolation='none') plt.title ("{}: {}" .format (name, item [)) plt.xticks ([]) plt.yticks ([]) plt.show () def one_hot (label) Depth=10): out = torch.zeros (label.size (0), depth) idx = torch.LongTensor (label). View (- 1,1) out.scatter_ (dim=1, index=idx, value=1) return out has finished reading this article I believe you have a certain understanding of "how to achieve handwritten digital picture recognition in pytorch". If you want to know more about it, you are welcome to follow the industry information channel. Thank you for reading!

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Development

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report