[mpact][benchmark] add sparsity safety to tensor generator (#64)
diff --git a/benchmark/python/utils/tensor_generator.py b/benchmark/python/utils/tensor_generator.py index 5f9d304..c98a68b 100644 --- a/benchmark/python/utils/tensor_generator.py +++ b/benchmark/python/utils/tensor_generator.py
@@ -50,9 +50,9 @@ Args: seed: Seed value for np.random. shape: A tuple for the shape of tensor. - sparsity: Sparsity level in the range of [0, 1]. + sparsity: Sparsity level in the range of [0, 1], viz. 0=dense and 1=all-zeros dtype: Data type of the generated tensor. Default is np.float64. - drange: Data range of the non-zero values. Default is (1, 100). + drange: Data range of the non-zero values (inclusive). Default is (1, 100). Returns: A dense torch tensor with the specified shape, sparsity level and type. @@ -61,6 +61,9 @@ number of specified elements. Therefore, for batched CSR, torch.cat can be used to concatenate generated tensors in the specified dimension. """ + if sparsity < 0.0 or sparsity > 1.0: + raise ValueError("Invalid sparsity level: %f" % sparsity) + np.random.seed(seed) size = math.prod(shape) nse = size - int(math.ceil(sparsity * size))