PyTorch 从内存中load 模型报错:EOFError Ran out of input

创建日期: 2023-07-03 15:19 | 作者: 风波 | 浏览次数: 17 | 分类: PyTorch

报错信息:EOFError: Ran out of input

加载模型报错的代码

def load_model():
    checkpoint_name = os.path.join(curtdir, "model.pt")
    if not os.path.isfile(checkpoint_name):
        raise Exception("file not exists: {}".format(checkpoint_name))

    # Load the checkpoint.
    bio = io.BytesIO()
    with open(checkpoint_name, 'rb') as f:
        for chunk in iter(lambda: f.read(1*1024*1024), b""):
            bio.write(chunk)
    sd = torch.load(bio)
    return sd

原因是 bio 此时的文件位置是在文件尾部,需要把文件读取头 seek 到文件开始位置。

修正后的代码:

def load_model():
    checkpoint_name = os.path.join(curtdir, "model.pt")
    if not os.path.isfile(checkpoint_name):
        raise Exception("file not exists: {}".format(checkpoint_name))

    # Load the checkpoint.
    bio = io.BytesIO()
    with open(checkpoint_name, 'rb') as f:
        for chunk in iter(lambda: f.read(1*1024*1024), b""):
            bio.write(chunk)
    bio.seek(0)
    sd = torch.load(bio)
    return sd
17 浏览
0 评论