1, 保存模型
来源:https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save(model, PATH)
2. 转换为 model.pt
tritonscript 格式
来源:https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/PyTorch/export.py
import torch
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
model = (
torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True)
.eval()
.to("cuda")
)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224).to("cuda"))
torch.jit.save(traced_model, "model.pt")
主要的功能是最后面的两行
traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224).to("cuda"))
torch.jit.save(traced_model, "model.pt")