Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to classify with RNN

2025-01-18 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >

Share

Shulou(Shulou.com)06/01 Report--

This article is to share with you about how to use RNN for classification, the editor thinks it is very practical, so I share it with you to learn. I hope you can get something after reading this article.

Today we are going to introduce how RNN plays with classification.

The MNIST dataset, which we are all familiar with, is a dataset of handwritten numbers, which we used to implement CNN classifiers and machine learning methods (reply "MNIST" in the official account and download it for free). Today we use RNN to make a prediction of the MNIST data set.

At this time, we need to treat each data image as a sequence signal of 28x28 (the size of the image is 28x28pixels). For the whole network framework, we use a 150 cyclic neurons plus a full connection layer with 10 neurons (one for each class), and finally a softmax layer. As follows:

The construction phase of the whole model is also very straightforward, which is very similar to the dnn construction method we learned in previous installments, except that the unexpanded RNN is used instead of the previous hidden layer. It should be noted that the last full connection layer connects to the state tensor of RNN, which contains only the last state of RNN, and y is the target category.

From tensorflow.contrib.layers import fully_connected

N_steps = 28

N_inputs = 28

N_neurons = 150

N_outputs = 10

Learning_rate = 0.001

X = tf.placeholder (tf.float32, [None, n_steps, n_inputs])

Y = tf.placeholder (tf.int32, [None])

Basic_cell = tf.contrib.rnn.BasicRNNCell (num_units=n_neurons)

Outputs, states = tf.nn.dynamic_rnn (basic_cell, X, dtype=tf.float32)

Logits = fully_connected (states, n_outputs, activation_fn=None)

Xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits (

Labels=y, logits=logits)

Loss = tf.reduce_mean (xentropy)

Optimizer = tf.train.AdamOptimizer (learning_rate=learning_rate)

Training_op = optimizer.minimize (loss)

Correct = tf.nn.in_top_k (logits, y, 1)

Accuracy = tf.reduce_mean (tf.cast (correct, tf.float32))

Init = tf.global_variables_initializer ()

Next, we load the dataset and reshape the dataset as follows:

From tensorflow.examples.tutorials.mnist import input_data

Mnist = input_data.read_data_sets ("/ tmp/data/")

X_test = mnist.test.images.reshape ((- 1, n_steps, n_inputs))

Y_test = mnist.test.labels

Now, we will training the above RNN, which is very similar to the previous dnn in the execution phase, as follows:

N_epochs = 100

Batch_size = 150

With tf.Session () as sess:

Init.run ()

For epoch in range (n_epochs):

For iteration in range (mnist.train.num_examples / / batch_size):

X_batch, y_batch = mnist.train.next_batch (batch_size)

X_batch = X_batch.reshape ((- 1, n_steps, n_inputs))

Sess.run (training_op, feed_dict= {X: X_batch, y: y_batch})

Acc_train = accuracy.eval (feed_dict= {X: X_batch, y: y_batch})

Acc_test = accuracy.eval (feed_dict= {X: X_test, y: y_test})

Print (epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)

The output is as follows:

0 Train accuracy: 0.713333 Test accuracy: 0.7299

1 Train accuracy: 0.766667 Test accuracy: 0.7977

...

98 Train accuracy: 0.986667 Test accuracy: 0.9777

99 Train accuracy: 0.986667 Test accuracy: 0.9809

In the end, the accuracy is 98%, which is quite good. If we adjust the super parameters or the initialization of RNN weights, train longer, or add some regularization methods, the result should be even better.

The above is how to classify with RNN, and the editor believes that there are some knowledge points that we may see or use in our daily work. I hope you can learn more from this article. For more details, please follow the industry information channel.

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Internet Technology

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report