01月04, 2020

Python模型的保存、加载和微调

与tensorflow高度抽象化的API设计不同,pytorch模型的加载和保存直接调用函数torch.load/torch.save就可以

举例

下面是一个简单的例子

import torch
import torch.nn as nn

class Net_old(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(2, 1, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(1, 1, 3)
        )
    def forward(self, x):
        return self.nets(x)

class Net_new(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 3)
        self.r1 = torch.nn.ReLU(True)
        self.conv2 = torch.nn.Conv2d(2, 1, 3)
        self.r2 = torch.nn.ReLU(True)
        self.conv3 = torch.nn.Conv2d(1, 1, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.r1(x)
        x = self.conv2(x)
        x = self.r2(x)
        x = self.conv3(x)
        return x

network = Net_old()
torch.save(network.cpu().state_dict(), 't.pth')

pretrained_net = torch.load('t.pth')
print(pretrained_net)

for key, v in enumerate(pretrained_net):
    print(key, v,pretrained_net[v].shape)

输出:

OrderedDict([('nets.0.weight', tensor([[[[ 0.3163, -0.0568,  0.0468],
          [ 0.2568,  0.2733,  0.2204],
          [ 0.2367, -0.2275, -0.0172]]],


        [[[ 0.1718,  0.0443,  0.0478],
          [ 0.2824,  0.0429,  0.2064],
          [-0.0866, -0.1060,  0.1226]]]])), ('nets.0.bias', tensor([-0.0727,  0.1102])), ('nets.2.weight', tensor([[[[-0.0639, -0.2025,  0.0956],
          [-0.0614,  0.2054, -0.0764],
          [ 0.1647,  0.1547, -0.0313]],

         [[-0.2207,  0.0536,  0.2036],
          [ 0.1053, -0.1585, -0.1922],
          [-0.1237, -0.0214, -0.1206]]]])), ('nets.2.bias', tensor([0.1438])), ('nets.4.weight', tensor([[[[-0.2930, -0.3210, -0.0247],
          [-0.1606, -0.2331, -0.3283],
          [ 0.3287,  0.1118, -0.2206]]]])), ('nets.4.bias', tensor([-0.0335]))])
0 nets.0.weight torch.Size([2, 1, 3, 3])
1 nets.0.bias torch.Size([2])
2 nets.2.weight torch.Size([1, 2, 3, 3])
3 nets.2.bias torch.Size([1])
4 nets.4.weight torch.Size([1, 1, 3, 3])
5 nets.4.bias torch.Size([1])

改变模型的位置

  1. cpu -> cpu或者gpu -> gpu:
checkpoint = torch.load('modelparameters.pth')

model.load_state_dict(checkpoint)
  1. cpu -> gpu 1

torch.load('modelparameters.pth', map_location=lambda storage, loc: storage.cuda(1))

  1. gpu 1 -> gpu 0

torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})

  1. gpu -> cpu

torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)

本文链接:http://57km.cc/post/save and load pytorch.html

-- EOF --

Comments