diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXIntrinsics.td')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 447 |
1 files changed, 380 insertions, 67 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 8ccd47c0fcfd..de4bf2ef3055 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -274,6 +274,22 @@ defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<Int32Regs, "b32", int_nvvm_match_all_s defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<Int64Regs, "b64", int_nvvm_match_all_sync_i64p, i64imm>; +multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> { + def : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$mask), + "redux.sync." # BinOp # "." # PTXType # " $dst, $src, $mask;", + [(set Int32Regs:$dst, (Intrin Int32Regs:$src, Int32Regs:$mask))]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm REDUX_SYNC_UMIN : REDUX_SYNC<"min", "u32", int_nvvm_redux_sync_umin>; +defm REDUX_SYNC_UMAX : REDUX_SYNC<"max", "u32", int_nvvm_redux_sync_umax>; +defm REDUX_SYNC_ADD : REDUX_SYNC<"add", "s32", int_nvvm_redux_sync_add>; +defm REDUX_SYNC_MIN : REDUX_SYNC<"min", "s32", int_nvvm_redux_sync_min>; +defm REDUX_SYNC_MAX : REDUX_SYNC<"max", "s32", int_nvvm_redux_sync_max>; +defm REDUX_SYNC_AND : REDUX_SYNC<"and", "b32", int_nvvm_redux_sync_and>; +defm REDUX_SYNC_XOR : REDUX_SYNC<"xor", "b32", int_nvvm_redux_sync_xor>; +defm REDUX_SYNC_OR : REDUX_SYNC<"or", "b32", int_nvvm_redux_sync_or>; + } // isConvergent = true //----------------------------------- @@ -289,6 +305,211 @@ def INT_MEMBAR_SYS : MEMBAR<"membar.sys;", int_nvvm_membar_sys>; //----------------------------------- +// Async Copy Functions +//----------------------------------- + +multiclass CP_ASYNC_MBARRIER_ARRIVE<string NoInc, string AddrSpace, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs), (ins Int32Regs:$addr), + !strconcat("cp.async.mbarrier.arrive", NoInc, AddrSpace, ".b64 [$addr];"), + [(Intrin Int32Regs:$addr)]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs), (ins Int64Regs:$addr), + !strconcat("cp.async.mbarrier.arrive", NoInc, AddrSpace, ".b64 [$addr];"), + [(Intrin Int64Regs:$addr)]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm CP_ASYNC_MBARRIER_ARRIVE : + CP_ASYNC_MBARRIER_ARRIVE<"", "", int_nvvm_cp_async_mbarrier_arrive>; +defm CP_ASYNC_MBARRIER_ARRIVE_SHARED : + CP_ASYNC_MBARRIER_ARRIVE<"", ".shared", int_nvvm_cp_async_mbarrier_arrive_shared>; +defm CP_ASYNC_MBARRIER_ARRIVE_NOINC : + CP_ASYNC_MBARRIER_ARRIVE<".noinc", "", int_nvvm_cp_async_mbarrier_arrive_noinc>; +defm CP_ASYNC_MBARRIER_ARRIVE_NOINC_SHARED : + CP_ASYNC_MBARRIER_ARRIVE<".noinc", ".shared", int_nvvm_cp_async_mbarrier_arrive_noinc_shared>; + +multiclass CP_ASYNC_CA_SHARED_GLOBAL_I<string cpsize, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs), (ins Int32Regs:$dst, Int32Regs:$src), + !strconcat("cp.async.ca.shared.global [$dst], [$src], ", cpsize, ";"), + [(Intrin Int32Regs:$dst, Int32Regs:$src)]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs), (ins Int64Regs:$dst, Int64Regs:$src), + !strconcat("cp.async.ca.shared.global [$dst], [$src], ", cpsize, ";"), + [(Intrin Int64Regs:$dst, Int64Regs:$src)]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm CP_ASYNC_CA_SHARED_GLOBAL_4 : + CP_ASYNC_CA_SHARED_GLOBAL_I<"4", int_nvvm_cp_async_ca_shared_global_4>; + +defm CP_ASYNC_CA_SHARED_GLOBAL_8 : + CP_ASYNC_CA_SHARED_GLOBAL_I<"8", int_nvvm_cp_async_ca_shared_global_8>; + +defm CP_ASYNC_CA_SHARED_GLOBAL_16 : + CP_ASYNC_CA_SHARED_GLOBAL_I<"16", int_nvvm_cp_async_ca_shared_global_16>; + +multiclass CP_ASYNC_CG_SHARED_GLOBAL<string cpsize, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs), (ins Int32Regs:$dst, Int32Regs:$src), + !strconcat("cp.async.cg.shared.global [$dst], [$src], ", cpsize, ";"), + [(Intrin Int32Regs:$dst, Int32Regs:$src)]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs), (ins Int64Regs:$dst, Int64Regs:$src), + !strconcat("cp.async.cg.shared.global [$dst], [$src], ", cpsize, ";"), + [(Intrin Int64Regs:$dst, Int64Regs:$src)]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm CP_ASYNC_CG_SHARED_GLOBAL_16 : + CP_ASYNC_CG_SHARED_GLOBAL<"16", int_nvvm_cp_async_cg_shared_global_16>; + +def CP_ASYNC_COMMIT_GROUP : + NVPTXInst<(outs), (ins), "cp.async.commit_group;", [(int_nvvm_cp_async_commit_group)]>, + Requires<[hasPTX70, hasSM80]>; + +def CP_ASYNC_WAIT_GROUP : + NVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group $n;", + [(int_nvvm_cp_async_wait_group (i32 timm:$n))]>, + Requires<[hasPTX70, hasSM80]>; + +def CP_ASYNC_WAIT_ALL : + NVPTXInst<(outs), (ins), "cp.async.wait_all;", + [(int_nvvm_cp_async_wait_all)]>, + Requires<[hasPTX70, hasSM80]>; + +//----------------------------------- +// MBarrier Functions +//----------------------------------- + +multiclass MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs), (ins Int32Regs:$addr, Int32Regs:$count), + !strconcat("mbarrier.init", AddrSpace, ".b64 [$addr], $count;"), + [(Intrin Int32Regs:$addr, Int32Regs:$count)]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs), (ins Int64Regs:$addr, Int32Regs:$count), + !strconcat("mbarrier.init", AddrSpace, ".b64 [$addr], $count;"), + [(Intrin Int64Regs:$addr, Int32Regs:$count)]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; +defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", + int_nvvm_mbarrier_init_shared>; + +multiclass MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs), (ins Int32Regs:$addr), + !strconcat("mbarrier.inval", AddrSpace, ".b64 [$addr];"), + [(Intrin Int32Regs:$addr)]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs), (ins Int64Regs:$addr), + !strconcat("mbarrier.inval", AddrSpace, ".b64 [$addr];"), + [(Intrin Int64Regs:$addr)]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; +defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", + int_nvvm_mbarrier_inval_shared>; + +multiclass MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs Int64Regs:$state), (ins Int32Regs:$addr), + !strconcat("mbarrier.arrive", AddrSpace, ".b64 $state, [$addr];"), + [(set Int64Regs:$state, (Intrin Int32Regs:$addr))]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs Int64Regs:$state), (ins Int64Regs:$addr), + !strconcat("mbarrier.arrive", AddrSpace, ".b64 $state, [$addr];"), + [(set Int64Regs:$state, (Intrin Int64Regs:$addr))]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; +defm MBARRIER_ARRIVE_SHARED : + MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; + +multiclass MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs Int64Regs:$state), + (ins Int32Regs:$addr, Int32Regs:$count), + !strconcat("mbarrier.arrive.noComplete", AddrSpace, + ".b64 $state, [$addr], $count;"), + [(set Int64Regs:$state, (Intrin Int32Regs:$addr, Int32Regs:$count))]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs Int64Regs:$state), + (ins Int64Regs:$addr, Int32Regs:$count), + !strconcat("mbarrier.arrive.noComplete", AddrSpace, + ".b64 $state, [$addr], $count;"), + [(set Int64Regs:$state, (Intrin Int64Regs:$addr, Int32Regs:$count))]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm MBARRIER_ARRIVE_NOCOMPLETE : + MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; +defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; + +multiclass MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs Int64Regs:$state), (ins Int32Regs:$addr), + !strconcat("mbarrier.arrive_drop", AddrSpace, + ".b64 $state, [$addr];"), + [(set Int64Regs:$state, (Intrin Int32Regs:$addr))]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs Int64Regs:$state), (ins Int64Regs:$addr), + !strconcat("mbarrier.arrive_drop", AddrSpace, + ".b64 $state, [$addr];"), + [(set Int64Regs:$state, (Intrin Int64Regs:$addr))]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm MBARRIER_ARRIVE_DROP : + MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; +defm MBARRIER_ARRIVE_DROP_SHARED : + MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; + +multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs Int64Regs:$state), + (ins Int32Regs:$addr, Int32Regs:$count), + !strconcat("mbarrier.arrive_drop.noComplete", AddrSpace, + ".b64 $state, [$addr], $count;"), + [(set Int64Regs:$state, (Intrin Int32Regs:$addr, Int32Regs:$count))]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs Int64Regs:$state), + (ins Int64Regs:$addr, Int32Regs:$count), + !strconcat("mbarrier.arrive_drop.noComplete", AddrSpace, + ".b64 $state, [$addr], $count;"), + [(set Int64Regs:$state, (Intrin Int64Regs:$addr, Int32Regs:$count))]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm MBARRIER_ARRIVE_DROP_NOCOMPLETE : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; +defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", + int_nvvm_mbarrier_arrive_drop_noComplete_shared>; + +multiclass MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> { + def _32 : NVPTXInst<(outs Int1Regs:$res), (ins Int32Regs:$addr, Int64Regs:$state), + !strconcat("mbarrier.test_wait", AddrSpace, ".b64 $res, [$addr], $state;"), + [(set Int1Regs:$res, (Intrin Int32Regs:$addr, Int64Regs:$state))]>, + Requires<[hasPTX70, hasSM80]>; + def _64 : NVPTXInst<(outs Int1Regs:$res), (ins Int64Regs:$addr, Int64Regs:$state), + !strconcat("mbarrier.test_wait", AddrSpace, ".b64 $res, [$addr], $state;"), + [(set Int1Regs:$res, (Intrin Int64Regs:$addr, Int64Regs:$state))]>, + Requires<[hasPTX70, hasSM80]>; +} + +defm MBARRIER_TEST_WAIT : + MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; +defm MBARRIER_TEST_WAIT_SHARED : + MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; + +class MBARRIER_PENDING_COUNT<Intrinsic Intrin> : + NVPTXInst<(outs Int32Regs:$res), (ins Int64Regs:$state), + "mbarrier.pending_count.b64 $res, $state;", + [(set Int32Regs:$res, (Intrin Int64Regs:$state))]>, + Requires<[hasPTX70, hasSM80]>; + +def MBARRIER_PENDING_COUNT : + MBARRIER_PENDING_COUNT<int_nvvm_mbarrier_pending_count>; + +//----------------------------------- // Math Functions //----------------------------------- @@ -1722,21 +1943,21 @@ multiclass VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> { !strconcat("ldu.global.", TyStr), []>; } -multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { +multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int32Regs:$src), + regclass:$dst4), (ins Int32Regs:$src), !strconcat("ldu.global.", TyStr), []>; def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int64Regs:$src), + regclass:$dst4), (ins Int64Regs:$src), !strconcat("ldu.global.", TyStr), []>; def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins MEMri:$src), + regclass:$dst4), (ins MEMri:$src), !strconcat("ldu.global.", TyStr), []>; def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins MEMri64:$src), + regclass:$dst4), (ins MEMri64:$src), !strconcat("ldu.global.", TyStr), []>; def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins imemAny:$src), + regclass:$dst4), (ins imemAny:$src), !strconcat("ldu.global.", TyStr), []>; } @@ -1776,7 +1997,7 @@ defm INT_PTX_LDU_G_v4f32_ELE //----------------------------------- -// Support for ldg on sm_35 or later +// Support for ldg on sm_35 or later //----------------------------------- // Don't annotate ld.global.nc as mayLoad, because these loads go through the @@ -1824,7 +2045,7 @@ defm INT_PTX_LDG_GLOBAL_p64 // vector -// Elementized vector ldg +// Elementized vector ldg multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> { def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), (ins Int32Regs:$src), @@ -1843,21 +2064,21 @@ multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> { !strconcat("ld.global.nc.", TyStr), []>; } -multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { +multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int32Regs:$src), + regclass:$dst4), (ins Int32Regs:$src), !strconcat("ld.global.nc.", TyStr), []>; def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int64Regs:$src), + regclass:$dst4), (ins Int64Regs:$src), !strconcat("ld.global.nc.", TyStr), []>; def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins MEMri:$src), + regclass:$dst4), (ins MEMri:$src), !strconcat("ld.global.nc.", TyStr), []>; def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins MEMri64:$src), + regclass:$dst4), (ins MEMri64:$src), !strconcat("ld.global.nc.", TyStr), []>; def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins imemAny:$src), + regclass:$dst4), (ins imemAny:$src), !strconcat("ld.global.nc.", TyStr), []>; } @@ -7347,12 +7568,15 @@ def INT_PTX_SREG_WARPSIZE : // In addition to target-independent fields provided by WMMA_REGS, it adds // the fields commonly used to implement specific PTX instruction -- register // types and names, constraints, parts of assembly, etc. -class WMMA_REGINFO<WMMA_REGS r> +class WMMA_REGINFO<WMMA_REGS r, string op> : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> { // NVPTX register types used to carry fragment data. NVPTXRegClass regclass = !cond( !eq(ptx_elt_type, "f16") : Float16x2Regs, !eq(ptx_elt_type, "f32") : Float32Regs, + !eq(ptx_elt_type, "f64") : Float64Regs, + !eq(ptx_elt_type, "bf16") : Int32Regs, + !eq(ptx_elt_type, "tf32") : Int32Regs, !eq(ptx_elt_type, "s32") : Int32Regs, !eq(ptx_elt_type, "s8") : Int32Regs, !eq(ptx_elt_type, "u8") : Int32Regs, @@ -7381,6 +7605,9 @@ class WMMA_REGINFO<WMMA_REGS r> !or(!eq(ptx_elt_type, "f16"), !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60], + !and(!eq(geom,"m8n8k4"), + !eq(ptx_elt_type, "f64")) : [hasSM80, hasPTX70], + // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16 !and(!or(!eq(geom, "m8n32k16"), !eq(geom, "m32n8k16")), @@ -7395,11 +7622,46 @@ class WMMA_REGINFO<WMMA_REGS r> !eq(ptx_elt_type, "s8"), !eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63], - // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) - !or(!eq(geom,"m8n8k128"), - !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63], + !and(!or(!eq(geom,"m16n16k16"), + !eq(geom,"m8n32k16"), + !eq(geom,"m32n8k16")), + !eq(ptx_elt_type, "bf16")) : [hasSM80, hasPTX70], + + !and(!eq(geom,"m16n16k8"), + !eq(ptx_elt_type, "tf32")) : [hasSM80, hasPTX70], + + !and(!eq(geom,"m16n16k8"), + !eq(ptx_elt_type, "f32")) : [hasSM80, hasPTX70], - !eq(geom, "m8n8k4") : [hasSM70, hasPTX64]); + // b1 -> s32 @ m8n8k128(b1) + !and(!ne(op,"mma"), + !eq(geom,"m8n8k128")) : [hasSM75, hasPTX63], + + // u4/s4 -> s32 @ m8n8k32 (u4/s4) + !and(!ne(op,"mma"), + !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63], + + !or(!eq(geom,"m16n8k8"), + !eq(geom,"m8n8k16")) : [hasSM75, hasPTX65], + + !and(!ne(ptx_elt_type,"f64"), + !eq(geom, "m8n8k4")) : [hasSM70, hasPTX64], + + // mma m8n8k32 requires higher PTX version + !and(!eq(op,"mma"), + !eq(geom,"m8n8k32")) : [hasSM75, hasPTX65], + + !and(!eq(ptx_elt_type,"f64"), + !eq(geom, "m8n8k4")) : [hasSM80, hasPTX70], + + !and(!eq(op,"mma"), + !or(!eq(geom, "m16n8k16"), + !eq(geom, "m16n8k4"), + !eq(geom, "m16n8k32"), + !eq(geom, "m16n8k64"), + !eq(geom, "m8n8k128"), + !eq(geom, "m16n8k128"), + !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -7523,60 +7785,109 @@ defset list<WMMA_INSTR> MMA_LDSTs = { foreach space = [".global", ".shared", ""] in { foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { foreach frag = NVVM_MMA_OPS.all_ld_ops in - if NVVM_MMA_SUPPORTED<[frag], layout>.ret then - def : WMMA_LOAD<WMMA_REGINFO<frag>, layout, space, stride, addr>; + if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then + def : WMMA_LOAD<WMMA_REGINFO<frag, "load">, layout, space, stride, addr>; foreach frag = NVVM_MMA_OPS.all_st_ops in - if NVVM_MMA_SUPPORTED<[frag], layout>.ret then - def : WMMA_STORE_D<WMMA_REGINFO<frag>, layout, space, stride, addr>; + if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then + def : WMMA_STORE_D<WMMA_REGINFO<frag, "store">, layout, space, stride, addr>; } // addr } // space } // stride } // layout } // defset +// B1 instruction variants need extra constraints. +class MMA_OP_PREDICATES<WMMA_REGINFO FragA, string b1op> { + string Op = b1op; + WMMA_REGINFO Frag = FragA; + list<Predicate> ret = !listconcat( + FragA.Predicates, + !if(!eq(b1op, ".and.popc"), [hasSM80,hasPTX71],[]) + ); +} // WMMA.MMA class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, - string ALayout, string BLayout, int Satfinite> - : WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record, - [FragA.Ins, FragB.Ins, FragC.Ins]>, + string ALayout, string BLayout, int Satfinite, string rnd, string b1op> + : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record, + [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. - Requires<FragA.Predicates> { + Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> { let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); string TypeList = !cond( - !eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type - # ".f16.f16." - # FragC.ptx_elt_type, - !eq(FragD.ptx_elt_type, "s32") : ".s32" - # "." # FragA.ptx_elt_type - # "." # FragB.ptx_elt_type - # ".s32", - 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type, + !eq(FragA.ptx_elt_type, "f16") : "." # FragD.ptx_elt_type + # "." # FragC.ptx_elt_type, + 1: "." # FragD.ptx_elt_type + # "." # FragA.ptx_elt_type + # "." # FragB.ptx_elt_type + # "." # FragC.ptx_elt_type, ); - let AsmString = !if(!eq(FragA.geom, "m8n8k4"), - "mma.sync.aligned.m8n8k4" - # "." # ALayout - # "." # BLayout - # TypeList # "\n\t\t" - # FragD.regstring # ",\n\t\t" - # FragA.regstring # ",\n\t\t" - # FragB.regstring # ",\n\t\t" - # FragC.regstring # ";", - "wmma.mma" - # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") - # ".sync" - # "${ptx:aligned}" - # "." # ALayout - # "." # BLayout - # "." # FragA.geom - # TypeList - # !if(Satfinite, ".satfinite", "") # "\n\t\t" - # FragD.regstring # ",\n\t\t" - # FragA.regstring # ",\n\t\t" - # FragB.regstring # ",\n\t\t" - # FragC.regstring # ";"); + let AsmString = "wmma.mma" + # b1op + # ".sync" + # "${ptx:aligned}" + # "." # ALayout + # "." # BLayout + # "." # FragA.geom + # !if(!ne(rnd, ""), !strconcat(".", rnd), "") + # TypeList + # !if(Satfinite, ".satfinite", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"; +} + +defset list<WMMA_INSTR> WMMAs = { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach satf = [0, 1] in { + foreach rnd = ["", "rn", "rz", "rm", "rp"] in { + foreach op = NVVM_MMA_OPS.all_wmma_ops in { + foreach b1op = NVVM_MMA_B1OPS<op>.ret in { + if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then { + def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">, + WMMA_REGINFO<op[1], "wmma.mma">, + WMMA_REGINFO<op[2], "wmma.mma">, + WMMA_REGINFO<op[3], "wmma.mma">, + layout_a, layout_b, satf, rnd, b1op>; + } + } // b1op + } // op + } // rnd + } // satf + } // layout_b + } // layout_a +} // defset + +// MMA +class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, + WMMA_REGINFO FragC, WMMA_REGINFO FragD, + string ALayout, string BLayout, int Satfinite, string b1op> + : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record, + [FragA.Ins, FragB.Ins, FragC.Ins]>, + // Requires does not seem to have effect on Instruction w/o Patterns. + // We set it here anyways and propagate to the Pat<> we construct below. + Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> { + let OutOperandList = FragD.Outs; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + string TypeList = "." # FragD.ptx_elt_type + # "." # FragA.ptx_elt_type + # "." # FragB.ptx_elt_type + # "." # FragC.ptx_elt_type; + let AsmString = "mma.sync.aligned." + # FragA.geom + # "." # ALayout + # "." # BLayout + # !if(Satfinite, ".satfinite", "") + # TypeList + # b1op # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"; } defset list<WMMA_INSTR> MMAs = { @@ -7584,13 +7895,15 @@ defset list<WMMA_INSTR> MMAs = { foreach layout_b = ["row", "col"] in { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { - if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then { - def : WMMA_MMA<WMMA_REGINFO<op[0]>, - WMMA_REGINFO<op[1]>, - WMMA_REGINFO<op[2]>, - WMMA_REGINFO<op[3]>, - layout_a, layout_b, satf>; - } + foreach b1op = NVVM_MMA_B1OPS<op>.ret in { + if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then { + def : MMA<WMMA_REGINFO<op[0], "mma">, + WMMA_REGINFO<op[1], "mma">, + WMMA_REGINFO<op[2], "mma">, + WMMA_REGINFO<op[3], "mma">, + layout_a, layout_b, satf, b1op>; + } + } // b1op } // op } // satf } // layout_b @@ -7601,12 +7914,12 @@ defset list<WMMA_INSTR> MMAs = { // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. -class WMMA_PAT<WMMA_INSTR wi> +class MMA_PAT<WMMA_INSTR wi> : Pat<wi.IntrinsicPattern, !con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)), (wi ptx.version))>, Requires<wi.Predicates>; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, MMA_LDSTs) in - def : WMMA_PAT<mma>; +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in + def : MMA_PAT<mma>; |