File size: 2,317 Bytes
60465e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# Source: https://github.com/facebookresearch/ToMe/blob/main/tome/utils.py
# --------------------------------------------------------

import time
from typing import Tuple
import torch
from tqdm import tqdm


def benchmark(

        model: torch.nn.Module,

        device: torch.device = 0,

        input_size: Tuple[int] = (3, 224, 224),

        batch_size: int = 64,

        runs: int = 40,

        throw_out: float = 0.25,

        use_fp16: bool = False,

        verbose: bool = False,

) -> float:
    """

    Benchmark the given model with random inputs at the given batch size.



    Args:

     - model: the module to benchmark

     - device: the device to use for benchmarking

     - input_size: the input size to pass to the model (channels, h, w)

     - batch_size: the batch size to use for evaluation

     - runs: the number of total runs to do

     - throw_out: the percentage of runs to throw out at the start of testing

     - use_fp16: whether or not to benchmark with float16 and autocast

     - verbose: whether or not to use tqdm to print progress / print throughput at end



    Returns:

     - the throughput measured in images / second

    """
    
    if not isinstance(device, torch.device):
        device = torch.device(device)
    is_cuda = torch.device(device).type == "cuda"

    model = model.eval().to(device)
    input = torch.rand(batch_size, *input_size, device=device)
    if use_fp16:
        input = input.half()

    warm_up = int(runs * throw_out)
    total = 0
    start = time.time()

    with torch.autocast(device.type, enabled=use_fp16):
        with torch.no_grad():
            for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
                if i == warm_up:
                    if is_cuda:
                        torch.cuda.synchronize()
                    total = 0
                    start = time.time()

                model(input)
                total += batch_size

    if is_cuda:
        torch.cuda.synchronize()

    end = time.time()
    elapsed = end - start

    throughput = total / elapsed

    if verbose:
        print(f"Throughput: {throughput:.2f} im/s")

    return throughput