TODOs:
focusing on:
- implement the “new” syntax for the target description
- allow nested ops on the left side: lavm.sub(tgt.neg(%a), %b) : lavm.add(%a, tgt.neg(%b));
- support operations with same name but different number of arguments
- error on wrong number of arguments in map section and in target sections
- catch duplicate argument names in the target section
- catch mismatch in the argument count and the type in the target section
functionality:
-- define memory description section
-- allow [parallel] loops / ops with regions?
-- generate machine description C++ class
-- generate patterns for structured op tiling
- generate td file with the target dialect?
- allow sequences in rhs of map section: (op0(args...), op1(args), ...) so that we can chain buffer-level ops
- have type containers? (as in vreg, sreg, etc.)
- allow immediates
- support naming: lavm.add(%a, %b) : tgt.add(tgt.neg(%a):$n, tgt.neg($n)));
- allow instantiations for specific types only: lavm.add(%a : f32, %b : f32) -> f32 : zzz;
- how to express instructions operating on memory directly?
- allow arbitrary variable names and various parameter counts(?)
- support default values for arguments: lavm.add(%a, %b = 42)...
- detect infinite cycles in expansions
- how to represent denorm/NaN/rounding/saturating behavior?
- add support for AnyType, which is helpful to define MEM <-> AnyType loads/stores and op(%a) : %a empty expansions
- ??? change syntax to (lavm.matmul %dst %a %b) : (lavm.store (tgt.matmul (lavm.load %a) (lavm.load %b)) %dst);
- !!! change syntax to be target op-centric instead of lavm-op centric (define target ops rather than mappings
from lavm ops)
- allow lavm ops to drop “lavm.” prefix?
- add pre-checkin tests
- re-implement the target description parser
error checking:
- emit error at parsing and stop when ‘;’ is missing at the end of the line
(this is a frequent cause of the late assert ‘GetTargetExpansions(): lavm_op != nullptr’)
- target section: remove variable names in lhs? as in tgt.nuts(%, %) or just get rid of the args on the lhs?
- catch duplicate expansions or anything else duplicated on lhs or rhs, issue warning and ignore
- filter out duplicates that differ only in variable names
other:
- rename ‘map’ section into ‘expansions’?
Operations and types supported by the target.
target { tgt.nuts(%x, %y) : (aaa, bbb) -> ccc, (a, b, c) -> (x, y, z), (f32, vector<128xf32>) -> (); tgt.nuts(%x, %y) : ((m) -> p) -> q; tgt.nuts(%x, %y) : ((m) -> ((p) -> bar)) -> q; tgt.nuts(%x, %y) : ((m) -> ((p) -> (bar) -> foo, (p) -> ((bar) -> foo), ((p) -> (bar)) -> foo)) -> ((q) -> r); tgt.nuts(%x, %y) : (m, ((n) -> ()) -> p) -> ((q) -> r);
tgt.i2f(%x) : (i32) -> f32, (vector<8x128xi32>) -> vector<8x128xf32>; tgt.f2i(%x) : (f32) -> i32, (vector<8x128xf32>) -> vector<8x128xi32>; tgt.neg(%x) : vector<8x128xf32>, f32; # vector<128xf32>
tgt.add(%x, %y) : vector<128xf32> @ 2, vector<8x128xf32> @ 3, f32 @ 1; tgt.sub(%x, %y) : vector<128xf32>, vector<8x128xf32>, f32; tgt.mul(%x, %y) : f32;
tgt.dot(%x, %y) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32> @ 42; tgt.matmul(%x, %y) : (vector<128x128xf32>, vector<128x128xf32>) -> vector<128x128xf32> @ 99; tgt.wait() : () -> ();
tgt.dma(%to, %from, %size) : (memref, memref, i32) -> (); tgt.dma(%to, %from, %size) : (memref, memref, i32) -> (); tgt.dma(%to, %from, %size) : (memref, memref, i32) -> (); tgt.dma(%to, %from, %size) : (memref, memref, i32) -> (); tgt.load(%address) : (memref) -> f32, (memref) -> i32; tgt.load(%address) : (memref) -> vector<128xf32>, (memref) -> vector<128xi32>; tgt.load(%address) : (memref) -> vector<8x128xf32>, (memref) -> vector<8x128xi32>; tgt.store(%value, %address) : (f32, memref) -> (), (i32, memref) -> (); tgt.store(%value, %address) : (vector<128xf32>, memref) -> (), (vector<128xi32>, memref) -> (); tgt.store(%value, %address) : (vector<8x128xf32>, memref) -> (), (vector<8x128xi32>, memref) -> ();
pseudo.combine(%x) : f32 -> vector<128xf32>;
pseudo.extract(%x) : vector<128xf32> -> f32;
pseudo.combine(%x) : vector<128xf32> -> vector<8x128xf32>;
pseudo.extract(%x) : vector<8x128xf32> -> vector<128xf32>;
pseudo.combine(%x) : vector<8x128xf32> -> vector<128x128xf32>;
pseudo.extract(%x) : vector<128x128xf32> -> vector<8x128xf32>;
pseudo.combine_f32_to_128xf32(%x) : f32 -> vector<128xf32>;
pseudo.combine_f32_to_8x128xf32(%x) : f32 -> vector<8x128xf32>;
pseudo.combine_f32_to_128x128xf32(%x) : f32 -> vector<128x128xf32>;
pseudo.combine_128xf32_to_8x128xf32(%x) : vector<128xf32> -> vector<8x128xf32>;
pseudo.combine_128xf32_to_128x128xf32(%x) : vector<128xf32> -> vector<128x128xf32>;
pseudo.combine_8x128xf32_to_128x128xf32(%x) : vector<8x128xf32> -> vector<128x128xf32>;
pseudo.extract_f32_from_128xf32(%x) : vector<128xf32> -> f32;
pseudo.extract_f32_from_8x128xf32(%x) : vector<8x128xf32> -> f32;
pseudo.extract_f32_from_128x128xf32(%x) : vector<128x128xf32> -> f32;
pseudo.extract_128xf32_from_8x128xf32(%x) : vector<8x128xf32> -> vector<128xf32>;
pseudo.extract_128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<128xf32>;
pseudo.extract_8x128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<8x128xf32>;
}
Map LAVM ops to previously defined target ops or their combinations.
current restriction is that lavm ops on the lhs should appear with same arg names and count
map { lavm.neg(%a) : tgt.neg(%a); lavm.neg(%a) : tgt.neg(tgt.i2f(%a));
not yet supported
lavm.none(%a) : %a;
lavm.zadd(%a, %b) : tgt.add(lavm.none(%a), %b);
lavm.add(%a, %b) : tgt.add(%a, %b), tgt.sub(%a, tgt.neg(%b)), tgt.add(tgt.i2f(%a), tgt.i2f(%b)), tgt.add(tgt.i2f(%b), %a), tgt.add(%b, tgt.i2f(%a)), tgt.add(%b, tgt.add(tgt.add(%a, %b), tgt.neg(%b))); lavm.add(%a, %b) : lavm.neg(tgt.add(lavm.neg(%a), lavm.neg(%b))), lavm.neg(tgt.add(lavm.neg(tgt.f2i(%b)), tgt.neg(%a)));
lavm.sub(%a, %b) : lavm.add(%a, tgt.neg(%b));
lavm.add_x(%dst, %a, %b) : # tgt.dma(
tgt.store(tgt.add(tgt.load(tgt.dma(%a)),
tgt.load(tgt.dma(%b))),
%dst)
# )
;
lavm.matmul(%dst : memref, %a : memref, %b : memref) :
%aa = alloc() : memref,
%bb = alloc() : memref,
%cc = alloc() : memref,
tgt.dma(%a, %aa),
tgt.dma(%b, %bb),
lavm.store(lavm.matmul(lavm.load(%aa), lavm.load(%bb)), %cc)
tgt.dma(%cc, %dst); # dealloc???
lavm.dma(%to, %from, %size) : tgt.dma(%to, %from, %size);
lavm.load(%address) : tgt.load(%address);
lavm.load(%address) : lavm.combine(tgt.load(%address));
lavm.load(%address) : lavm.extract(tgt.load(%address));
lavm.store(%value, %address) : tgt.store(%value, %address);
want:
lavm.matmul(%dst, %a, %b) : lavm.store(tgt.matmul(lavm.load(%a), lavm.load(%b)), %dst);
...no: get rid of these
lavm.combine(%a) : pseudo.combine(%a);
lavm.extract(%a) : pseudo.extract(%a);
lavm.combine(%a) : pseudo.combine(pseudo.combine(%a));
lavm.extract(%a) : pseudo.extract(pseudo.extract(%a));
lavm.combine(%a) : pseudo.combine(pseudo.combine(pseudo.combine(%a)));
lavm.extract(%a) : pseudo.extract(pseudo.extract(pseudo.extract(%a)));
lavm.matmulY(%dst, %a, %b) : lavm.store(lavm.extract(
tgt.matmul(lavm.combine(lavm.load(%a)),
lavm.combine(lavm.load(%b)))),
%dst);
... also no
lavm.combineX(%a) : pseudo.combine_f32_to_128xf32(%a);
lavm.combineX(%a) : pseudo.combine_f32_to_8x128xf32(%a);
lavm.combineX(%a) : pseudo.combine_f32_to_128x128xf32(%a);
lavm.combineX(%a) : pseudo.combine_128xf32_to_8x128xf32(%a);
lavm.combineX(%a) : pseudo.combine_128xf32_to_128x128xf32(%a);
lavm.combineX(%a) : pseudo.combine_8x128xf32_to_128x128xf32(%a);
lavm.extractX(%a) : pseudo.extract_f32_from_128xf32(%a);
lavm.extractX(%a) : pseudo.extract_f32_from_8x128xf32(%a);
lavm.extractX(%a) : pseudo.extract_f32_from_128x128xf32(%a);
lavm.extractX(%a) : pseudo.extract_128xf32_from_8x128xf32(%a);
lavm.extractX(%a) : pseudo.extract_128xf32_from_128x128xf32(%a);
lavm.extractX(%a) : pseudo.extract_8x128xf32_from_128x128xf32(%a);
lavm.matmulX(%dst, %a, %b) : lavm.store(lavm.extractX(
tgt.matmul(lavm.combineX(lavm.load(%a)),
lavm.combineX(lavm.load(%b)))),
%dst);
lavm.add(%a, %b) : tgt.add(%a, %a); # should cause a warning that %b is unused (implemented in wrong place, too late)
lavm.add(%a, %b) : tgt.add(lavm.neg(%a), lavm.neg(%a)); # should cause an earlier warning
lavm.add(%a, %b) : tgt.boo(%a, %b); # should be an error because tgt.boo is undefined
lavm.add(%a, %b) : tgt.add(tgt.boo(%a), %b); # should be an error: undefined tgt.boo (implemented in wrong place, too late)
lavm.add(%a, %b) : tgt.add(%a, %c); # this is an error! should be caught (NYI)
}
Target memory description.
memory {
TPU-like memory:
HBM : size = 16G; HBM : garbage = ignored; VMEM : size = 16M; SMEM : size = 16K; CMEM : size = 16M; HBM -> VMEM; VMEM -> HBM; HBM -> SMEM; SMEM -> HBM; HBM -> CMEM -> VMEM; VMEM -> CMEM -> HBM;
GPU-like memory:
GLOBAL : size = 8G; SHARED : size = 16M; LOCAL : size = 1M;
CPU-like memory:
MEMORY : size = 64GB; $L1 : size = 512K; $L2 : size = 4M; MEMORY -> $L2 -> $L1; $L1 -> $L2 -> MEMORY;
specifying
- on the mem side: specify what is address, what is the value, stride, banks
- on the register(?) side: size, stride, zero/sign-extended, mask, if can partially fill target reg
- load/store latency, size/granularity
- dma latency, size/granularity
- cache latency, size/granularity, policy, associativity, on which path
some properties are memory instructions' properties, can use attributes to represent them
Interface: cache levels/sizes/latencies/lane sizes/policies
memory spaces, dmas between memory spaces, sizes, etc.
how to specify direct access to memory from other instructions? using implicit addresses? chained instructions?
for ex. a chain of instructions, which process in memory data, and each modifies the implicit pointer
used in the next instruction.
how to represent memory local to PEs?
how to represent in-memory calculations?
name [-> $name -> ...] -> 'reg : load instructions; # do not use register?
'reg [-> $name -> ...] -> name : store instructions;
name -> name : dma instructions;
TPU-like memory:
SMEM -> 'reg : sload(%address);
'reg -> SMEM : sstore(%value, %address);
VMEM -> 'reg : vload(%address); # stride, mask
'reg -> VMEM : vstore(%value, %address); # stride, mask
HBM -> VMEM : dma(%from, %to, %size) size0^latency0, size1^latency1; # syntax?
VMEM -> HBM : dma();
HBM -> SMEM : dma();
SMEM -> HBM : dma();
GPU-like memory:
GLOBAL -> 'reg : ld(%address);
SHARED -> 'reg : lds(%address);
LOCAL -> 'reg : ldl(%address);
'reg -> GLOBAL : st(%value, %address);
'reg -> SHARED : sts(%value, %address);
'reg -> LOCAL : stl(%value, %address);
CPU-like memory:
MEMORY -> $L2 -> $L1 -> 'reg : load(%address);
'reg -> $L1 -> $L2 -> MEMORY : store(%value, %address);
}