blob: d048c24a5087acdefeeaa40d5f818a6ba75a4ac3 [file] [log] [blame]
import functools
import torch
from enum import Enum
from typing import Any, Callable
from torch.utils import benchmark as torch_benchmark
from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
from mpact_benchmark.utils.tensor_generator import generate_inputs
class Backends(Enum):
TORCH_SPARSE_EAGER = 1
TORCH_DENSE_EAGER = 2
TORCH_SPARSE_INDUCTOR = 3
TORCH_DENSE_INDUCTOR = 4
MPACT_SPARSE = 5
MPACT_DENSE = 6
def timer(stmt: str, description: str, setup: str = "", **kwargs: Any) -> Any:
"""Timer for benchmark."""
return torch_benchmark.Timer(
stmt=stmt,
globals=kwargs["variables"],
setup=setup,
num_threads=1,
label=kwargs["variables"]["label"],
sub_label=kwargs["variables"]["sub_label"],
description=description,
).adaptive_autorange()
def get_dynamo_compile_time(sub_label: str, label: str, description: str) -> Any:
"""Get compile time from dynamo and create a benchmark measurement object."""
try:
compile_time = torch_benchmark.Measurement(
1,
[
float(
torch._dynamo.utils.compile_times(repr="csv")[1][0]
.split(",")[-1]
.strip()
)
],
torch_benchmark.TaskSpec(
sub_label,
None,
description=description,
label=label,
),
)
return compile_time
except ValueError:
print(f"No compilation happened for {description}: {sub_label}.")
return None
def run_benchmark(
sparse_inputs: tuple[torch.Tensor, ...],
dense_inputs: tuple[torch.Tensor, ...],
torch_net: torch.nn.Module,
variables: dict[str, Any],
backends: tuple[Backends, ...],
runtime_results: list[torch_benchmark.Measurement],
compile_time_results: list[torch_benchmark.Measurement],
):
"""Run benchmark with specified backends."""
output = []
output_type = None
with torch.no_grad():
for backend in backends:
match backend:
case Backends.TORCH_SPARSE_EAGER:
sparse_out = torch_net(*sparse_inputs)
output_type = sparse_out.layout
output.append(sparse_out)
runtime_results.append(
timer(
"torch_net(*sparse_inputs)",
"torch-sparse-eager",
variables=variables,
)
)
case Backends.TORCH_DENSE_EAGER:
output.append(torch_net(*dense_inputs))
runtime_results.append(
timer(
"torch_net(*dense_inputs)",
"torch-dense-eager",
variables=variables,
)
)
case Backends.TORCH_SPARSE_INDUCTOR:
torch_inductor = torch.compile(torch_net)
torch_out = torch_inductor(*sparse_inputs)
output.append(torch_out)
compile_time = get_dynamo_compile_time(
variables["sub_label"],
variables["label"],
"torch-sparse-inductor-compile",
)
if compile_time:
compile_time_results.append(compile_time)
runtime_results.append(
timer(
"torch_inductor(*sparse_inputs)",
"torch-sparse-inductor-runtime",
variables=dict(variables, **locals()),
)
)
case Backends.TORCH_DENSE_INDUCTOR:
torch_inductor = torch.compile(torch_net)
output.append(torch_inductor(*dense_inputs))
compile_time = get_dynamo_compile_time(
variables["sub_label"],
variables["label"],
"torch-dense-inductor-compile",
)
if compile_time:
compile_time_results.append(compile_time)
runtime_results.append(
timer(
"torch_inductor(*dense_inputs)",
"torch-dense-inductor-runtime",
variables=dict(variables, **locals()),
)
)
case Backends.MPACT_SPARSE:
sp_out = mpact_jit(torch_net, *sparse_inputs)
# Construct sparse csr tensor if the output type is csr.
# TODO: return sparse tensor directly instead of a tuple of arrays.
if type(sp_out) is tuple:
# torch.sparse_csr_tensor could deduce the size incorrectly,
# so pass the dense_out's shape explicitly.
dense_out = mpact_jit(torch_net, *dense_inputs)
output.append(
torch.sparse_csr_tensor(*sp_out, size=dense_out.shape)
)
# Check MPACT and torch eager both return sparse csr output
# only when torch sparse eager has been run.
if output_type:
assert output_type == torch.sparse_csr
else:
output.append(torch.from_numpy(sp_out))
# Check MPACT and torch eager both return dense output
# only when torch sparse eager has been run.
if output_type:
assert output_type == torch.strided
invoker, f = mpact_jit_compile(torch_net, *sparse_inputs)
compile_time_results.append(
timer(
"mpact_jit_compile(torch_net, *sparse_inputs)",
"mpact-sparse-compile",
"from mpact.mpactbackend import mpact_jit_compile",
variables=dict(variables, **locals()),
)
)
runtime_results.append(
timer(
"mpact_jit_run(invoker, f, *sparse_inputs)",
"mpact-sparse-runtime",
"from mpact.mpactbackend import mpact_jit_run",
variables=dict(variables, **locals()),
)
)
case Backends.MPACT_DENSE:
output.append(torch.from_numpy(mpact_jit(torch_net, *dense_inputs)))
invoker, f = mpact_jit_compile(torch_net, *dense_inputs)
compile_time_results.append(
timer(
"mpact_jit_compile(torch_net, *dense_inputs)",
"mpact-dense-compile",
"from mpact.mpactbackend import mpact_jit_compile",
variables=dict(variables, **locals()),
)
)
runtime_results.append(
timer(
"mpact_jit_run(invoker, f, *dense_inputs)",
"mpact-dense-runtime",
"from mpact.mpactbackend import mpact_jit_run",
variables=dict(variables, **locals()),
)
)
case _:
print(f"{backend} is not supported yet.")
# Sanity check.
if output:
rtol = variables["precision"] if "precision" in variables else 1e-5
assert all(
torch.allclose(output[0].to_dense(), out.to_dense(), rtol=rtol)
for out in output
)
def benchmark(*args: Any) -> Callable:
"""Wrapper for benchmark."""
def decorator(func):
@functools.wraps(func)
def wrapper(test_cases=args[0]):
runtime_results = []
compile_time_results = []
torch_net = net = func()
for test_case in test_cases:
label = func.__name__
for sparsity in test_case["sparsity"]:
sub_label = f"{test_case['name']}_{sparsity}"
dense_inputs, sparse_inputs = generate_inputs(
test_case["shape"],
sparsity,
test_case["formats"],
test_case["dtype"],
test_case["drange"],
)
if "GCN" in label:
torch_net = net(*test_case["shape"][0])
if "precision" in test_case:
precision = test_case["precision"]
run_benchmark(
sparse_inputs,
dense_inputs,
torch_net,
locals(),
test_case["backends"],
runtime_results,
compile_time_results,
)
compare1 = torch_benchmark.Compare(runtime_results)
compare1.print()
compare2 = torch_benchmark.Compare(compile_time_results)
compare2.print()
return func
return wrapper
return decorator