tensorflow 固定部分参数训练,只训练部分参数的实例

所属分类: 脚本专栏 / python 阅读数: 1231
收藏 0 赞 0 分享

在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练。

1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如

def inference(input_):
  """
  this is where you put your graph.
  the following is just an example.
  """
  
  conv1 = tf.layers.conv2d(input_)
 
  conv2 = tf.layers.conv2d(conv1)
 
  return conv2
 
 
input_ = tf.placeholder()
output = inference(input_)
...
calculate_loss_op = ...
train_op = ...
...
 
with tf.Session() as sess:
  sess.run([loss, train_op], feed_dict={input_: train_data})
 
  if validation == True:
    sess.run([loss], feed_dict={input_: validate_date})

这种方式很简单,也很直接了然。

2.但是,如果处理的数据量很大的时候,使用 tf.placeholder() 来载入数据会严重地拖慢训练的进度,因此,常用tfrecords文件来读取数据。

此时,很容易想到,将不同的值传入inference()函数中进行计算。

train_batch, label_batch = decode_train()
val_train_batch, val_label_batch = decode_validation()
 
 
train_result = inference(train_batch)
...
loss = ..
train_op = ...
...
 
if validation == True:
  val_result = inference(val_train_batch)
  val_loss = ..
  
 
with tf.Session() as sess:
  sess.run([loss, train_op])
 
  if validation == True:
    sess.run([val_result, val_loss])

这种方式看似能够直接调用inference()来对验证数据进行前向传播计算,但是,实则会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。

3.用一个tf.placeholder来控制是否训练、验证。

def inference(input_):
  ...
  ...
  ...
  
  return inference_result
 
 
train_batch, label_batch = decode_train()
val_batch, val_label = decode_validation()
 
is_training = tf.placeholder(tf.bool, shape=())
 
x = tf.cond(is_training, lambda: train_batch, lambda: val_batch)
y = tf.cond(is_training, lambda: train_label, lambda: val_label)
 
logits = inference(x)
loss = cal_loss(logits, y)
train_op = optimize(loss)
 
with tf.Session() as sess:
  
  loss, _ = sess.run([loss, train_op], feed_dict={is_training: True})
  
  if validation == True:
    loss = sess.run(loss, feed_dict={is_training: False})

使用这种方式就可以在一个大图里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。

以上这篇tensorflow 固定部分参数训练,只训练部分参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

更多精彩内容其他人还在看

Python实现按学生年龄排序的实际问题详解

这篇文章主要给大家介绍了关于Python实现按学生年龄排序实际问题的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面跟着小编来一起学习学习吧。
收藏 0 赞 0 分享

Python开发的HTTP库requests详解

Requests是用Python语言编写,基于urllib,采用Apache2 Licensed开源协议的HTTP库。它比urllib更加方便,可以节约我们大量的工作,完全满足HTTP测试需求。Requests的哲学是以PEP 20 的习语为中心开发的,所以它比urllib更加P
收藏 0 赞 0 分享

Python网络爬虫与信息提取(实例讲解)

下面小编就为大家带来一篇Python网络爬虫与信息提取(实例讲解)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
收藏 0 赞 0 分享

在python3环境下的Django中使用MySQL数据库的实例

下面小编就为大家带来一篇在python3环境下的Django中使用MySQL数据库的实例。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
收藏 0 赞 0 分享

Python 3.x读写csv文件中数字的方法示例

在我们日常开发中经常需要对csv文件进行读写,下面这篇文章主要给大家介绍了关于Python 3.x读写csv文件中数字的相关资料,文中通过示例代码介绍的非常详细,对大家具有一定的参考学习价值,需要的朋友们下面跟着小编来一起学习学习吧。
收藏 0 赞 0 分享

Python实现解析Bit Torrent种子文件内容的方法

这篇文章主要介绍了Python实现解析Bit Torrent种子文件内容的方法,结合实例形式分析了Python针对Torrent文件的读取与解析相关操作技巧与注意事项,需要的朋友可以参考下
收藏 0 赞 0 分享

Python实现文件内容批量追加的方法示例

这篇文章主要介绍了Python实现文件内容批量追加的方法,结合实例形式分析了Python文件的读写相关操作技巧,需要的朋友可以参考下
收藏 0 赞 0 分享

Python简单实现自动删除目录下空文件夹的方法

这篇文章主要介绍了Python简单实现自动删除目录下空文件夹的方法,涉及Python针对文件与目录的读取、判断、删除等相关操作技巧,需要的朋友可以参考下
收藏 0 赞 0 分享

简单学习Python多进程Multiprocessing

这篇文章主要和大家一起简单的学习Python多进程Multiprocessing ,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
收藏 0 赞 0 分享

Python导入模块时遇到的错误分析

这篇文章主要给大家详细解释了在Python处理导入模块的时候出现错误以及具体的情况分析,非常的详尽,有需要的小伙伴可以参考下
收藏 0 赞 0 分享
查看更多