| import torch |
| import cProfile |
| from pstats import Stats |
| |
| |
| def profile_torch( |
| func, args, row_limit=10, save_output=False, func_name=None, file_name="trace" |
| ): |
| """Use PyTorch's profiler to profile torch ops. |
| |
| To see the graph: upload trace.json to chrome://tracing |
| |
| More details about PyTorch profiler: |
| https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html |
| """ |
| func_name = func_name if func_name else func.__name__ |
| with torch.profiler.profile() as prof: |
| with torch.profiler.record_function(func_name): |
| func(*args) |
| print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=row_limit)) |
| if save_output: |
| prof.export_chrome_trace(f"{file_name}.json") |
| |
| |
| def profile_python(func, args, row_limit=10, save_output=False, file_name="stats"): |
| """Use cProfile to profile python function calls. |
| |
| To see the graph, run the following commands: |
| 1. python -m pip install snakeviz |
| 2. snakeviz stats.prof |
| """ |
| pr = cProfile.Profile() |
| pr.enable() |
| func(*args) |
| pr.disable() |
| stats = Stats(pr) |
| stats.sort_stats("tottime").print_stats(row_limit) |
| if save_output: |
| pr.dump_stats(f"{file_name}.prof") |
| |
| |
| if __name__ == "__main__": |
| # Example usage of the profiler. |
| from mpact.models.kernels import MMNet |
| from mpact_benchmark.utils.tensor_generator import generate_tensor |
| from mpact.mpactbackend import mpact_jit |
| |
| # Generate input tensors. |
| dense_tensor1 = generate_tensor(seed=0, shape=(32, 32), sparsity=0.8) |
| dense_tensor2 = generate_tensor(seed=1, shape=(32, 32), sparsity=0.8) |
| sparse_tensor1 = dense_tensor1.to_sparse_csr() |
| sparse_tensor2 = dense_tensor2.to_sparse_csr() |
| |
| # Profile with PyTorch profiler for torch operators. |
| # MPACT sparse. |
| profile_torch(mpact_jit, (MMNet(), sparse_tensor1, sparse_tensor2)) |
| # Torch sparse. |
| profile_torch( |
| MMNet(), (sparse_tensor1, sparse_tensor2), func_name="sparsexsparse matmul" |
| ) |
| |
| # Profile with cProfile for Python function calls. |
| # MPACT sparse. |
| profile_python(mpact_jit, (MMNet(), sparse_tensor1, sparse_tensor2)) |
| # Torch sparse. |
| profile_python(MMNet(), (sparse_tensor1, sparse_tensor2)) |