您现在的位置是:主页 > news > 网站制作服务公司/免费网上销售平台

网站制作服务公司/免费网上销售平台

admin2025/6/17 7:46:14news

简介网站制作服务公司,免费网上销售平台,宁波网站建站的公司,广告设计专业考研保存模型主要分为两类:保存整个模型和只保存模型参数 1.保存加载整个模型(不推荐): 保存整个网络模型,网络结构权重参数 torch.save(model,net.pth) 加载整个网络模型(可能比较耗时) model…

网站制作服务公司,免费网上销售平台,宁波网站建站的公司,广告设计专业考研保存模型主要分为两类:保存整个模型和只保存模型参数 1.保存加载整个模型(不推荐): 保存整个网络模型,网络结构权重参数 torch.save(model,net.pth) 加载整个网络模型(可能比较耗时) model…

保存模型主要分为两类:保存整个模型和只保存模型参数

1.保存加载整个模型(不推荐):

   保存整个网络模型,网络结构+权重参数 

torch.save(model,'net.pth')

加载整个网络模型(可能比较耗时)

model=torch.load('net.pth')

2.只保存加载模型参数(推荐)

  保存模型的权重参数(速度快,占内存少)

torch.save(model.state_dict(),'net_params.pth')

load 模型参数

因为我们只保存了 模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。

model=ClassNet()

#将模型参数加载到新模型中,torch.load返回的是一个OrderedDict,说明.state_dict()只是把所有模型的参数都已OrderedDict的形式存下来。

state_dict=torch.load('net_params.pth')
model.load_state_dict(state_dict)

Note:保存模型进行推理测试时,只需保存训练好的模型的权重参数,即推荐第二种方法。

load_state_dict的参数strict=False

new_model.load_state_dict(state_dict,strict=False)

如果哪一天我们需要重新写这个网络的,比如使用new_model,如果直接load会出现unexpected key. 但是加上strict=False可以很容易地加载预训练的参数(注意检查key是否匹配),直接忽略不匹配的key,对于匹配的key则进行正常的赋值。 

3.保存加载自定义模型

上面“保存加载整个模型”加载的net.pt其实是一个字典,通常包含以下内容:

 网络结构:输入尺寸,输出尺寸以及隐藏层的信息,以便能够在加载时重建模型。

模型的权重参数,包含各网络层训练后的可学习参数,可以在模型实例上调用state_dict()方法来获取,比如只保存模型权重参数时用到的model.state_dict().

优化器参数:有时保存模型的参数需要稍后接着训练,那么就必须保存优化器的状态和其所使用的超参数,也是在优化器实例上调用state_dict()方法来获取这些参数。

其他信息:有时我们需要保存一些其他的信息,比如epoch, batch_size等超参数

我们可以自定义需要save的内容

#saving a checkpoint assuming the network class named ClassNet
checkpoint={'modle':ClassNet(),'model_state_dict':model.state_dict(),'optimize_state_dict':optimizer.state_dict(),'epoch':epoch}
torch.save(checkpoint,'checkpoint.pkl')

上面的checkpoint是个字典,里面有4各键值对,分别表示网络模型的不同信息。

然后我们要loda上面保存的自定义模型

def load_checkpoint(filepath):checkpoint=torch.load(filepath)def load_checkpoint(filepath):checkpoint = torch.load(filepath)model=checkpoint['model']#提前网络结构model.load_state_dict(checkpoint['model_state_dict'])#加载网络权重参数optimizer=TheOptimizerClass()optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#加载优化器参数for parameter in model.parameters():parameter.requires_grad=Falsemodel.eval()return model
modle=load_checkpoint('checkpoint.pkl')

后续使用

   如果加载模型只是为了进行推理测试,则将每一层的requires_grad置为False,即固定这些权重参数,还需要调用model.eval()将模型置为测试模式,主要是将dropout和batch normalization层进行固定,否则模型的预测结果每次都会不同。

  如果需要继续训练,则调用model.train(),以确保网络模型处于训练模式。

跨设备保存加载模型

在CPU上加载在GPU上训练并保存的模型(save on GPU, load on CPU

device=torch.device('cpu')
model=TheModelClass()
#load all tensors onto the CPU device
model.load_state_dict(torch.load('net_params.pkl',map_location=device))

令torch.load()函数的map_location参数等于torch.device('cpu')即可,这里令map_location参数等于‘cpu'也同样可以。