tensorflow2.0保存和恢复模型3种方法

所属分类: 脚本专栏 / python 阅读数: 1220
收藏 0 赞 0 分享

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

更多精彩内容其他人还在看

Python实现图像几何变换

这篇文章主要介绍了Python实现图像几何变换的方法,实例分析了Python基于Image模块实现图像翻转、旋转、改变大小等操作的相关技巧,非常简单实用,需要的朋友可以参考下
收藏 0 赞 0 分享

Python中的urllib模块使用详解

这篇文章主要介绍了Python中的urllib模块使用详解,是Python入门学习中的基础知识,需要的朋友可以参考下
收藏 0 赞 0 分享

Python的多态性实例分析

这篇文章主要介绍了Python的多态性,以实例形式深入浅出的分析了Python在面向对象编程中多态性的原理与实现方法,需要的朋友可以参考下
收藏 0 赞 0 分享

python生成IP段的方法

这篇文章主要介绍了python生成IP段的方法,涉及Python文件读写及随机数操作的相关技巧,具有一定参考借鉴价值,需要的朋友可以参考下
收藏 0 赞 0 分享

python操作redis的方法

这篇文章主要介绍了python操作redis的方法,包括Python针对redis的连接、设置、获取、删除等常用技巧,具有一定参考借鉴价值,需要的朋友可以参考下
收藏 0 赞 0 分享

python妹子图简单爬虫实例

这篇文章主要介绍了python妹子图简单爬虫,实例分析了Python爬虫程序所涉及的页面源码获取、进度显示、正则匹配等技巧,需要的朋友可以参考下
收藏 0 赞 0 分享

分析用Python脚本关闭文件操作的机制

这篇文章主要介绍了分析用Python脚本关闭文件操作的机制,作者分Python2.x版本和3.x版本两种情况进行了阐述,需要的朋友可以参考下
收藏 0 赞 0 分享

python实现搜索指定目录下文件及文件内搜索指定关键词的方法

这篇文章主要介绍了python实现搜索指定目录下文件及文件内搜索指定关键词的方法,可实现针对文件夹及文件内关键词的搜索功能,需要的朋友可以参考下
收藏 0 赞 0 分享

python中getaddrinfo()基本用法实例分析

这篇文章主要介绍了python中getaddrinfo()基本用法,实例分析了Python中使用getaddrinfo方法进行IP地址解析的基本技巧,需要的朋友可以参考下
收藏 0 赞 0 分享

python查找指定具有相同内容文件的方法

这篇文章主要介绍了python查找指定具有相同内容文件的方法,涉及Python针对文件操作的相关技巧,需要的朋友可以参考下
收藏 0 赞 0 分享
查看更多