pytorch中获取模型input/output shape实例

所属分类: 脚本专栏 / python 阅读数: 1500
收藏 0 赞 0 分享

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

更多精彩内容其他人还在看

QML用PathView实现轮播图

这篇文章主要为大家详细介绍了QML用PathView实现轮播图,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
收藏 0 赞 0 分享

Opencv图像处理:如何判断图片里某个颜色值占的比例

这篇文章主要介绍了Opencv图像处理:如何判断图片里某个颜色值占的比例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

python golang中grpc 使用示例代码详解

这篇文章主要介绍了python golang中grpc 使用,本文通过示例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
收藏 0 赞 0 分享

浅谈python opencv对图像颜色通道进行加减操作溢出

这篇文章主要介绍了浅谈python opencv对图像颜色通道进行加减操作溢出,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

解决python运行启动报错问题

这篇文章主要介绍了解决python运行启动报错问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

Python常见反爬虫机制解决方案

这篇文章主要介绍了Python常见反爬虫机制解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
收藏 0 赞 0 分享

解决pycharm导入本地py文件时,模块下方出现红色波浪线的问题

这篇文章主要介绍了解决pycharm导入本地py文件时,模块下方出现红色波浪线的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

pycharm设置默认的UTF-8编码模式的方法详解

这篇文章主要介绍了pycharm设置默认的UTF-8编码模式,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
收藏 0 赞 0 分享

浅谈Pycharm的项目文件名是红色的原因及解决方式

这篇文章主要介绍了浅谈Pycharm的项目文件名是红色的原因及解决方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
收藏 0 赞 0 分享

Python网络爬虫四大选择器用法原理总结

这篇文章主要介绍了Python网络爬虫四大选择器用法原理总结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
收藏 0 赞 0 分享
查看更多