| import functools |
| import torch |
| from enum import Enum |
| from typing import Any, Callable |
| from torch.utils import benchmark as torch_benchmark |
| from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run |
| from mpact_benchmark.utils.tensor_generator import generate_inputs |
| |
| |
| class Backends(Enum): |
| TORCH_SPARSE_EAGER = 1 |
| TORCH_DENSE_EAGER = 2 |
| TORCH_SPARSE_INDUCTOR = 3 |
| TORCH_DENSE_INDUCTOR = 4 |
| MPACT_SPARSE = 5 |
| MPACT_DENSE = 6 |
| |
| |
| def timer(stmt: str, description: str, setup: str = "", **kwargs: Any) -> Any: |
| """Timer for benchmark.""" |
| return torch_benchmark.Timer( |
| stmt=stmt, |
| globals=kwargs["variables"], |
| setup=setup, |
| num_threads=1, |
| label=kwargs["variables"]["label"], |
| sub_label=kwargs["variables"]["sub_label"], |
| description=description, |
| ).adaptive_autorange() |
| |
| |
| def get_dynamo_compile_time(sub_label: str, label: str, description: str) -> Any: |
| """Get compile time from dynamo and create a benchmark measurement object.""" |
| try: |
| compile_time = torch_benchmark.Measurement( |
| 1, |
| [ |
| float( |
| torch._dynamo.utils.compile_times(repr="csv")[1][0] |
| .split(",")[-1] |
| .strip() |
| ) |
| ], |
| torch_benchmark.TaskSpec( |
| sub_label, |
| None, |
| description=description, |
| label=label, |
| ), |
| ) |
| return compile_time |
| except ValueError: |
| print(f"No compilation happened for {description}: {sub_label}.") |
| return None |
| |
| |
| def run_benchmark( |
| sparse_inputs: tuple[torch.Tensor, ...], |
| dense_inputs: tuple[torch.Tensor, ...], |
| torch_net: torch.nn.Module, |
| variables: dict[str, Any], |
| backends: tuple[Backends, ...], |
| runtime_results: list[torch_benchmark.Measurement], |
| compile_time_results: list[torch_benchmark.Measurement], |
| ): |
| """Run benchmark with specified backends.""" |
| output = [] |
| output_type = None |
| |
| with torch.no_grad(): |
| for backend in backends: |
| match backend: |
| case Backends.TORCH_SPARSE_EAGER: |
| sparse_out = torch_net(*sparse_inputs) |
| output_type = sparse_out.layout |
| output.append(sparse_out) |
| runtime_results.append( |
| timer( |
| "torch_net(*sparse_inputs)", |
| "torch-sparse-eager", |
| variables=variables, |
| ) |
| ) |
| case Backends.TORCH_DENSE_EAGER: |
| output.append(torch_net(*dense_inputs)) |
| runtime_results.append( |
| timer( |
| "torch_net(*dense_inputs)", |
| "torch-dense-eager", |
| variables=variables, |
| ) |
| ) |
| case Backends.TORCH_SPARSE_INDUCTOR: |
| torch_inductor = torch.compile(torch_net) |
| torch_out = torch_inductor(*sparse_inputs) |
| output.append(torch_out) |
| compile_time = get_dynamo_compile_time( |
| variables["sub_label"], |
| variables["label"], |
| "torch-sparse-inductor-compile", |
| ) |
| if compile_time: |
| compile_time_results.append(compile_time) |
| runtime_results.append( |
| timer( |
| "torch_inductor(*sparse_inputs)", |
| "torch-sparse-inductor-runtime", |
| variables=dict(variables, **locals()), |
| ) |
| ) |
| case Backends.TORCH_DENSE_INDUCTOR: |
| torch_inductor = torch.compile(torch_net) |
| output.append(torch_inductor(*dense_inputs)) |
| compile_time = get_dynamo_compile_time( |
| variables["sub_label"], |
| variables["label"], |
| "torch-dense-inductor-compile", |
| ) |
| if compile_time: |
| compile_time_results.append(compile_time) |
| runtime_results.append( |
| timer( |
| "torch_inductor(*dense_inputs)", |
| "torch-dense-inductor-runtime", |
| variables=dict(variables, **locals()), |
| ) |
| ) |
| case Backends.MPACT_SPARSE: |
| sp_out = mpact_jit(torch_net, *sparse_inputs) |
| # Construct sparse csr tensor if the output type is csr. |
| # TODO: return sparse tensor directly instead of a tuple of arrays. |
| if type(sp_out) is tuple: |
| # torch.sparse_csr_tensor could deduce the size incorrectly, |
| # so pass the dense_out's shape explicitly. |
| dense_out = mpact_jit(torch_net, *dense_inputs) |
| output.append( |
| torch.sparse_csr_tensor(*sp_out, size=dense_out.shape) |
| ) |
| # Check MPACT and torch eager both return sparse csr output |
| # only when torch sparse eager has been run. |
| if output_type: |
| assert output_type == torch.sparse_csr |
| else: |
| output.append(torch.from_numpy(sp_out)) |
| # Check MPACT and torch eager both return dense output |
| # only when torch sparse eager has been run. |
| if output_type: |
| assert output_type == torch.strided |
| invoker, f = mpact_jit_compile(torch_net, *sparse_inputs) |
| compile_time_results.append( |
| timer( |
| "mpact_jit_compile(torch_net, *sparse_inputs)", |
| "mpact-sparse-compile", |
| "from mpact.mpactbackend import mpact_jit_compile", |
| variables=dict(variables, **locals()), |
| ) |
| ) |
| runtime_results.append( |
| timer( |
| "mpact_jit_run(invoker, f, *sparse_inputs)", |
| "mpact-sparse-runtime", |
| "from mpact.mpactbackend import mpact_jit_run", |
| variables=dict(variables, **locals()), |
| ) |
| ) |
| case Backends.MPACT_DENSE: |
| output.append(torch.from_numpy(mpact_jit(torch_net, *dense_inputs))) |
| invoker, f = mpact_jit_compile(torch_net, *dense_inputs) |
| compile_time_results.append( |
| timer( |
| "mpact_jit_compile(torch_net, *dense_inputs)", |
| "mpact-dense-compile", |
| "from mpact.mpactbackend import mpact_jit_compile", |
| variables=dict(variables, **locals()), |
| ) |
| ) |
| runtime_results.append( |
| timer( |
| "mpact_jit_run(invoker, f, *dense_inputs)", |
| "mpact-dense-runtime", |
| "from mpact.mpactbackend import mpact_jit_run", |
| variables=dict(variables, **locals()), |
| ) |
| ) |
| case _: |
| print(f"{backend} is not supported yet.") |
| |
| # Sanity check. |
| if output: |
| rtol = variables["precision"] if "precision" in variables else 1e-5 |
| assert all( |
| torch.allclose(output[0].to_dense(), out.to_dense(), rtol=rtol) |
| for out in output |
| ) |
| |
| |
| def benchmark(*args: Any) -> Callable: |
| """Wrapper for benchmark.""" |
| |
| def decorator(func): |
| @functools.wraps(func) |
| def wrapper(test_cases=args[0]): |
| runtime_results = [] |
| compile_time_results = [] |
| torch_net = net = func() |
| for test_case in test_cases: |
| label = func.__name__ |
| for sparsity in test_case["sparsity"]: |
| sub_label = f"{test_case['name']}_{sparsity}" |
| dense_inputs, sparse_inputs = generate_inputs( |
| test_case["shape"], |
| sparsity, |
| test_case["formats"], |
| test_case["dtype"], |
| test_case["drange"], |
| ) |
| |
| if "GCN" in label: |
| torch_net = net(*test_case["shape"][0]) |
| if "precision" in test_case: |
| precision = test_case["precision"] |
| |
| run_benchmark( |
| sparse_inputs, |
| dense_inputs, |
| torch_net, |
| locals(), |
| test_case["backends"], |
| runtime_results, |
| compile_time_results, |
| ) |
| |
| compare1 = torch_benchmark.Compare(runtime_results) |
| compare1.print() |
| compare2 = torch_benchmark.Compare(compile_time_results) |
| compare2.print() |
| |
| return func |
| |
| return wrapper |
| |
| return decorator |