PyTorch 模型转换为 tritonscript model.pt 格式

创建日期: 2024-08-12 17:19 | 作者: 风波 | 浏览次数: 19 | 分类: PyTorch

1, 保存模型

来源:https://pytorch.org/tutorials/beginner/saving_loading_models.html

torch.save(model, PATH)

2. 转换为 model.pt tritonscript 格式

来源:https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/PyTorch/export.py

import torch

torch.hub._validate_not_a_forked_repo = lambda a, b, c: True

model = (
    torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True)
    .eval()
    .to("cuda")
)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224).to("cuda"))
torch.jit.save(traced_model, "model.pt")

主要的功能是最后面的两行

traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224).to("cuda"))
torch.jit.save(traced_model, "model.pt")
19 浏览
8 爬虫
0 评论