TODOs:

focusing on:

- implement the “new” syntax for the target description

- allow nested ops on the left side: lavm.sub(tgt.neg(%a), %b) : lavm.add(%a, tgt.neg(%b));

- support operations with same name but different number of arguments

- error on wrong number of arguments in map section and in target sections

- catch duplicate argument names in the target section

- catch mismatch in the argument count and the type in the target section

functionality:

-- define memory description section

-- allow [parallel] loops / ops with regions?

-- generate machine description C++ class

-- generate patterns for structured op tiling

- generate td file with the target dialect?

- allow sequences in rhs of map section: (op0(args...), op1(args), ...) so that we can chain buffer-level ops

- have type containers? (as in vreg, sreg, etc.)

- allow immediates

- support naming: lavm.add(%a, %b) : tgt.add(tgt.neg(%a):$n, tgt.neg($n)));

- allow instantiations for specific types only: lavm.add(%a : f32, %b : f32) -> f32 : zzz;

- how to express instructions operating on memory directly?

- allow arbitrary variable names and various parameter counts(?)

- support default values for arguments: lavm.add(%a, %b = 42)...

- detect infinite cycles in expansions

- how to represent denorm/NaN/rounding/saturating behavior?

- add support for AnyType, which is helpful to define MEM <-> AnyType loads/stores and op(%a) : %a empty expansions

- ??? change syntax to (lavm.matmul %dst %a %b) : (lavm.store (tgt.matmul (lavm.load %a) (lavm.load %b)) %dst);

- !!! change syntax to be target op-centric instead of lavm-op centric (define target ops rather than mappings

from lavm ops)

- allow lavm ops to drop “lavm.” prefix?

- add pre-checkin tests

- re-implement the target description parser

error checking:

- emit error at parsing and stop when ‘;’ is missing at the end of the line

(this is a frequent cause of the late assert ‘GetTargetExpansions(): lavm_op != nullptr’)

- target section: remove variable names in lhs? as in tgt.nuts(%, %) or just get rid of the args on the lhs?

- catch duplicate expansions or anything else duplicated on lhs or rhs, issue warning and ignore

- filter out duplicates that differ only in variable names

other:

- rename ‘map’ section into ‘expansions’?

Operations and types supported by the target.

target { tgt.nuts(%x, %y) : (aaa, bbb) -> ccc, (a, b, c) -> (x, y, z), (f32, vector<128xf32>) -> (); tgt.nuts(%x, %y) : ((m) -> p) -> q; tgt.nuts(%x, %y) : ((m) -> ((p) -> bar)) -> q; tgt.nuts(%x, %y) : ((m) -> ((p) -> (bar) -> foo, (p) -> ((bar) -> foo), ((p) -> (bar)) -> foo)) -> ((q) -> r); tgt.nuts(%x, %y) : (m, ((n) -> ()) -> p) -> ((q) -> r);

tgt.i2f(%x) : (i32) -> f32, (vector<8x128xi32>) -> vector<8x128xf32>; tgt.f2i(%x) : (f32) -> i32, (vector<8x128xf32>) -> vector<8x128xi32>; tgt.neg(%x) : vector<8x128xf32>, f32; # vector<128xf32>

tgt.add(%x, %y) : vector<128xf32> @ 2, vector<8x128xf32> @ 3, f32 @ 1; tgt.sub(%x, %y) : vector<128xf32>, vector<8x128xf32>, f32; tgt.mul(%x, %y) : f32;

tgt.dot(%x, %y) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32> @ 42; tgt.matmul(%x, %y) : (vector<128x128xf32>, vector<128x128xf32>) -> vector<128x128xf32> @ 99; tgt.wait() : () -> ();

tgt.dma(%to, %from, %size) : (memref, memref, i32) -> (); tgt.dma(%to, %from, %size) : (memref, memref, i32) -> (); tgt.dma(%to, %from, %size) : (memref, memref, i32) -> (); tgt.dma(%to, %from, %size) : (memref, memref, i32) -> (); tgt.load(%address) : (memref) -> f32, (memref) -> i32; tgt.load(%address) : (memref) -> vector<128xf32>, (memref) -> vector<128xi32>; tgt.load(%address) : (memref) -> vector<8x128xf32>, (memref) -> vector<8x128xi32>; tgt.store(%value, %address) : (f32, memref) -> (), (i32, memref) -> (); tgt.store(%value, %address) : (vector<128xf32>, memref) -> (), (vector<128xi32>, memref) -> (); tgt.store(%value, %address) : (vector<8x128xf32>, memref) -> (), (vector<8x128xi32>, memref) -> ();

pseudo.combine(%x) : f32 -> vector<128xf32>;

pseudo.extract(%x) : vector<128xf32> -> f32;

pseudo.combine(%x) : vector<128xf32> -> vector<8x128xf32>;

pseudo.extract(%x) : vector<8x128xf32> -> vector<128xf32>;

pseudo.combine(%x) : vector<8x128xf32> -> vector<128x128xf32>;

pseudo.extract(%x) : vector<128x128xf32> -> vector<8x128xf32>;

pseudo.combine_f32_to_128xf32(%x) : f32 -> vector<128xf32>;

pseudo.combine_f32_to_8x128xf32(%x) : f32 -> vector<8x128xf32>;

pseudo.combine_f32_to_128x128xf32(%x) : f32 -> vector<128x128xf32>;

pseudo.combine_128xf32_to_8x128xf32(%x) : vector<128xf32> -> vector<8x128xf32>;

pseudo.combine_128xf32_to_128x128xf32(%x) : vector<128xf32> -> vector<128x128xf32>;

pseudo.combine_8x128xf32_to_128x128xf32(%x) : vector<8x128xf32> -> vector<128x128xf32>;

pseudo.extract_f32_from_128xf32(%x) : vector<128xf32> -> f32;

pseudo.extract_f32_from_8x128xf32(%x) : vector<8x128xf32> -> f32;

pseudo.extract_f32_from_128x128xf32(%x) : vector<128x128xf32> -> f32;

pseudo.extract_128xf32_from_8x128xf32(%x) : vector<8x128xf32> -> vector<128xf32>;

pseudo.extract_128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<128xf32>;

pseudo.extract_8x128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<8x128xf32>;

}

