| # TODOs: |
| # |
| # focusing on: |
| # - implement the "new" syntax for the target description |
| # - allow nested ops on the left side: lavm.sub(tgt.neg(%a), %b) : lavm.add(%a, tgt.neg(%b)); |
| # - support operations with same name but different number of arguments |
| # - error on wrong number of arguments in map section and in target sections |
| # - catch duplicate argument names in the target section |
| # - catch mismatch in the argument count and the type in the target section |
| # |
| # functionality: |
| # -- define memory description section |
| # -- allow [parallel] loops / ops with regions? |
| # -- generate machine description C++ class |
| # -- generate patterns for structured op tiling |
| # - generate td file with the target dialect? |
| # - allow sequences in rhs of map section: (op0(args...), op1(args), ...) so that we can chain buffer-level ops |
| # - have type containers? (as in vreg, sreg, etc.) |
| # - allow immediates |
| # - support naming: lavm.add(%a, %b) : tgt.add(tgt.neg(%a):$n, tgt.neg($n))); |
| # - allow instantiations for specific types only: lavm.add(%a : f32, %b : f32) -> f32 : zzz; |
| # - how to express instructions operating on memory directly? |
| # - allow arbitrary variable names and various parameter counts(?) |
| # - support default values for arguments: lavm.add(%a, %b = 42)... |
| # - detect infinite cycles in expansions |
| # - how to represent denorm/NaN/rounding/saturating behavior? |
| # - add support for AnyType, which is helpful to define MEM <-> AnyType loads/stores and op(%a) : %a empty expansions |
| # - ??? change syntax to (lavm.matmul %dst %a %b) : (lavm.store (tgt.matmul (lavm.load %a) (lavm.load %b)) %dst); |
| # - !!! change syntax to be target op-centric instead of lavm-op centric (define target ops rather than mappings |
| # from lavm ops) |
| # - allow lavm ops to drop "lavm." prefix? |
| # - add pre-checkin tests |
| # - re-implement the target description parser |
| # |
| # error checking: |
| # - emit error at parsing *and stop* when ';' is missing at the end of the line |
| # (this is a frequent cause of the late assert 'GetTargetExpansions(): lavm_op != nullptr') |
| # - target section: remove variable names in lhs? as in tgt.nuts(%, %) or just get rid of the args on the lhs? |
| # - catch duplicate expansions or anything else duplicated on lhs or rhs, issue warning and ignore |
| # - filter out duplicates that differ only in variable names |
| # |
| # other: |
| # - rename 'map' section into 'expansions'? |
| |
| ### Operations and types supported by the target. ### |
| target { |
| tgt.nuts(%x, %y) : (aaa, bbb) -> ccc, (a, b, c) -> (x, y, z), (f32, vector<128xf32>) -> (); |
| tgt.nuts(%x, %y) : ((m) -> p) -> q; |
| tgt.nuts(%x, %y) : ((m) -> ((p) -> bar)) -> q; |
| tgt.nuts(%x, %y) : ((m) -> ((p) -> (bar) -> foo, (p) -> ((bar) -> foo), ((p) -> (bar)) -> foo)) -> ((q) -> r); |
| tgt.nuts(%x, %y) : (m, ((n) -> ()) -> p) -> ((q) -> r); |
| |
| tgt.i2f(%x) : (i32) -> f32, (vector<8x128xi32>) -> vector<8x128xf32>; |
| tgt.f2i(%x) : (f32) -> i32, (vector<8x128xf32>) -> vector<8x128xi32>; |
| tgt.neg(%x) : vector<8x128xf32>, f32; # vector<128xf32> |
| |
| tgt.add(%x, %y) : vector<128xf32> @ 2, vector<8x128xf32> @ 3, f32 @ 1; |
| tgt.sub(%x, %y) : vector<128xf32>, vector<8x128xf32>, f32; |
| tgt.mul(%x, %y) : f32; |
| |
| tgt.dot(%x, %y) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32> @ 42; |
| tgt.matmul(%x, %y) : (vector<128x128xf32>, vector<128x128xf32>) -> vector<128x128xf32> @ 99; |
| tgt.wait() : () -> (); |
| |
| tgt.dma(%to, %from, %size) : (memref<HBM>, memref<VMEM>, i32) -> (); |
| tgt.dma(%to, %from, %size) : (memref<VMEM>, memref<HBM>, i32) -> (); |
| tgt.dma(%to, %from, %size) : (memref<HBM>, memref<SMEM>, i32) -> (); |
| tgt.dma(%to, %from, %size) : (memref<SMEM>, memref<HBM>, i32) -> (); |
| tgt.load(%address) : (memref<SMEM>) -> f32, (memref<SMEM>) -> i32; |
| tgt.load(%address) : (memref<VMEM>) -> vector<128xf32>, (memref<VMEM>) -> vector<128xi32>; |
| tgt.load(%address) : (memref<VMEM>) -> vector<8x128xf32>, (memref<VMEM>) -> vector<8x128xi32>; |
| tgt.store(%value, %address) : (f32, memref<SMEM>) -> (), (i32, memref<SMEM>) -> (); |
| tgt.store(%value, %address) : (vector<128xf32>, memref<VMEM>) -> (), (vector<128xi32>, memref<VMEM>) -> (); |
| tgt.store(%value, %address) : (vector<8x128xf32>, memref<VMEM>) -> (), (vector<8x128xi32>, memref<VMEM>) -> (); |
| |
| # pseudo.combine(%x) : f32 -> vector<128xf32>; |
| # pseudo.extract(%x) : vector<128xf32> -> f32; |
| # pseudo.combine(%x) : vector<128xf32> -> vector<8x128xf32>; |
| # pseudo.extract(%x) : vector<8x128xf32> -> vector<128xf32>; |
| # pseudo.combine(%x) : vector<8x128xf32> -> vector<128x128xf32>; |
| # pseudo.extract(%x) : vector<128x128xf32> -> vector<8x128xf32>; |
| |
| # pseudo.combine_f32_to_128xf32(%x) : f32 -> vector<128xf32>; |
| # pseudo.combine_f32_to_8x128xf32(%x) : f32 -> vector<8x128xf32>; |
| # pseudo.combine_f32_to_128x128xf32(%x) : f32 -> vector<128x128xf32>; |
| # pseudo.combine_128xf32_to_8x128xf32(%x) : vector<128xf32> -> vector<8x128xf32>; |
| # pseudo.combine_128xf32_to_128x128xf32(%x) : vector<128xf32> -> vector<128x128xf32>; |
| # pseudo.combine_8x128xf32_to_128x128xf32(%x) : vector<8x128xf32> -> vector<128x128xf32>; |
| # pseudo.extract_f32_from_128xf32(%x) : vector<128xf32> -> f32; |
| # pseudo.extract_f32_from_8x128xf32(%x) : vector<8x128xf32> -> f32; |
| # pseudo.extract_f32_from_128x128xf32(%x) : vector<128x128xf32> -> f32; |
| # pseudo.extract_128xf32_from_8x128xf32(%x) : vector<8x128xf32> -> vector<128xf32>; |
| # pseudo.extract_128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<128xf32>; |
| # pseudo.extract_8x128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<8x128xf32>; |
| } |
| |
| ### Map LAVM ops to previously defined target ops or their combinations. ### |
| # current restriction is that lavm ops on the lhs should appear with same arg names and count |
| map { |
| lavm.neg(%a) : tgt.neg(%a); |
| lavm.neg(%a) : tgt.neg(tgt.i2f(%a)); |
| |
| # not yet supported |
| # lavm.none(%a) : %a; |
| # lavm.zadd(%a, %b) : tgt.add(lavm.none(%a), %b); |
| |
| lavm.add(%a, %b) : tgt.add(%a, %b), |
| tgt.sub(%a, tgt.neg(%b)), |
| tgt.add(tgt.i2f(%a), tgt.i2f(%b)), |
| tgt.add(tgt.i2f(%b), %a), tgt.add(%b, tgt.i2f(%a)), |
| tgt.add(%b, tgt.add(tgt.add(%a, %b), tgt.neg(%b))); |
| lavm.add(%a, %b) : lavm.neg(tgt.add(lavm.neg(%a), lavm.neg(%b))), |
| lavm.neg(tgt.add(lavm.neg(tgt.f2i(%b)), tgt.neg(%a))); |
| |
| lavm.sub(%a, %b) : lavm.add(%a, tgt.neg(%b)); |
| |
| # lavm.add_x(%dst, %a, %b) : # tgt.dma( |
| # tgt.store(tgt.add(tgt.load(tgt.dma(%a)), |
| # tgt.load(tgt.dma(%b))), |
| # %dst) |
| # # ) |
| # ; |
| # lavm.matmul(%dst : memref<HBM>, %a : memref<HBM>, %b : memref<HBM>) : |
| # %aa = alloc() : memref<VMEM>, |
| # %bb = alloc() : memref<VMEM>, |
| # %cc = alloc() : memref<VMEM>, |
| # tgt.dma(%a, %aa), |
| # tgt.dma(%b, %bb), |
| # lavm.store(lavm.matmul(lavm.load(%aa), lavm.load(%bb)), %cc) |
| # tgt.dma(%cc, %dst); # dealloc??? |
| |
| lavm.dma(%to, %from, %size) : tgt.dma(%to, %from, %size); |
| |
| lavm.load(%address) : tgt.load(%address); |
| # lavm.load(%address) : lavm.combine(tgt.load(%address)); |
| # lavm.load(%address) : lavm.extract(tgt.load(%address)); |
| lavm.store(%value, %address) : tgt.store(%value, %address); |
| |
| # want: |
| lavm.matmul(%dst, %a, %b) : lavm.store(tgt.matmul(lavm.load(%a), |
| lavm.load(%b)), |
| %dst); |
| |
| # ...no: get rid of these |
| # lavm.combine(%a) : pseudo.combine(%a); |
| # lavm.extract(%a) : pseudo.extract(%a); |
| # lavm.combine(%a) : pseudo.combine(pseudo.combine(%a)); |
| # lavm.extract(%a) : pseudo.extract(pseudo.extract(%a)); |
| # lavm.combine(%a) : pseudo.combine(pseudo.combine(pseudo.combine(%a))); |
| # lavm.extract(%a) : pseudo.extract(pseudo.extract(pseudo.extract(%a))); |
| # |
| # lavm.matmulY(%dst, %a, %b) : lavm.store(lavm.extract( |
| # tgt.matmul(lavm.combine(lavm.load(%a)), |
| # lavm.combine(lavm.load(%b)))), |
| # %dst); |
| |
| # ... also no |
| # lavm.combineX(%a) : pseudo.combine_f32_to_128xf32(%a); |
| # lavm.combineX(%a) : pseudo.combine_f32_to_8x128xf32(%a); |
| # lavm.combineX(%a) : pseudo.combine_f32_to_128x128xf32(%a); |
| # lavm.combineX(%a) : pseudo.combine_128xf32_to_8x128xf32(%a); |
| # lavm.combineX(%a) : pseudo.combine_128xf32_to_128x128xf32(%a); |
| # lavm.combineX(%a) : pseudo.combine_8x128xf32_to_128x128xf32(%a); |
| # lavm.extractX(%a) : pseudo.extract_f32_from_128xf32(%a); |
| # lavm.extractX(%a) : pseudo.extract_f32_from_8x128xf32(%a); |
| # lavm.extractX(%a) : pseudo.extract_f32_from_128x128xf32(%a); |
| # lavm.extractX(%a) : pseudo.extract_128xf32_from_8x128xf32(%a); |
| # lavm.extractX(%a) : pseudo.extract_128xf32_from_128x128xf32(%a); |
| # lavm.extractX(%a) : pseudo.extract_8x128xf32_from_128x128xf32(%a); |
| # |
| # lavm.matmulX(%dst, %a, %b) : lavm.store(lavm.extractX( |
| # tgt.matmul(lavm.combineX(lavm.load(%a)), |
| # lavm.combineX(lavm.load(%b)))), |
| # %dst); |
| |
| # lavm.add(%a, %b) : tgt.add(%a, %a); # should cause a warning that %b is unused (implemented in wrong place, too late) |
| # lavm.add(%a, %b) : tgt.add(lavm.neg(%a), lavm.neg(%a)); # should cause an earlier warning |
| # lavm.add(%a, %b) : tgt.boo(%a, %b); # should be an error because tgt.boo is undefined |
| # lavm.add(%a, %b) : tgt.add(tgt.boo(%a), %b); # should be an error: undefined tgt.boo (implemented in wrong place, too late) |
| # lavm.add(%a, %b) : tgt.add(%a, %c); # this is an error! should be caught (NYI) |
| } |
| |
| ### Target memory description. ### |
| memory { |
| |
| # TPU-like memory: |
| HBM : size = 16G; |
| HBM : garbage = ignored; |
| VMEM : size = 16M; |
| SMEM : size = 16K; |
| CMEM : size = 16M; |
| HBM -> VMEM; |
| VMEM -> HBM; |
| HBM -> SMEM; |
| SMEM -> HBM; |
| HBM -> CMEM -> VMEM; |
| VMEM -> CMEM -> HBM; |
| # GPU-like memory: |
| GLOBAL : size = 8G; |
| SHARED : size = 16M; |
| LOCAL : size = 1M; |
| # CPU-like memory: |
| MEMORY : size = 64GB; |
| $L1 : size = 512K; |
| $L2 : size = 4M; |
| MEMORY -> $L2 -> $L1; |
| $L1 -> $L2 -> MEMORY; |
| |
| # specifying |
| # - on the mem side: specify what is address, what is the value, stride, banks |
| # - on the register(?) side: size, stride, zero/sign-extended, mask, if can partially fill target reg |
| # - load/store latency, size/granularity |
| # - dma latency, size/granularity |
| # - cache latency, size/granularity, policy, associativity, on which path |
| # some properties are memory instructions' properties, can use attributes to represent them |
| # |
| # Interface: cache levels/sizes/latencies/lane sizes/policies |
| # memory spaces, dmas between memory spaces, sizes, etc. |
| # |
| # how to specify direct access to memory from other instructions? using implicit addresses? chained instructions? |
| # for ex. a chain of instructions, which process in memory data, and each modifies the implicit pointer |
| # used in the next instruction. |
| # |
| # how to represent memory local to PEs? |
| # how to represent in-memory calculations? |
| |
| # name [-> $name -> ...] -> 'reg : load instructions; # do not use register? |
| # 'reg [-> $name -> ...] -> name : store instructions; |
| # name -> name : dma instructions; |
| |
| # TPU-like memory: |
| # SMEM -> 'reg : sload(%address); |
| # 'reg -> SMEM : sstore(%value, %address); |
| # VMEM -> 'reg : vload(%address); # stride, mask |
| # 'reg -> VMEM : vstore(%value, %address); # stride, mask |
| # HBM -> VMEM : dma(%from, %to, %size) size0^latency0, size1^latency1; # syntax? |
| # VMEM -> HBM : dma(); |
| # HBM -> SMEM : dma(); |
| # SMEM -> HBM : dma(); |
| |
| # GPU-like memory: |
| # GLOBAL -> 'reg : ld(%address); |
| # SHARED -> 'reg : lds(%address); |
| # LOCAL -> 'reg : ldl(%address); |
| # 'reg -> GLOBAL : st(%value, %address); |
| # 'reg -> SHARED : sts(%value, %address); |
| # 'reg -> LOCAL : stl(%value, %address); |
| |
| # CPU-like memory: |
| # MEMORY -> $L2 -> $L1 -> 'reg : load(%address); |
| # 'reg -> $L1 -> $L2 -> MEMORY : store(%value, %address); |
| } |