pytorch如何保存和加载模型

浪矢\n / 2023-08-07 / 原文

两种方法:保存和加载参数  和  保存加载整个模型

保存和加载参数 

#保存
torch.save(model.state_dict,PATH)  #PATH推荐格式为.pt
#加载
model=TheModelClass(*args, **kwargs )
model.load_state_dict(torch.load(PATH) )

保存加载整个模型 

#保存
torch.save(model,PATH)
#加载
model = torch.load(PATH)