Map LAVM ops to previously defined target ops or their combinations.

current restriction is that lavm ops on the lhs should appear with same arg names and count

map { lavm.neg(%a) : tgt.neg(%a); lavm.neg(%a) : tgt.neg(tgt.i2f(%a));

not yet supported

lavm.none(%a) : %a;

lavm.zadd(%a, %b) : tgt.add(lavm.none(%a), %b);

lavm.add(%a, %b) : tgt.add(%a, %b), tgt.sub(%a, tgt.neg(%b)), tgt.add(tgt.i2f(%a), tgt.i2f(%b)), tgt.add(tgt.i2f(%b), %a), tgt.add(%b, tgt.i2f(%a)), tgt.add(%b, tgt.add(tgt.add(%a, %b), tgt.neg(%b))); lavm.add(%a, %b) : lavm.neg(tgt.add(lavm.neg(%a), lavm.neg(%b))), lavm.neg(tgt.add(lavm.neg(tgt.f2i(%b)), tgt.neg(%a)));

lavm.sub(%a, %b) : lavm.add(%a, tgt.neg(%b));

lavm.add_x(%dst, %a, %b) : # tgt.dma(

tgt.store(tgt.add(tgt.load(tgt.dma(%a)),

tgt.load(tgt.dma(%b))),

%dst)

# )

;

lavm.matmul(%dst : memref, %a : memref, %b : memref) :

%aa = alloc() : memref,

%bb = alloc() : memref,

%cc = alloc() : memref,

tgt.dma(%a, %aa),

tgt.dma(%b, %bb),

lavm.store(lavm.matmul(lavm.load(%aa), lavm.load(%bb)), %cc)

tgt.dma(%cc, %dst); # dealloc???

lavm.dma(%to, %from, %size) : tgt.dma(%to, %from, %size);

lavm.load(%address) : tgt.load(%address);

lavm.load(%address) : lavm.combine(tgt.load(%address));

lavm.load(%address) : lavm.extract(tgt.load(%address));

lavm.store(%value, %address) : tgt.store(%value, %address);

want:

lavm.matmul(%dst, %a, %b) : lavm.store(tgt.matmul(lavm.load(%a), lavm.load(%b)), %dst);

...no: get rid of these

lavm.combine(%a) : pseudo.combine(%a);

lavm.extract(%a) : pseudo.extract(%a);

lavm.combine(%a) : pseudo.combine(pseudo.combine(%a));

lavm.extract(%a) : pseudo.extract(pseudo.extract(%a));

lavm.combine(%a) : pseudo.combine(pseudo.combine(pseudo.combine(%a)));

lavm.extract(%a) : pseudo.extract(pseudo.extract(pseudo.extract(%a)));

lavm.matmulY(%dst, %a, %b) : lavm.store(lavm.extract(

tgt.matmul(lavm.combine(lavm.load(%a)),

lavm.combine(lavm.load(%b)))),

%dst);

... also no

lavm.combineX(%a) : pseudo.combine_f32_to_128xf32(%a);

lavm.combineX(%a) : pseudo.combine_f32_to_8x128xf32(%a);

lavm.combineX(%a) : pseudo.combine_f32_to_128x128xf32(%a);

lavm.combineX(%a) : pseudo.combine_128xf32_to_8x128xf32(%a);

lavm.combineX(%a) : pseudo.combine_128xf32_to_128x128xf32(%a);

lavm.combineX(%a) : pseudo.combine_8x128xf32_to_128x128xf32(%a);

lavm.extractX(%a) : pseudo.extract_f32_from_128xf32(%a);

lavm.extractX(%a) : pseudo.extract_f32_from_8x128xf32(%a);

lavm.extractX(%a) : pseudo.extract_f32_from_128x128xf32(%a);

lavm.extractX(%a) : pseudo.extract_128xf32_from_8x128xf32(%a);

lavm.extractX(%a) : pseudo.extract_128xf32_from_128x128xf32(%a);

lavm.extractX(%a) : pseudo.extract_8x128xf32_from_128x128xf32(%a);

lavm.matmulX(%dst, %a, %b) : lavm.store(lavm.extractX(

tgt.matmul(lavm.combineX(lavm.load(%a)),

lavm.combineX(lavm.load(%b)))),

%dst);

lavm.add(%a, %b) : tgt.add(%a, %a); # should cause a warning that %b is unused (implemented in wrong place, too late)

lavm.add(%a, %b) : tgt.add(lavm.neg(%a), lavm.neg(%a)); # should cause an earlier warning

lavm.add(%a, %b) : tgt.boo(%a, %b); # should be an error because tgt.boo is undefined

lavm.add(%a, %b) : tgt.add(tgt.boo(%a), %b); # should be an error: undefined tgt.boo (implemented in wrong place, too late)

lavm.add(%a, %b) : tgt.add(%a, %c); # this is an error! should be caught (NYI)

}

