Keras: model实现固定部分layer,训练部分layer操作

所属分类: 脚本专栏 / python 阅读数: 985
收藏 0 赞 0 分享

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

更多精彩内容其他人还在看

python2.7无法使用pip的解决方法(安装easy_install)

下面小编就为大家分享一篇python2.7无法使用pip的解决方法(安装easy_install),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

Python实现的计算马氏距离算法示例

这篇文章主要介绍了Python实现的计算马氏距离算法,简单说明了马氏距离算法原理,并结合实例形式分析了Python实现与使用马氏距离算法的相关操作技巧,需要的朋友可以参考下
收藏 0 赞 0 分享

python逐行读写txt文件的实例讲解

下面小编就为大家分享一篇python逐行读写txt文件的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

python批量读取txt文件为DataFrame的方法

下面小编就为大家分享一篇python批量读取txt文件为DataFrame的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

Python通过调用mysql存储过程实现更新数据功能示例

这篇文章主要介绍了Python通过调用mysql存储过程实现更新数据功能,结合实例形式分析了Python调用mysql存储过程实现更新数据的具体步骤与相关操作技巧,需要的朋友可以参考下
收藏 0 赞 0 分享

Python实现的HMacMD5加密算法示例

这篇文章主要介绍了Python实现的HMacMD5加密算法,简单说明了HMAC-MD5加密算法的概念、原理并结合实例形式分析了Python实现HMAC-MD5加密算法的相关操作技巧,,末尾还附带了Java实现HMAC-MD5加密算法的示例,需要的朋友可以参考下
收藏 0 赞 0 分享

图解Python变量与赋值

Python是一门独特的语言,与C语言有很大区别,初学Python很多萌新表示对变量与赋值不理解,这里就大家介绍一下,需要的朋友可以参考下
收藏 0 赞 0 分享

Python中的并发处理之asyncio包使用的详解

本篇文章主要介绍了Python中的并发处理之asyncio包使用的详解,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
收藏 0 赞 0 分享

Python获取二维矩阵每列最大值的方法

下面小编就为大家分享一篇Python获取二维矩阵每列最大值的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

numpy找出array中的最大值,最小值实例

下面小编就为大家分享一篇numpy找出array中的最大值,最小值实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享
查看更多