In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-04-03 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >
Share
Shulou(Shulou.com)06/02 Report--
This article mainly explains "what is the method of saving and reading TensorFlow neural network model in python". Interested friends may wish to take a look. The method introduced in this paper is simple, fast and practical. Let's let Xiaobian take you to learn "what is the method of saving and reading TensorFlow neural network model in python"!
TensorFlow provides a very simple API, the tf.train.Saver class, to save and restore a neural network model.
The following code shows how to save the TensorFlow model:
import tensorflow as tf#Declare two variables v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")init_op = tf.global_variables_initializer() #Initialize all variables saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) #Declare tf.train.Saver class for saving models with tf.Session() as sess: sess.run(init_op) print("v1:", sess.run (v1)) #Print v1, v2 values after reading print("v2:", sess.run(v2)) saver_path = saver.save (sess, "save/model.ckpt") #Save model to save/model.ckpt file print("Model saved in file:", saver_path)
Note: Saver method has been changed, now is V2 version, tf.train.Saver(write_version=tf.train.SaverDef.V1) Add this parameter in parentheses can continue to use V1, but will report warning, can be ignored. If you use saver = tf.train.Saver(), the current version (V2) will be used by default. After saving, 4 files will appear in the save folder, and the file model.ckpt.data-0000-of-00001 will be more than the V1 version. Thank you for pointing out this in the comment. As for the meaning of this file, I am still not very clear, nor have I found specific information. TensorFlow has been open source since the end of 15, and many functions have been changed, or updated or abandoned. Maybe some codes were OK at that time, but after a long time, they may report errors. Please indicate the event time here: 2017.4.30
In this code, the TensorFlow model is saved to the save/model.ckpt file through the saver.save function. The path specified in the code here is "save/model.ckpt", that is, saved to the save folder in the folder where the current program is located.
TensorFlow models are saved in files with the suffix.ckpt. After saving, three files will appear in the save folder, because TensorFlow will save the structure of the calculation graph separately from the parameter values on the graph.
Checkpoint files hold a list of all model files in a directory that is automatically generated and maintained by the tf.train.Saver class. The filename of all TensorFlow model files persisted by a tf.train.Saver class is maintained in the checkpoint file. When a saved TensorFlow model file is deleted, the file name corresponding to that model is also deleted from the checkpoint file. CheckpointState Protocol Buffer.
The model.ckpt.meta file stores the structure of TensorFlow computational graphs, which can be understood as the network structure of neural networks.
TensorFlow uses MetaGraphs to record information about nodes in a computational graph and metadata needed to run nodes in a computational graph. A metagraph in TensorFlow is defined by MetaGraphDef Protocol Buffer. The contents of MetaGraphDef constitute the first file when TensorFlow is persisted. Files that store MetaGraphDef information default to.meta, and the file model.ckpt.meta stores metagraph data.
The model.ckpt file stores the values of each variable in TensorFlow. This file is stored in SSTable format and can be roughly understood as a (key, value) list. The first line of the list in the model.ckpt file describes meta-information about the file, such as the list of variables stored in the file. Each of the remaining rows of the list holds a fragment of a variable whose information is defined by SavedSlice Protocol Buffer. The SavedSlice type holds the name of the variable, information about the current fragment, and the value of the variable. TensorFlow provides the tf.train.NewCheckpointReader class to view variable information stored in the model.ckpt file. How to use the tf.train.NewCheckpointReader class is not explained here, self-examination.
The following code shows how to load the TensorFlow model:
Can you compare v1, v2 values are randomly initialized values or with the previous saved values are the same?
import tensorflow as tf#Declare variables v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")saver = tf.train.Saver() #Declare tf.train.Saver classes to save models with tf.Session() as sess: saver.restore(sess, "save/model.ckpt") #Read the Session that is about to be cured to the hard disk from the save path print("v1:", sess.run (v1)) #Print v1, v2 values to compare with previous ones print("v2:", sess.run(v2)) print("Model Restored")
Run Results:
v1: [[ 0.76705766 1.82217288]]v2: [[-0.98012197 1.2369734 0.5797025 ] [ 2.50458145 0.81897354 0.07858191]]Model Restored
The code to load the model is basically the same as the code to save the model. It also defines all the operations on the TensorFlow computation graph and declares a tf.train.Saver class. The only difference between the two paragraphs is that instead of running initialization of variables in the code that loads the model, the values of variables are loaded through the saved model.
That is to say, TensorFlow completes a model save and read operation.
If you do not want to duplicate operations on a graph definition, you can also load persistent graphs directly:
import tensorflow as tf#In the following code, all variables defined on the TensorFlow calculation graph are loaded by default #Load the persisted graph directly saver = tf.train.import_meta_graph("save/model.ckpt.meta")with tf.Session() as sess: saver.restore(sess, "save/model.ckpt") #Get Tensor by Tensor Name print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
Run the program, output:
[[ 0.76705766 1.82217288]]
Sometimes you need to save or load only a few variables.
For example, if you have a previously trained five-layer neural network model, but now you want to write a six-layer neural network, you can load the parameters from the previous five-layer neural network directly into the new model, and only retrain the last layer neural network.
In order to save or load some variables, you can provide a list of variables to save or load when declaring the tf.train.Saver class. For example, if you use saver = tf.train.Saver([v1]) to build tf.train.Saver in the code that loads the model, only the variable v1 will be loaded.
At this point, I believe that everyone has a deeper understanding of "what is the method of saving and reading TensorFlow neural network model in python". Let's actually operate it! Here is the website, more related content can enter the relevant channels for inquiry, pay attention to us, continue to learn!
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.