[mpact][benchmark] add sparsity safety to tensor generator (#64)
diff --git a/benchmark/python/utils/tensor_generator.py b/benchmark/python/utils/tensor_generator.py
index 5f9d304..c98a68b 100644
--- a/benchmark/python/utils/tensor_generator.py
+++ b/benchmark/python/utils/tensor_generator.py
@@ -50,9 +50,9 @@
Args:
seed: Seed value for np.random.
shape: A tuple for the shape of tensor.
- sparsity: Sparsity level in the range of [0, 1].
+ sparsity: Sparsity level in the range of [0, 1], viz. 0=dense and 1=all-zeros
dtype: Data type of the generated tensor. Default is np.float64.
- drange: Data range of the non-zero values. Default is (1, 100).
+ drange: Data range of the non-zero values (inclusive). Default is (1, 100).
Returns:
A dense torch tensor with the specified shape, sparsity level and type.
@@ -61,6 +61,9 @@
number of specified elements. Therefore, for batched CSR, torch.cat can be
used to concatenate generated tensors in the specified dimension.
"""
+ if sparsity < 0.0 or sparsity > 1.0:
+ raise ValueError("Invalid sparsity level: %f" % sparsity)
+
np.random.seed(seed)
size = math.prod(shape)
nse = size - int(math.ceil(sparsity * size))