blob: ef7a3bd6268a438a454fb4c5c8b9a3d534b8bbef [file] [log] [blame]
import torch
import numpy as np
from mpact.models.gcn import GraphConv
from mpact_benchmark.utils.benchmark_utils import benchmark, Backends
@benchmark(
[
{
"name": f"{fmt}_{shape}_{dtype.__name__}",
"shape": shape,
"formats": fmt,
"dtype": dtype,
"drange": (1, 100),
"sparsity": [0, 0.5, 0.9, 0.99],
"backends": [b for b in Backends],
}
for shape in [
[[128, 128], [128, 128]],
[[512, 512], [512, 512]],
[[1024, 1024], [1024, 1024]],
]
for fmt in [["dense", "csr"]]
for dtype in [np.float32]
]
)
def GCN() -> torch.nn.Module:
"""Graph Convolution Network."""
return GraphConv
if __name__ == "__main__":
GCN()