Target memory description.

memory {

TPU-like memory:

HBM : size = 16G; HBM : garbage = ignored; VMEM : size = 16M; SMEM : size = 16K; CMEM : size = 16M; HBM -> VMEM; VMEM -> HBM; HBM -> SMEM; SMEM -> HBM; HBM -> CMEM -> VMEM; VMEM -> CMEM -> HBM;

GPU-like memory:

GLOBAL : size = 8G; SHARED : size = 16M; LOCAL : size = 1M;

CPU-like memory:

MEMORY : size = 64GB; $L1 : size = 512K; $L2 : size = 4M; MEMORY -> $L2 -> $L1; $L1 -> $L2 -> MEMORY;

specifying

- on the mem side: specify what is address, what is the value, stride, banks

- on the register(?) side: size, stride, zero/sign-extended, mask, if can partially fill target reg

- load/store latency, size/granularity

- dma latency, size/granularity

- cache latency, size/granularity, policy, associativity, on which path

some properties are memory instructions' properties, can use attributes to represent them

Interface: cache levels/sizes/latencies/lane sizes/policies

memory spaces, dmas between memory spaces, sizes, etc.

how to specify direct access to memory from other instructions? using implicit addresses? chained instructions?

for ex. a chain of instructions, which process in memory data, and each modifies the implicit pointer

used in the next instruction.

how to represent memory local to PEs?

how to represent in-memory calculations?

name [-> $name -> ...] -> 'reg : load instructions; # do not use register?

'reg [-> $name -> ...] -> name : store instructions;

name -> name : dma instructions;

TPU-like memory:

SMEM -> 'reg : sload(%address);

'reg -> SMEM : sstore(%value, %address);

VMEM -> 'reg : vload(%address); # stride, mask

'reg -> VMEM : vstore(%value, %address); # stride, mask

HBM -> VMEM : dma(%from, %to, %size) size0^latency0, size1^latency1; # syntax?

VMEM -> HBM : dma();

HBM -> SMEM : dma();

SMEM -> HBM : dma();

GPU-like memory:

GLOBAL -> 'reg : ld(%address);

SHARED -> 'reg : lds(%address);

LOCAL -> 'reg : ldl(%address);

'reg -> GLOBAL : st(%value, %address);

'reg -> SHARED : sts(%value, %address);

'reg -> LOCAL : stl(%value, %address);

CPU-like memory:

MEMORY -> $L2 -> $L1 -> 'reg : load(%address);

'reg -> $L1 -> $L2 -> MEMORY : store(%value, %address);

}