使用 libtorch 加载加密的模型 c++

创建日期: 2023-06-16 15:53 | 作者: 风波 | 浏览次数: 12 | 分类: PyTorch

来源:https://discuss.pytorch.org/t/can-libtorch-load-model-from-memory/55226

struct MLP : public torch::nn::Module
{
    torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };

    MLP()
    {
        fc1 = register_module("fc1", torch::nn::Linear(2, 10));
        fc2 = register_module("fc2", torch::nn::Linear(10, 10));
        fc3 = register_module("fc3", torch::nn::Linear(10, 1));
    }

    torch::Tensor forward(torch::Tensor x)
    {
        x = torch::leaky_relu(fc1->forward(x));
        x = torch::leaky_relu(fc2->forward(x));
        return torch::sigmoid(fc3->forward(x));
    }

    std::vector<char> SaveMemory()
    {
        std::ostringstream oss;
        torch::serialize::OutputArchive archive;

        this->save(archive);
        archive.save_to(oss);

        std::string s = oss.str();

        const char* ptr = s.c_str();
        size_t length = s.size();
        std::vector<char> retval(ptr, ptr + length);

        return retval;
    }

    void LoadMemory(std::vector<char>& data, c10::Device device)
    {
        std::istringstream iss(std::string(data.begin(), data.end()));

        torch::serialize::InputArchive archive;
        archive.load_from(iss, device);
        this->load(archive);
    }
};

void main()
{
    try
    {
        auto testInput = torch::randn({ 1, 2 });

        std::vector<char> savedState;

        {
            auto net = std::make_shared<MLP>();
            auto output1 = net->forward(testInput);
            std::cout << "Output before save: " << output1 << std::endl;
            savedState = net->SaveMemory();
        }

        {
            auto net = std::make_shared<MLP>();
            auto output1 = net->forward(testInput);
            std::cout << "Output before load: " << output1 << std::endl;

            net->LoadMemory(savedState, c10::Device(c10::DeviceType::CPU));

            auto output2 = net->forward(testInput);
            std::cout << "Output after load: " << output2 << std::endl;
        }
    }
    catch (std::runtime_error& e)
    {
        std::cout << e.what() << std::endl;
    }
    catch (const c10::Error& e)
    {
        std::cout << e.msg() << std::endl;
    }

    system("PAUSE");
}
12 浏览
10 爬虫
0 评论