与tensorflow高度抽象化的API设计不同,pytorch模型的加载和保存直接调用函数torch.load/torch.save就可以
举例
下面是一个简单的例子
import torch
import torch.nn as nn
class Net_old(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.nets = nn.Sequential(
torch.nn.Conv2d(1, 2, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(2, 1, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(1, 1, 3)
)
def forward(self, x):
return self.nets(x)
class Net_new(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3)
self.r1 = torch.nn.ReLU(True)
self.conv2 = torch.nn.Conv2d(2, 1, 3)
self.r2 = torch.nn.ReLU(True)
self.conv3 = torch.nn.Conv2d(1, 1, 3)
def forward(self, x):
x = self.conv1(x)
x = self.r1(x)
x = self.conv2(x)
x = self.r2(x)
x = self.conv3(x)
return x
network = Net_old()
torch.save(network.cpu().state_dict(), 't.pth')
pretrained_net = torch.load('t.pth')
print(pretrained_net)
for key, v in enumerate(pretrained_net):
print(key, v,pretrained_net[v].shape)
输出:
OrderedDict([('nets.0.weight', tensor([[[[ 0.3163, -0.0568, 0.0468],
[ 0.2568, 0.2733, 0.2204],
[ 0.2367, -0.2275, -0.0172]]],
[[[ 0.1718, 0.0443, 0.0478],
[ 0.2824, 0.0429, 0.2064],
[-0.0866, -0.1060, 0.1226]]]])), ('nets.0.bias', tensor([-0.0727, 0.1102])), ('nets.2.weight', tensor([[[[-0.0639, -0.2025, 0.0956],
[-0.0614, 0.2054, -0.0764],
[ 0.1647, 0.1547, -0.0313]],
[[-0.2207, 0.0536, 0.2036],
[ 0.1053, -0.1585, -0.1922],
[-0.1237, -0.0214, -0.1206]]]])), ('nets.2.bias', tensor([0.1438])), ('nets.4.weight', tensor([[[[-0.2930, -0.3210, -0.0247],
[-0.1606, -0.2331, -0.3283],
[ 0.3287, 0.1118, -0.2206]]]])), ('nets.4.bias', tensor([-0.0335]))])
0 nets.0.weight torch.Size([2, 1, 3, 3])
1 nets.0.bias torch.Size([2])
2 nets.2.weight torch.Size([1, 2, 3, 3])
3 nets.2.bias torch.Size([1])
4 nets.4.weight torch.Size([1, 1, 3, 3])
5 nets.4.bias torch.Size([1])
改变模型的位置
- cpu -> cpu或者gpu -> gpu:
checkpoint = torch.load('modelparameters.pth')
model.load_state_dict(checkpoint)
- cpu -> gpu 1
torch.load('modelparameters.pth', map_location=lambda storage, loc: storage.cuda(1))
- gpu 1 -> gpu 0
torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})
- gpu -> cpu
torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)
Comments