blob: 1f1a6801547d0ebc39b4ca1c46189f506eb82730 [file] [log] [blame] [view]
# TODOs:
#
# focusing on:
# - implement the "new" syntax for the target description
# - allow nested ops on the left side: lavm.sub(tgt.neg(%a), %b) : lavm.add(%a, tgt.neg(%b));
# - support operations with same name but different number of arguments
# - error on wrong number of arguments in map section and in target sections
# - catch duplicate argument names in the target section
# - catch mismatch in the argument count and the type in the target section
#
# functionality:
# -- define memory description section
# -- allow [parallel] loops / ops with regions?
# -- generate machine description C++ class
# -- generate patterns for structured op tiling
# - generate td file with the target dialect?
# - allow sequences in rhs of map section: (op0(args...), op1(args), ...) so that we can chain buffer-level ops
# - have type containers? (as in vreg, sreg, etc.)
# - allow immediates
# - support naming: lavm.add(%a, %b) : tgt.add(tgt.neg(%a):$n, tgt.neg($n)));
# - allow instantiations for specific types only: lavm.add(%a : f32, %b : f32) -> f32 : zzz;
# - how to express instructions operating on memory directly?
# - allow arbitrary variable names and various parameter counts(?)
# - support default values for arguments: lavm.add(%a, %b = 42)...
# - detect infinite cycles in expansions
# - how to represent denorm/NaN/rounding/saturating behavior?
# - add support for AnyType, which is helpful to define MEM <-> AnyType loads/stores and op(%a) : %a empty expansions
# - ??? change syntax to (lavm.matmul %dst %a %b) : (lavm.store (tgt.matmul (lavm.load %a) (lavm.load %b)) %dst);
# - !!! change syntax to be target op-centric instead of lavm-op centric (define target ops rather than mappings
# from lavm ops)
# - allow lavm ops to drop "lavm." prefix?
# - add pre-checkin tests
# - re-implement the target description parser
#
# error checking:
# - emit error at parsing *and stop* when ';' is missing at the end of the line
# (this is a frequent cause of the late assert 'GetTargetExpansions(): lavm_op != nullptr')
# - target section: remove variable names in lhs? as in tgt.nuts(%, %) or just get rid of the args on the lhs?
# - catch duplicate expansions or anything else duplicated on lhs or rhs, issue warning and ignore
# - filter out duplicates that differ only in variable names
#
# other:
# - rename 'map' section into 'expansions'?
### Operations and types supported by the target. ###
target {
tgt.nuts(%x, %y) : (aaa, bbb) -> ccc, (a, b, c) -> (x, y, z), (f32, vector<128xf32>) -> ();
tgt.nuts(%x, %y) : ((m) -> p) -> q;
tgt.nuts(%x, %y) : ((m) -> ((p) -> bar)) -> q;
tgt.nuts(%x, %y) : ((m) -> ((p) -> (bar) -> foo, (p) -> ((bar) -> foo), ((p) -> (bar)) -> foo)) -> ((q) -> r);
tgt.nuts(%x, %y) : (m, ((n) -> ()) -> p) -> ((q) -> r);
tgt.i2f(%x) : (i32) -> f32, (vector<8x128xi32>) -> vector<8x128xf32>;
tgt.f2i(%x) : (f32) -> i32, (vector<8x128xf32>) -> vector<8x128xi32>;
tgt.neg(%x) : vector<8x128xf32>, f32; # vector<128xf32>
tgt.add(%x, %y) : vector<128xf32> @ 2, vector<8x128xf32> @ 3, f32 @ 1;
tgt.sub(%x, %y) : vector<128xf32>, vector<8x128xf32>, f32;
tgt.mul(%x, %y) : f32;
tgt.dot(%x, %y) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32> @ 42;
tgt.matmul(%x, %y) : (vector<128x128xf32>, vector<128x128xf32>) -> vector<128x128xf32> @ 99;
tgt.wait() : () -> ();
tgt.dma(%to, %from, %size) : (memref<HBM>, memref<VMEM>, i32) -> ();
tgt.dma(%to, %from, %size) : (memref<VMEM>, memref<HBM>, i32) -> ();
tgt.dma(%to, %from, %size) : (memref<HBM>, memref<SMEM>, i32) -> ();
tgt.dma(%to, %from, %size) : (memref<SMEM>, memref<HBM>, i32) -> ();
tgt.load(%address) : (memref<SMEM>) -> f32, (memref<SMEM>) -> i32;
tgt.load(%address) : (memref<VMEM>) -> vector<128xf32>, (memref<VMEM>) -> vector<128xi32>;
tgt.load(%address) : (memref<VMEM>) -> vector<8x128xf32>, (memref<VMEM>) -> vector<8x128xi32>;
tgt.store(%value, %address) : (f32, memref<SMEM>) -> (), (i32, memref<SMEM>) -> ();
tgt.store(%value, %address) : (vector<128xf32>, memref<VMEM>) -> (), (vector<128xi32>, memref<VMEM>) -> ();
tgt.store(%value, %address) : (vector<8x128xf32>, memref<VMEM>) -> (), (vector<8x128xi32>, memref<VMEM>) -> ();
# pseudo.combine(%x) : f32 -> vector<128xf32>;
# pseudo.extract(%x) : vector<128xf32> -> f32;
# pseudo.combine(%x) : vector<128xf32> -> vector<8x128xf32>;
# pseudo.extract(%x) : vector<8x128xf32> -> vector<128xf32>;
# pseudo.combine(%x) : vector<8x128xf32> -> vector<128x128xf32>;
# pseudo.extract(%x) : vector<128x128xf32> -> vector<8x128xf32>;
# pseudo.combine_f32_to_128xf32(%x) : f32 -> vector<128xf32>;
# pseudo.combine_f32_to_8x128xf32(%x) : f32 -> vector<8x128xf32>;
# pseudo.combine_f32_to_128x128xf32(%x) : f32 -> vector<128x128xf32>;
# pseudo.combine_128xf32_to_8x128xf32(%x) : vector<128xf32> -> vector<8x128xf32>;
# pseudo.combine_128xf32_to_128x128xf32(%x) : vector<128xf32> -> vector<128x128xf32>;
# pseudo.combine_8x128xf32_to_128x128xf32(%x) : vector<8x128xf32> -> vector<128x128xf32>;
# pseudo.extract_f32_from_128xf32(%x) : vector<128xf32> -> f32;
# pseudo.extract_f32_from_8x128xf32(%x) : vector<8x128xf32> -> f32;
# pseudo.extract_f32_from_128x128xf32(%x) : vector<128x128xf32> -> f32;
# pseudo.extract_128xf32_from_8x128xf32(%x) : vector<8x128xf32> -> vector<128xf32>;
# pseudo.extract_128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<128xf32>;
# pseudo.extract_8x128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<8x128xf32>;
}
### Map LAVM ops to previously defined target ops or their combinations. ###
# current restriction is that lavm ops on the lhs should appear with same arg names and count
map {
lavm.neg(%a) : tgt.neg(%a);
lavm.neg(%a) : tgt.neg(tgt.i2f(%a));
# not yet supported
# lavm.none(%a) : %a;
# lavm.zadd(%a, %b) : tgt.add(lavm.none(%a), %b);
lavm.add(%a, %b) : tgt.add(%a, %b),
tgt.sub(%a, tgt.neg(%b)),
tgt.add(tgt.i2f(%a), tgt.i2f(%b)),
tgt.add(tgt.i2f(%b), %a), tgt.add(%b, tgt.i2f(%a)),
tgt.add(%b, tgt.add(tgt.add(%a, %b), tgt.neg(%b)));
lavm.add(%a, %b) : lavm.neg(tgt.add(lavm.neg(%a), lavm.neg(%b))),
lavm.neg(tgt.add(lavm.neg(tgt.f2i(%b)), tgt.neg(%a)));
lavm.sub(%a, %b) : lavm.add(%a, tgt.neg(%b));
# lavm.add_x(%dst, %a, %b) : # tgt.dma(
# tgt.store(tgt.add(tgt.load(tgt.dma(%a)),
# tgt.load(tgt.dma(%b))),
# %dst)
# # )
# ;
# lavm.matmul(%dst : memref<HBM>, %a : memref<HBM>, %b : memref<HBM>) :
# %aa = alloc() : memref<VMEM>,
# %bb = alloc() : memref<VMEM>,
# %cc = alloc() : memref<VMEM>,
# tgt.dma(%a, %aa),
# tgt.dma(%b, %bb),
# lavm.store(lavm.matmul(lavm.load(%aa), lavm.load(%bb)), %cc)
# tgt.dma(%cc, %dst); # dealloc???
lavm.dma(%to, %from, %size) : tgt.dma(%to, %from, %size);
lavm.load(%address) : tgt.load(%address);
# lavm.load(%address) : lavm.combine(tgt.load(%address));
# lavm.load(%address) : lavm.extract(tgt.load(%address));
lavm.store(%value, %address) : tgt.store(%value, %address);
# want:
lavm.matmul(%dst, %a, %b) : lavm.store(tgt.matmul(lavm.load(%a),
lavm.load(%b)),
%dst);
# ...no: get rid of these
# lavm.combine(%a) : pseudo.combine(%a);
# lavm.extract(%a) : pseudo.extract(%a);
# lavm.combine(%a) : pseudo.combine(pseudo.combine(%a));
# lavm.extract(%a) : pseudo.extract(pseudo.extract(%a));
# lavm.combine(%a) : pseudo.combine(pseudo.combine(pseudo.combine(%a)));
# lavm.extract(%a) : pseudo.extract(pseudo.extract(pseudo.extract(%a)));
#
# lavm.matmulY(%dst, %a, %b) : lavm.store(lavm.extract(
# tgt.matmul(lavm.combine(lavm.load(%a)),
# lavm.combine(lavm.load(%b)))),
# %dst);
# ... also no
# lavm.combineX(%a) : pseudo.combine_f32_to_128xf32(%a);
# lavm.combineX(%a) : pseudo.combine_f32_to_8x128xf32(%a);
# lavm.combineX(%a) : pseudo.combine_f32_to_128x128xf32(%a);
# lavm.combineX(%a) : pseudo.combine_128xf32_to_8x128xf32(%a);
# lavm.combineX(%a) : pseudo.combine_128xf32_to_128x128xf32(%a);
# lavm.combineX(%a) : pseudo.combine_8x128xf32_to_128x128xf32(%a);
# lavm.extractX(%a) : pseudo.extract_f32_from_128xf32(%a);
# lavm.extractX(%a) : pseudo.extract_f32_from_8x128xf32(%a);
# lavm.extractX(%a) : pseudo.extract_f32_from_128x128xf32(%a);
# lavm.extractX(%a) : pseudo.extract_128xf32_from_8x128xf32(%a);
# lavm.extractX(%a) : pseudo.extract_128xf32_from_128x128xf32(%a);
# lavm.extractX(%a) : pseudo.extract_8x128xf32_from_128x128xf32(%a);
#
# lavm.matmulX(%dst, %a, %b) : lavm.store(lavm.extractX(
# tgt.matmul(lavm.combineX(lavm.load(%a)),
# lavm.combineX(lavm.load(%b)))),
# %dst);
# lavm.add(%a, %b) : tgt.add(%a, %a); # should cause a warning that %b is unused (implemented in wrong place, too late)
# lavm.add(%a, %b) : tgt.add(lavm.neg(%a), lavm.neg(%a)); # should cause an earlier warning
# lavm.add(%a, %b) : tgt.boo(%a, %b); # should be an error because tgt.boo is undefined
# lavm.add(%a, %b) : tgt.add(tgt.boo(%a), %b); # should be an error: undefined tgt.boo (implemented in wrong place, too late)
# lavm.add(%a, %b) : tgt.add(%a, %c); # this is an error! should be caught (NYI)
}
### Target memory description. ###
memory {
# TPU-like memory:
HBM : size = 16G;
HBM : garbage = ignored;
VMEM : size = 16M;
SMEM : size = 16K;
CMEM : size = 16M;
HBM -> VMEM;
VMEM -> HBM;
HBM -> SMEM;
SMEM -> HBM;
HBM -> CMEM -> VMEM;
VMEM -> CMEM -> HBM;
# GPU-like memory:
GLOBAL : size = 8G;
SHARED : size = 16M;
LOCAL : size = 1M;
# CPU-like memory:
MEMORY : size = 64GB;
$L1 : size = 512K;
$L2 : size = 4M;
MEMORY -> $L2 -> $L1;
$L1 -> $L2 -> MEMORY;
# specifying
# - on the mem side: specify what is address, what is the value, stride, banks
# - on the register(?) side: size, stride, zero/sign-extended, mask, if can partially fill target reg
# - load/store latency, size/granularity
# - dma latency, size/granularity
# - cache latency, size/granularity, policy, associativity, on which path
# some properties are memory instructions' properties, can use attributes to represent them
#
# Interface: cache levels/sizes/latencies/lane sizes/policies
# memory spaces, dmas between memory spaces, sizes, etc.
#
# how to specify direct access to memory from other instructions? using implicit addresses? chained instructions?
# for ex. a chain of instructions, which process in memory data, and each modifies the implicit pointer
# used in the next instruction.
#
# how to represent memory local to PEs?
# how to represent in-memory calculations?
# name [-> $name -> ...] -> 'reg : load instructions; # do not use register?
# 'reg [-> $name -> ...] -> name : store instructions;
# name -> name : dma instructions;
# TPU-like memory:
# SMEM -> 'reg : sload(%address);
# 'reg -> SMEM : sstore(%value, %address);
# VMEM -> 'reg : vload(%address); # stride, mask
# 'reg -> VMEM : vstore(%value, %address); # stride, mask
# HBM -> VMEM : dma(%from, %to, %size) size0^latency0, size1^latency1; # syntax?
# VMEM -> HBM : dma();
# HBM -> SMEM : dma();
# SMEM -> HBM : dma();
# GPU-like memory:
# GLOBAL -> 'reg : ld(%address);
# SHARED -> 'reg : lds(%address);
# LOCAL -> 'reg : ldl(%address);
# 'reg -> GLOBAL : st(%value, %address);
# 'reg -> SHARED : sts(%value, %address);
# 'reg -> LOCAL : stl(%value, %address);
# CPU-like memory:
# MEMORY -> $L2 -> $L1 -> 'reg : load(%address);
# 'reg -> $L1 -> $L2 -> MEMORY : store(%value, %address);
}