将Pytorch模型从CPU转换成GPU的实现方法

所属分类: 脚本专栏 / python 阅读数: 730
收藏 0 赞 0 分享

最近将Pytorch程序迁移到GPU上去的一些工作和思考

环境:Ubuntu 16.04.3

Python版本:3.5.2

Pytorch版本:0.4.0

0. 序言

大家知道,在深度学习中使用GPU来对模型进行训练是可以通过并行化其计算来提高运行效率,这里就不多谈了。

最近申请到了实验室的服务器来跑程序,成功将我简陋的程序改成了“高大上”GPU版本。

看到网上总体来说少了很多介绍,这里决定将我的一些思考和工作记录下来。

1. 如何进行迁移

由于我使用的是Pytorch写的模型,网上给出了一个非常简单的转换方式: 对模型和相应的数据进行.cuda()处理。通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去。从而可以通过GPU来进行运算了。

网上说的非常简单,但是实际使用过程中还是遇到了一些疑惑。下面分数据和模型两方面的迁移来进行说明介绍。

1.1 判定使用GPU

下载了对应的GPU版本的Pytorch之后,要确保GPU是可以进行使用的,通过torch.cuda.is_available()的返回值来进行判断。返回True则具有能够使用的GPU。

通过torch.cuda.device_count()可以获得能够使用的GPU数量。其他就不多赘述了。

常常通过如下判定来写可以跑在GPU和CPU上的通用模型:

if torch.cuda.is_available():
  ten1 = ten1.cuda()
  MyModel = MyModel.cuda() 

2. 对应数据的迁移

数据方面常用的主要是两种 —— Tensor和Variable。实际上这两种类型是同一个东西,因为Variable实际上只是一个容器,这里先视其不同。

2.1 将Tensor迁移到显存中去

不论是什么类型的Tensor(FloatTensor或者是LongTensor等等),一律直接使用方法.cuda()即可。

例如:

ten1 = torch.FloatTensor(2)
>>>> 6.1101e+24
   4.5659e-41
   [torch.FloatTensor of size 2]

ten1_cuda = ten1.cuda()
>>>>  6.1101e+24
    4.5659e-41
    [torch.cuda.FloatTensor of size 2 (GPU 0)]

其数据类型会由torch.FloatTensor变为torch.cuda.FloatTensor (GPU 0)这样代表这个数据现在存储在

GPU 0的显存中了。

如果要将显存中的数据复制到内存中,则对cuda数据类型使用.cpu()方法即可。

2.2 将Variable迁移到显存中去

在模型中,我们最常使用的是Variable这个容器来装载使用数据。主要是由于Variable可以进行反向传播来进行自动求导。

同样地,要将Variable迁移到显存中,同样只需要使用.cuda()即可实现。

这里有一个小疑问,对Variable直接使用.cuda和对Tensor进行.cuda然后再放置到Variable中结果是否一致呢。答案是肯定的。

ten1 = torch.FloatTensor(2)
>>> 6.1101e+24
   4.5659e-41
  [torch.FloatTensor of size 2]

ten1_cuda = ten1.cuda()
>>>> 6.1101e+24
   4.5659e-41
  [torch.cuda.FloatTensor of size 2 (GPU 0)]

V1_cpu = autograd.Variable(ten1)
>>>> Variable containing:
   6.1101e+24
   4.5659e-41
  [torch.FloatTensor of size 2]

V2 = autograd.Variable(ten1_cuda)
>>>> Variable containing:
   6.1101e+24
   4.5659e-41
  [torch.cuda.FloatTensor of size 2 (GPU 0)]

V1 = V1_cpu.cuda()
>>>> Variable containing:
   6.1101e+24
   4.5659e-41
  [torch.cuda.FloatTensor of size 2 (GPU 0)]

最终我们能发现他们都能够达到相同的目的,但是他们完全一样了吗?我们使用V1 is V2发现,结果是否定的。

对于V1,我们是直接对Variable进行操作的,这样子V1的.grad_fn中会记录下创建的方式。因此这二者并不是完全相同的。

2.3 数据迁移小结

.cuda()操作默认使用GPU 0也就是第一张显卡来进行操作。当我们想要存储在其他显卡中时可以使用.cuda(<显卡号数>)来将数据存储在指定的显卡中。还有很多种方式,具体参考官方文档。

对于不同存储位置的变量,我们是不可以对他们直接进行计算的。存储在不同位置中的数据是不可以直接进行交互计算的。

换句话说也就是上面例子中的torch.FloatTensor是不可以直接与torch.cuda.FloatTensor进行基本运算的。位于不同GPU显存上的数据也是不能直接进行计算的。

对于Variable,其实就仅仅是一种能够记录操作信息并且能够自动求导的容器,实际上的关键信息并不在Variable本身,而更应该侧重于Variable中存储的data。

3. 模型迁移

模型的迁移这里指的是torch.nn下面的一些网络模型以及自己创建的模型迁移到GPU上去。

