PyTorch 加速数据读取, 提高 GPU 利用率 - prefetch_generator

创建日期: 2023-02-01 17:35 | 作者: 风波 | 浏览次数: 17 | 分类: PyTorch

来源: - https://blog.csdn.net/shanglianlm/article/details/113354959 - https://github.com/shanglianlm0525/CvPytorch

1. 方案一:使用 prefetch_generator

安装

pip install prefetch_generator

PyTorch 默认的 DataLoader 会创建一些 worker 线程来预读取新的数据,但是除非这些线程的数据全部都被清空,这些线程才会读下一批数据。 使用 prefetch_generator,我们可以保证线程不会等待,每个线程都总有至少一个数据在加载。

使用

import torch
from torch.utils.data import DataLoader
from prefetch_generator import BackgroundGenerator

class PrefetchDataLoader(DataLoader):
    '''
        replace DataLoader with PrefetchDataLoader
    '''
    def __iter__(self):
        return BackgroundGenerator(super().__iter__())

2. 方案二:使用 data_prefetcher

使用 data_prefetcher 新开 cuda stream 来拷贝 tensor 到 gpu。 默认情况下,PyTorch 将所有涉及到 GPU 的操作(比如内核操作,cpu->gpu,gpu->cpu)都排入同一个 stream(default stream)中,并对同一个流的操作序列化,它们永远不会并行。要想并行,两个操作必须位于不同的 stream 中。

而前向传播位于 default stream 中,因此,要想将下一个 batch 数据的预读取(涉及 cpu->gpu)与当前 batch 的前向传播并行处理,就必须:

  1. cpu 上的数据 batch 必须 pinned;
  2. 预读取操作必须在另一个 stream 上进行
class DataPrefetcher(object):
    '''
        prefetcher = DataPrefetcher(train_loader, device=self.device)
        batch = prefetcher.next()
        iter_id = 0
        while batch is not None:
            iter_id += 1
            if iter_id >= num_iters:
                break
            run_step()
            batch = prefetcher.next()
    '''
    def __init__(self, loader, device):
        self.loader = loader
        self.dataset = loader.dataset
        self.stream = torch.cuda.Stream()
        self.next_input = None
        self.next_target = None
        self.device = device

    def __len__(self):
        return len(self.loader)

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loaditer)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(device=self.device, non_blocking=True)
            self.next_target = self.next_target.cuda(device=self.device, non_blocking=True)

    def __iter__(self):
        count = 0
        self.loaditer = iter(self.loader)
        self.preload()
        while self.next_input is not None:
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            self.preload()
            count += 1
            yield input, target
17 浏览
10 爬虫
0 评论