持久化代码实现
通过tf.train.Saver类来保存和还原一个神经网络,模型文件目录下会出现三个文件。这是因为Tensorflow会将计算图的结构和图上参数取值分开保存。
- model.ckpt.meta,保存了计算图的结构。
- model.ckpt,保存程序中每一个变量的取值。
- checkpoint,保存了一个目录下所有的模型文件列表。
加载已经保存的Tensorflow模型方法。1.使用和保存模型代码中一样的方法来声明变量。2.加载已经保存的模型。sever.restore(sess,”.ckpt”)
加载模型的程序也是定义了Tensorflow计算图上的所有运算,并声明了一个tf.train.Saver类。区别在于加载模型的代码中没有运行变量的初始化过程而是将变量的值通过已经保存的模型加载了进来。也可以直接加载已经持久化的图
saver=tf.train.import_meta_graph(…..meta)
函数tf.get_default_graph().get_tensor_by_name(“add:0”)可以通过张量的名称获取张量。
也可以声明tf.train.Saver类时提供一个列表指定需要保存或者加载的变量。同样可以在保存和加载时使用字典给变量重命名。
使用Saver会保存运行程序所需的全部信息,然而有时不需要某些信息。在测试或者离线预测时,不需要某些辅助节点的信息。且多个文件存储时也并不方便。convert_variables_to_constants将计算图中的变量及其取值通过常亮保存。
导出当前计算图的GraphDef部分只需要这一部分就可以完成从输入层到输出层的计算过程。
graph_def=tf.get_default_graph().as_graph_def()
将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉。一些如变量初始化操作的系统运算也会被转化为计算图的节点。可以通过【】指定需要保存的操作。
output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,[‘add’])//add为节点名
将导出模型存入文件:
with tf.gfile.GFile(“…pb”,”wb”) as f:
f.write(output_graph_def.SerialzeToString())
加载模型:
with gflie.FastGFile(model_filename//.pb,’rb’) as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
result=tf.import_graph_def(graph_def,return_elements=[“add:0”])//add:0为一个张量
sess.run(result)
Saver持久化原理及数据格式
Tensorflow通过元图(MetaGraph)来记录计算图中节点信息以及运行计算图中节点所需要的元数据。
由Protocol Buffer定义,记录了五类信息:
- meta_info_def属性,记录计算图中的元数据以及所有使用到运算方法的信息。
- graph_def属性,记录计算图的节点信息。
- saver_def属性,记录了持久化模型时需要用到的一些参数。
- collection_def属性维护集合的底层实现是通过collection_def这个属性。
5 signature_def属性。
model.ckpt保存所有变量的取值,通过SSTable格式存储,大致为一个(key,value)列表。
checkpoint是Saver类自动生成自动维护的。当某个保存的TensorFlow模型文件被删除时,这个模型对应的文件名也会从checkpoint文件中删除。