上面讲了使用.cuda()即可将数据从内存中移植到显存中去。

对于模型来说,也是同样的方式,我们使用.cuda来将网络放到显存上去。

3.1 torch.nn下的基本模型迁移

这里使用基本的单层感知机来进行举例(线性模型)。

data1 = torch.FloatTensor(2)
data2 = data1.cuda

# 创建一个输入维度为2,输出维度为2的单层神经网络
linear = torch.nn.Linear(2, 2)
>>>> Linear(in_features=2, out_features=2)

linear_cuda = linear.cuda()
>>>> Linear(in_features=2, out_features=2)

我们很惊奇地发现对于模型来说,不像数据那样使用了.cuda()之后会改变其的数据类型。模型看起来没有任何的变化。

但是他真的没有改变吗。

我们将data1投入linear_cuda中去可以发现,系统会报错,而将.cuda之后的data2投入linear_cuda才能正常工作。并且输出的也是具有cuda的数据类型。

那是怎么一回事呢?

这是因为这些所谓的模型,其实也就是对输入参数做了一些基本的矩阵运算。所以我们对模型.cuda()实际上也相当于将模型使用到的参数存储到了显存上去。

对于上面的例子,我们可以通过观察参数来发现区别所在。

linear.weight
>>>> Parameter containing:
  -0.6847 0.2149
  -0.5473 0.6863
  [torch.FloatTensor of size 2x2]

linear_cuda.weight
>>>> Parameter containing:
  -0.6847 0.2149
  -0.5473 0.6863
  [torch.cuda.FloatTensor of size 2x2 (GPU 0)]

3.2 自己模型的迁移

对于自己创建的模型类,由于继承了torch.nn.Module,则可同样使用.cuda()来将模型中用到的所有参数都存储到显存中去。

这里笔者曾经有一个疑问:当我们对模型存储到显存中去之后,那么这个模型中的方法后面所创建出来的Tensor是不是都会默认变成cuda的数据类型。答案是否定的。具体操作留给读者自己去实现。

3.3 模型小结

对于模型而言,我们可以将其看做是一种类似于Variable的容器。我们对它进行.cuda()处理,是将其中的参数放到显存上去(因为实际使用的时候也是通过这些参数做运算)。

4. 总结

Pytorch使用起来直接简单,GPU的使用也是简单明了。然而对于多GPU和CPU的协同使用则还是有待提高。

以上这篇将Pytorch模型从CPU转换成GPU的实现方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

更多精彩内容其他人还在看

Python实现按学生年龄排序的实际问题详解

这篇文章主要给大家介绍了关于Python实现按学生年龄排序实际问题的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面跟着小编来一起学习学习吧。
收藏 0 赞 0 分享

Python开发的HTTP库requests详解

Requests是用Python语言编写,基于urllib,采用Apache2 Licensed开源协议的HTTP库。它比urllib更加方便,可以节约我们大量的工作,完全满足HTTP测试需求。Requests的哲学是以PEP 20 的习语为中心开发的,所以它比urllib更加P
收藏 0 赞 0 分享

Python网络爬虫与信息提取(实例讲解)

下面小编就为大家带来一篇Python网络爬虫与信息提取(实例讲解)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
收藏 0 赞 0 分享

在python3环境下的Django中使用MySQL数据库的实例

下面小编就为大家带来一篇在python3环境下的Django中使用MySQL数据库的实例。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
收藏 0 赞 0 分享

Python 3.x读写csv文件中数字的方法示例

在我们日常开发中经常需要对csv文件进行读写,下面这篇文章主要给大家介绍了关于Python 3.x读写csv文件中数字的相关资料,文中通过示例代码介绍的非常详细,对大家具有一定的参考学习价值,需要的朋友们下面跟着小编来一起学习学习吧。
收藏 0 赞 0 分享

Python实现解析Bit Torrent种子文件内容的方法

这篇文章主要介绍了Python实现解析Bit Torrent种子文件内容的方法,结合实例形式分析了Python针对Torrent文件的读取与解析相关操作技巧与注意事项,需要的朋友可以参考下
收藏 0 赞 0 分享

Python实现文件内容批量追加的方法示例

这篇文章主要介绍了Python实现文件内容批量追加的方法,结合实例形式分析了Python文件的读写相关操作技巧,需要的朋友可以参考下
收藏 0 赞 0 分享

Python简单实现自动删除目录下空文件夹的方法

这篇文章主要介绍了Python简单实现自动删除目录下空文件夹的方法,涉及Python针对文件与目录的读取、判断、删除等相关操作技巧,需要的朋友可以参考下
收藏 0 赞 0 分享

简单学习Python多进程Multiprocessing

这篇文章主要和大家一起简单的学习Python多进程Multiprocessing ,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
收藏 0 赞 0 分享

Python导入模块时遇到的错误分析

这篇文章主要给大家详细解释了在Python处理导入模块的时候出现错误以及具体的情况分析,非常的详尽,有需要的小伙伴可以参考下
收藏 0 赞 0 分享
查看更多