全国服务热线:18888889999
在线报名
欧陆注册CURRICULUM
欧陆资讯 NEWS CENTER
联系我们 CONTACT US
手机:
18888889999
电话:
0898-66889888
邮箱:
admin@youweb.com
地址:
海南省海口市玉沙路58号
欧陆资讯
你的位置: 首页 > 欧陆资讯
一文梳理pytorch保存和重载模型参数攻略
2024-07-08 21:34:26 点击量:

训练过程中保存模型参数,就不怕断电了——沃资基·索德


在训练完成之前,我们需要每隔一段时间保存模型当前参数值,一方面可以防止断电重跑,另一方面可以观察不同迭代次数模型的表现;在训练完成以后,我们需要保存模型参数值用于后续的测试过程。所以,保存的对象包含网络参数值、优化器参数值、epoch值等等。


一、定义一个容易识别的网络

在正式介绍模型的保存和加载之前,我们首先定义一个基本的网络Net,它只包含一个全连接层:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer = nn.Linear(1, 1)
        self.layer.weight = nn.Parameter(torch.FloatTensor([[10]]))
        self.layer.bias = nn.Parameter(torch.FloatTensor([1]))

    def forward(self, x):
         y = self.layer(x)
        return y

将全连接的权重w和偏差b分别设置为10和1,全连接的计算方式如下:

假设输入x=1,可以知道y值为11:

测试一下输出是不是11,代码如下:

x = torch.FloatTensor([[1]])
net = Net()
out = net(x)
print(out)

输出:tensor([[11.]], grad_fn=<AddmmBackward>),说明上述计算是正确的。不采用参数随机初始化,而是用特殊的数值初始化,是因为我们希望重载模型的时候,能够从特殊数值一眼判断出保存和重载过程是否正确,也可以把权重设置为一张图片数值,然后判断加载的参数值能不能恢复原图。


二、保存Net的参数值

保存模型参数之前,需要知道Net的参数值存储在其state_dict(状态字典)属性中,我们查看一下net的state_dict包含哪些参数:

print(net.state_dict())

我们将会得到net包含的所有参数名称与参数值

包含一个weight和一个bias,对应的值分别是10和1,和我们之前定义的全连接层一致。我们需要保存的就是这个state_dict,保存的函数为“torch.save()”,参数是我们需要保存的dict和存储路径

torch.save(obj=net.state_dict(), f="models/net.pth")

现在,同级目录models下将会出现net.pth文件,pth文件中的内容就是net的参数名称和值对应的state_dict,如下:


三、加载Net参数值并用于新的模型

最后一个步骤就是从pth文件中重新获取Net参数值,并把参数值装载到新定义的Model对象中。这里我们重新定义一个结构和Net类相同的类Model,区别仅仅是Model参数初始值和Net不同,代码如下:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer = nn.Linear(1, 1)
        self.layer.weight = nn.Parameter(torch.FloatTensor([[0]]))
        self.layer.bias = nn.Parameter(torch.FloatTensor([0]))

    def forward(self, x):
        out = self.layer(x)
        return out

这里将Model的初始值权重w和偏差都设置为0,查看其state_dict:

model = Model()
print(model.state_dict())

得到的w和b值与预期相同,均为0,如下:

现在,我们将model对象的参数值设置为net.pth中的值,需要使用“model.load_state_dict()”函数重置model的参数值为"torch.load(models/ net.pth)"中的参数值,如下:

model.load_state_dict(torch.load("models/net.pth"))
print(model.state_dict())

至此,model的w和b值就不再是0了,而是net中w和b对应的10和1,如下:

其中参数值重载的核心函数为“model.load_state_dict()”,每个继承自nn.Module的网络都能通过这个函数设定参数值。


四、优化器与epoch的保存

保存优化器参数值和epoch值的主要目的是用于继续训练,保存的流程依旧是先“torch.save()”再“torch.load_state_dict()”,我们首先定义一个Adam优化器、一个任意的epoch值与net如下:

net=Net()
Adam=optim.Adam(params=net.parameters(), lr=0.001, betas=(0.5, 0.999))
epoch=96

现在,创建一个字典来保存所有的对象,并用save函数保存这个字典

all_states = {"net": net.state_dict(), "Adam": Adam.state_dict(), "epoch": epoch}
torch.save(obj=all_states, f="models/all_states.pth")

所有的对象都被保存到models文件夹下了:

可以使用load()函数把所有的对象再次提取出来:

reload_states = torch.load("models/all_states.pth")
print(reload_states)

得到的所有参数如下:


五、总结

pytorch中state_dict()和load_state_dict()函数配合使用可以实现状态的获取与重载,load()和save()函数配合使用可以实现参数的存储与读取。其中最重要的部分是“字典”的概念,因为参数对象的存储是需要“名称”——“值”对应(即键值对),读取时也是通过键值对读取的。


参考:

pytorchtutorial.com/pyt

blog.csdn.net/Code_Mart