blob: 976c10c859e497626cf38ea9fff6117e9b57f076 [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import torch
import numpy as np
from mpact.mpactbackend import mpact_linalg
from mpact.models.kernels import MMNet
net = MMNet()
X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
#
# CHECK: module {
# CHECK: func.func @main(%[[A0:.*]]: tensor<4x4xf32>, %[[A1:.*]]: tensor<4x4xf32>) -> tensor<4x4xf32> {
# CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
# CHECK: %[[T0:.*]] = tensor.empty() : tensor<4x4xf32>
# CHECK: %[[T1:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[T0]] : tensor<4x4xf32>) -> tensor<4x4xf32>
# CHECK: %[[T2:.*]] = linalg.matmul
# CHECK-SAME: ins(%[[A0]], %[[A1]] : tensor<4x4xf32>, tensor<4x4xf32>)
# CHECK-SAME: outs(%[[T1]] : tensor<4x4xf32>) -> tensor<4x4xf32>
# CHECK: return %2 : tensor<4x4xf32>
# CHECK: }
# CHECK: }
#
linalg = mpact_linalg(net, X, Y)
print(linalg)