[mpact][benchmarks] set up benchmark framework (#25)
diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 2d67cdc..4256292 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt
@@ -2,4 +2,24 @@ # The MPACT Compiler Benchmarks #------------------------------------------------------------------------------- -# TODO(yinying): add all our benchmarks under benchmark/python/* +declare_mlir_python_sources(MPACTBenchmarkPythonSources) + +declare_mlir_python_sources(MPACTBenchmarkPythonSources.BenchmarkSuite + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/python" + ADD_TO_PARENT MPACTBenchmarkPythonSources + SOURCES_GLOB + benchmarks/*.py +) + +declare_mlir_python_sources(MPACTBenchmarkPythonSources.BenchmarkUtils + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/python" + ADD_TO_PARENT MPACTBenchmarkPythonSources + SOURCES_GLOB + utils/*.py +) + +add_mlir_python_modules(MPACTBenchmarkPythonPythonModules + ROOT_PREFIX "${MPACT_PYTHON_PACKAGES_DIR}/mpact/mpact_benchmark" + INSTALL_PREFIX "python_packages/mpact/mpact_benchmark" + DECLARED_SOURCES MPACTBenchmarkPythonSources +)
diff --git a/benchmark/python/benchmarks/gcn_benchmark.py b/benchmark/python/benchmarks/gcn_benchmark.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/benchmark/python/benchmarks/gcn_benchmark.py
diff --git a/benchmark/python/benchmarks/kernels_benchmark.py b/benchmark/python/benchmarks/kernels_benchmark.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/benchmark/python/benchmarks/kernels_benchmark.py
diff --git a/benchmark/python/benchmarks/lif_benchmark.py b/benchmark/python/benchmarks/lif_benchmark.py new file mode 100644 index 0000000..5c789ea --- /dev/null +++ b/benchmark/python/benchmarks/lif_benchmark.py
@@ -0,0 +1,44 @@ +import torch +import numpy as np +from mpact.models.lif import Block +from mpact_benchmark.utils.benchmark_utils import benchmark, Backends + + +@benchmark( + [ + { + "name": f"{fmt}_{shape}_{dtype.__name__}", + "shape": shape, + "formats": [fmt], + "dtype": dtype, + # Simulate batch normalization. + "drange": (-1, 1), + "sparsity": [0, 0.5, 0.9, 0.99], + # to_dense() in LIF prop hack is not supported in torch inductor. + # TODO: add torch inductor once prop hack is no longer needed. + "backends": [ + b + for b in Backends + if b.value + not in ( + Backends.TORCH_SPARSE_INDUCTOR.value, + Backends.TORCH_DENSE_INDUCTOR.value, + ) + ], + } + for shape in [ + [[64, 3, 32, 32, 1]], + [[32, 3, 64, 64, 1]], + [[16, 3, 224, 224, 1]], + ] + for fmt in ["dense"] + for dtype in [np.float64] + ] +) +def SNN() -> torch.nn.Module: + """Spiking Neural Network.""" + return Block + + +if __name__ == "__main__": + SNN()
diff --git a/benchmark/python/benchmarks/resnet_benchmark.py b/benchmark/python/benchmarks/resnet_benchmark.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/benchmark/python/benchmarks/resnet_benchmark.py
diff --git a/benchmark/python/utils/benchmark_utils.py b/benchmark/python/utils/benchmark_utils.py new file mode 100644 index 0000000..275853b --- /dev/null +++ b/benchmark/python/utils/benchmark_utils.py
@@ -0,0 +1,217 @@ +import functools +import torch +from enum import Enum +from typing import Any, Callable +from torch.utils import benchmark as torch_benchmark +from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run +from mpact_benchmark.utils.tensor_generator import generate_inputs + + +class Backends(Enum): + TORCH_SPARSE_EAGER = 1 + TORCH_DENSE_EAGER = 2 + TORCH_SPARSE_INDUCTOR = 3 + TORCH_DENSE_INDUCTOR = 4 + MPACT_SPARSE = 5 + MPACT_DENSE = 6 + + +def timer(stmt: str, description: str, setup: str = "", **kwargs: Any) -> Any: + """Timer for benchmark.""" + return torch_benchmark.Timer( + stmt=stmt, + globals=kwargs["variables"], + setup=setup, + num_threads=1, + label=kwargs["variables"]["label"], + sub_label=kwargs["variables"]["sub_label"], + description=description, + ).adaptive_autorange() + + +def get_dynamo_compile_time(sub_label: str, label: str, description: str) -> Any: + """Get compile time from dynamo and create a benchmark measurement object.""" + try: + compile_time = torch_benchmark.Measurement( + 1, + [ + float( + torch._dynamo.utils.compile_times(repr="csv")[1][0] + .split(",")[-1] + .strip() + ) + ], + torch_benchmark.TaskSpec( + sub_label, + None, + description=description, + label=label, + ), + ) + return compile_time + except ValueError: + print(f"No compilation happened for {description}: {sub_label}.") + return None + + +def run_benchmark( + sparse_inputs: tuple[torch.Tensor, ...], + dense_inputs: tuple[torch.Tensor, ...], + torch_net: torch.nn.Module, + variables: dict[str, Any], + backends: tuple[Backends, ...], + runtime_results: list[torch_benchmark.Measurement], + compile_time_results: list[torch_benchmark.Measurement], +): + """Run benchmark with specified backends.""" + output = [] + + with torch.no_grad(): + for backend in backends: + match backend: + case Backends.TORCH_SPARSE_EAGER: + output.append(torch_net(*sparse_inputs)) + runtime_results.append( + timer( + "torch_net(*sparse_inputs)", + "torch-sparse-eager", + variables=variables, + ) + ) + case Backends.TORCH_DENSE_EAGER: + output.append(torch_net(*dense_inputs)) + runtime_results.append( + timer( + "torch_net(*dense_inputs)", + "torch-dense-eager", + variables=variables, + ) + ) + case Backends.TORCH_SPARSE_INDUCTOR: + torch_inductor = torch.compile(torch_net) + torch_out = torch_inductor(*sparse_inputs) + output.append(torch_out) + compile_time = get_dynamo_compile_time( + variables["sub_label"], + variables["label"], + "torch-sparse-inductor-compile", + ) + if compile_time: + compile_time_results.append(compile_time) + runtime_results.append( + timer( + "torch_inductor(*sparse_inputs)", + "torch-sparse-inductor-runtime", + variables=dict(variables, **locals()), + ) + ) + case Backends.TORCH_DENSE_INDUCTOR: + torch_inductor = torch.compile(torch_net) + output.append(torch_inductor(*dense_inputs)) + compile_time = get_dynamo_compile_time( + variables["sub_label"], + variables["label"], + "torch-dense-inductor-compile", + ) + if compile_time: + compile_time_results.append(compile_time) + runtime_results.append( + timer( + "torch_inductor(*dense_inputs)", + "torch-dense-inductor-runtime", + variables=dict(variables, **locals()), + ) + ) + case Backends.MPACT_SPARSE: + output.append( + torch.from_numpy(mpact_jit(torch_net, *sparse_inputs)) + ) + invoker, f = mpact_jit_compile(torch_net, *sparse_inputs) + compile_time_results.append( + timer( + "mpact_jit_compile(torch_net, *sparse_inputs)", + "mpact-sparse-compile", + "from mpact.mpactbackend import mpact_jit_compile", + variables=dict(variables, **locals()), + ) + ) + runtime_results.append( + timer( + "mpact_jit_run(invoker, f, *sparse_inputs)", + "mpact-sparse-runtime", + "from mpact.mpactbackend import mpact_jit_run", + variables=dict(variables, **locals()), + ) + ) + case Backends.MPACT_DENSE: + output.append(torch.from_numpy(mpact_jit(torch_net, *dense_inputs))) + invoker, f = mpact_jit_compile(torch_net, *dense_inputs) + compile_time_results.append( + timer( + "mpact_jit_compile(torch_net, *dense_inputs)", + "mpact-dense-compile", + "from mpact.mpactbackend import mpact_jit_compile", + variables=dict(variables, **locals()), + ) + ) + runtime_results.append( + timer( + "mpact_jit_run(invoker, f, *dense_inputs)", + "mpact-dense-runtime", + "from mpact.mpactbackend import mpact_jit_run", + variables=dict(variables, **locals()), + ) + ) + case _: + print(f"{backend} is not supported yet.") + + # Sanity check. + if output: + assert all(output[0].to_dense().allclose(out.to_dense()) for out in output) + + +def benchmark(*args: Any) -> Callable: + """Wrapper for benchmark.""" + + def decorator(func): + @functools.wraps(func) + def wrapper(test_cases=args[0]): + net = func() + runtime_results = [] + compile_time_results = [] + for test_case in test_cases: + label = func.__name__ + for sparsity in test_case["sparsity"]: + sub_label = f"{test_case['name']}_{sparsity}" + dense_inputs, sparse_inputs = generate_inputs( + test_case["shape"], + sparsity, + test_case["formats"], + test_case["dtype"], + test_case["drange"], + ) + if "GCN" in label: + torch_net = net(*test_case["shape"][0]) + else: + torch_net = net() + + run_benchmark( + sparse_inputs, + dense_inputs, + torch_net, + locals(), + test_case["backends"], + runtime_results, + compile_time_results, + ) + + compare1 = torch_benchmark.Compare(runtime_results) + compare1.print() + compare2 = torch_benchmark.Compare(compile_time_results) + compare2.print() + + return func + + return wrapper + + return decorator
diff --git a/benchmark/python/utils/tensor_generator.py b/benchmark/python/utils/tensor_generator.py new file mode 100644 index 0000000..5f9d304 --- /dev/null +++ b/benchmark/python/utils/tensor_generator.py
@@ -0,0 +1,74 @@ +import torch +import math +import numpy as np +from typing import Any + + +def generate_inputs( + shapes: tuple[Any, ...], + sparsity: float, + formats: tuple[str, ...], + dtype: Any = np.float64, + drange: tuple[Any, ...] = (1, 100), +) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: + """Generates dense and sparse tensor inputs. + + Args: + shapes: Shape for each input. + sparsity: Sparsity level for the inputs. + formats: Sparsity format for each input. + dtype: Data type of the generated inputs. Default is np.float64. + drange: Data range of the non-zero values. Default is (1, 100). + + Returns: + dense_inputs: all dense tensors. + sparse_inputs: inputs are of the specified sparsity format, such as CSR. + """ + dense_inputs = [] + sparse_inputs = [] + # Each input has a different seed. + for seed, shape in enumerate(shapes): + dense_inputs.append(generate_tensor(seed, shape, sparsity, dtype, drange)) + for idx, dense_input in enumerate(dense_inputs): + if formats[idx] == "dense": + sparse_inputs.append(dense_input) + else: + # TODO: support more sparsity formats. + sparse_inputs.append(dense_input.to_sparse_csr()) + return dense_inputs, sparse_inputs + + +def generate_tensor( + seed: int, + shape: tuple[Any, ...], + sparsity: float, + dtype: Any = np.float64, + drange: tuple[Any, ...] = (1, 100), +) -> torch.Tensor: + """Generates a tensor given sparsity level, shape and data type. + + Args: + seed: Seed value for np.random. + shape: A tuple for the shape of tensor. + sparsity: Sparsity level in the range of [0, 1]. + dtype: Data type of the generated tensor. Default is np.float64. + drange: Data range of the non-zero values. Default is (1, 100). + + Returns: + A dense torch tensor with the specified shape, sparsity level and type. + + Note: the tensor generated doesn't guarantee each batch will have the same + number of specified elements. Therefore, for batched CSR, torch.cat can be + used to concatenate generated tensors in the specified dimension. + """ + np.random.seed(seed) + size = math.prod(shape) + nse = size - int(math.ceil(sparsity * size)) + + flat_output = np.zeros(size) + indices = np.random.choice(size, nse, replace=False) + values = np.random.uniform(drange[0], drange[1], nse) + flat_output[indices] = values + + result = np.reshape(flat_output, shape).astype(dtype) + return torch.from_numpy(result)