aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXIntrinsics.td')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td447
1 files changed, 380 insertions, 67 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8ccd47c0fcfd..de4bf2ef3055 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -274,6 +274,22 @@ defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<Int32Regs, "b32", int_nvvm_match_all_s
defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<Int64Regs, "b64", int_nvvm_match_all_sync_i64p,
i64imm>;
+multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> {
+ def : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$mask),
+ "redux.sync." # BinOp # "." # PTXType # " $dst, $src, $mask;",
+ [(set Int32Regs:$dst, (Intrin Int32Regs:$src, Int32Regs:$mask))]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm REDUX_SYNC_UMIN : REDUX_SYNC<"min", "u32", int_nvvm_redux_sync_umin>;
+defm REDUX_SYNC_UMAX : REDUX_SYNC<"max", "u32", int_nvvm_redux_sync_umax>;
+defm REDUX_SYNC_ADD : REDUX_SYNC<"add", "s32", int_nvvm_redux_sync_add>;
+defm REDUX_SYNC_MIN : REDUX_SYNC<"min", "s32", int_nvvm_redux_sync_min>;
+defm REDUX_SYNC_MAX : REDUX_SYNC<"max", "s32", int_nvvm_redux_sync_max>;
+defm REDUX_SYNC_AND : REDUX_SYNC<"and", "b32", int_nvvm_redux_sync_and>;
+defm REDUX_SYNC_XOR : REDUX_SYNC<"xor", "b32", int_nvvm_redux_sync_xor>;
+defm REDUX_SYNC_OR : REDUX_SYNC<"or", "b32", int_nvvm_redux_sync_or>;
+
} // isConvergent = true
//-----------------------------------
@@ -289,6 +305,211 @@ def INT_MEMBAR_SYS : MEMBAR<"membar.sys;", int_nvvm_membar_sys>;
//-----------------------------------
+// Async Copy Functions
+//-----------------------------------
+
+multiclass CP_ASYNC_MBARRIER_ARRIVE<string NoInc, string AddrSpace, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs), (ins Int32Regs:$addr),
+ !strconcat("cp.async.mbarrier.arrive", NoInc, AddrSpace, ".b64 [$addr];"),
+ [(Intrin Int32Regs:$addr)]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs), (ins Int64Regs:$addr),
+ !strconcat("cp.async.mbarrier.arrive", NoInc, AddrSpace, ".b64 [$addr];"),
+ [(Intrin Int64Regs:$addr)]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm CP_ASYNC_MBARRIER_ARRIVE :
+ CP_ASYNC_MBARRIER_ARRIVE<"", "", int_nvvm_cp_async_mbarrier_arrive>;
+defm CP_ASYNC_MBARRIER_ARRIVE_SHARED :
+ CP_ASYNC_MBARRIER_ARRIVE<"", ".shared", int_nvvm_cp_async_mbarrier_arrive_shared>;
+defm CP_ASYNC_MBARRIER_ARRIVE_NOINC :
+ CP_ASYNC_MBARRIER_ARRIVE<".noinc", "", int_nvvm_cp_async_mbarrier_arrive_noinc>;
+defm CP_ASYNC_MBARRIER_ARRIVE_NOINC_SHARED :
+ CP_ASYNC_MBARRIER_ARRIVE<".noinc", ".shared", int_nvvm_cp_async_mbarrier_arrive_noinc_shared>;
+
+multiclass CP_ASYNC_CA_SHARED_GLOBAL_I<string cpsize, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs), (ins Int32Regs:$dst, Int32Regs:$src),
+ !strconcat("cp.async.ca.shared.global [$dst], [$src], ", cpsize, ";"),
+ [(Intrin Int32Regs:$dst, Int32Regs:$src)]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs), (ins Int64Regs:$dst, Int64Regs:$src),
+ !strconcat("cp.async.ca.shared.global [$dst], [$src], ", cpsize, ";"),
+ [(Intrin Int64Regs:$dst, Int64Regs:$src)]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm CP_ASYNC_CA_SHARED_GLOBAL_4 :
+ CP_ASYNC_CA_SHARED_GLOBAL_I<"4", int_nvvm_cp_async_ca_shared_global_4>;
+
+defm CP_ASYNC_CA_SHARED_GLOBAL_8 :
+ CP_ASYNC_CA_SHARED_GLOBAL_I<"8", int_nvvm_cp_async_ca_shared_global_8>;
+
+defm CP_ASYNC_CA_SHARED_GLOBAL_16 :
+ CP_ASYNC_CA_SHARED_GLOBAL_I<"16", int_nvvm_cp_async_ca_shared_global_16>;
+
+multiclass CP_ASYNC_CG_SHARED_GLOBAL<string cpsize, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs), (ins Int32Regs:$dst, Int32Regs:$src),
+ !strconcat("cp.async.cg.shared.global [$dst], [$src], ", cpsize, ";"),
+ [(Intrin Int32Regs:$dst, Int32Regs:$src)]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs), (ins Int64Regs:$dst, Int64Regs:$src),
+ !strconcat("cp.async.cg.shared.global [$dst], [$src], ", cpsize, ";"),
+ [(Intrin Int64Regs:$dst, Int64Regs:$src)]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm CP_ASYNC_CG_SHARED_GLOBAL_16 :
+ CP_ASYNC_CG_SHARED_GLOBAL<"16", int_nvvm_cp_async_cg_shared_global_16>;
+
+def CP_ASYNC_COMMIT_GROUP :
+ NVPTXInst<(outs), (ins), "cp.async.commit_group;", [(int_nvvm_cp_async_commit_group)]>,
+ Requires<[hasPTX70, hasSM80]>;
+
+def CP_ASYNC_WAIT_GROUP :
+ NVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group $n;",
+ [(int_nvvm_cp_async_wait_group (i32 timm:$n))]>,
+ Requires<[hasPTX70, hasSM80]>;
+
+def CP_ASYNC_WAIT_ALL :
+ NVPTXInst<(outs), (ins), "cp.async.wait_all;",
+ [(int_nvvm_cp_async_wait_all)]>,
+ Requires<[hasPTX70, hasSM80]>;
+
+//-----------------------------------
+// MBarrier Functions
+//-----------------------------------
+
+multiclass MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs), (ins Int32Regs:$addr, Int32Regs:$count),
+ !strconcat("mbarrier.init", AddrSpace, ".b64 [$addr], $count;"),
+ [(Intrin Int32Regs:$addr, Int32Regs:$count)]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs), (ins Int64Regs:$addr, Int32Regs:$count),
+ !strconcat("mbarrier.init", AddrSpace, ".b64 [$addr], $count;"),
+ [(Intrin Int64Regs:$addr, Int32Regs:$count)]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>;
+defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared",
+ int_nvvm_mbarrier_init_shared>;
+
+multiclass MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs), (ins Int32Regs:$addr),
+ !strconcat("mbarrier.inval", AddrSpace, ".b64 [$addr];"),
+ [(Intrin Int32Regs:$addr)]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs), (ins Int64Regs:$addr),
+ !strconcat("mbarrier.inval", AddrSpace, ".b64 [$addr];"),
+ [(Intrin Int64Regs:$addr)]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>;
+defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared",
+ int_nvvm_mbarrier_inval_shared>;
+
+multiclass MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs Int64Regs:$state), (ins Int32Regs:$addr),
+ !strconcat("mbarrier.arrive", AddrSpace, ".b64 $state, [$addr];"),
+ [(set Int64Regs:$state, (Intrin Int32Regs:$addr))]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs Int64Regs:$state), (ins Int64Regs:$addr),
+ !strconcat("mbarrier.arrive", AddrSpace, ".b64 $state, [$addr];"),
+ [(set Int64Regs:$state, (Intrin Int64Regs:$addr))]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>;
+defm MBARRIER_ARRIVE_SHARED :
+ MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>;
+
+multiclass MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs Int64Regs:$state),
+ (ins Int32Regs:$addr, Int32Regs:$count),
+ !strconcat("mbarrier.arrive.noComplete", AddrSpace,
+ ".b64 $state, [$addr], $count;"),
+ [(set Int64Regs:$state, (Intrin Int32Regs:$addr, Int32Regs:$count))]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs Int64Regs:$state),
+ (ins Int64Regs:$addr, Int32Regs:$count),
+ !strconcat("mbarrier.arrive.noComplete", AddrSpace,
+ ".b64 $state, [$addr], $count;"),
+ [(set Int64Regs:$state, (Intrin Int64Regs:$addr, Int32Regs:$count))]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm MBARRIER_ARRIVE_NOCOMPLETE :
+ MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>;
+defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED :
+ MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>;
+
+multiclass MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs Int64Regs:$state), (ins Int32Regs:$addr),
+ !strconcat("mbarrier.arrive_drop", AddrSpace,
+ ".b64 $state, [$addr];"),
+ [(set Int64Regs:$state, (Intrin Int32Regs:$addr))]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs Int64Regs:$state), (ins Int64Regs:$addr),
+ !strconcat("mbarrier.arrive_drop", AddrSpace,
+ ".b64 $state, [$addr];"),
+ [(set Int64Regs:$state, (Intrin Int64Regs:$addr))]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm MBARRIER_ARRIVE_DROP :
+ MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>;
+defm MBARRIER_ARRIVE_DROP_SHARED :
+ MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>;
+
+multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs Int64Regs:$state),
+ (ins Int32Regs:$addr, Int32Regs:$count),
+ !strconcat("mbarrier.arrive_drop.noComplete", AddrSpace,
+ ".b64 $state, [$addr], $count;"),
+ [(set Int64Regs:$state, (Intrin Int32Regs:$addr, Int32Regs:$count))]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs Int64Regs:$state),
+ (ins Int64Regs:$addr, Int32Regs:$count),
+ !strconcat("mbarrier.arrive_drop.noComplete", AddrSpace,
+ ".b64 $state, [$addr], $count;"),
+ [(set Int64Regs:$state, (Intrin Int64Regs:$addr, Int32Regs:$count))]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm MBARRIER_ARRIVE_DROP_NOCOMPLETE :
+ MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>;
+defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED :
+ MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared",
+ int_nvvm_mbarrier_arrive_drop_noComplete_shared>;
+
+multiclass MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> {
+ def _32 : NVPTXInst<(outs Int1Regs:$res), (ins Int32Regs:$addr, Int64Regs:$state),
+ !strconcat("mbarrier.test_wait", AddrSpace, ".b64 $res, [$addr], $state;"),
+ [(set Int1Regs:$res, (Intrin Int32Regs:$addr, Int64Regs:$state))]>,
+ Requires<[hasPTX70, hasSM80]>;
+ def _64 : NVPTXInst<(outs Int1Regs:$res), (ins Int64Regs:$addr, Int64Regs:$state),
+ !strconcat("mbarrier.test_wait", AddrSpace, ".b64 $res, [$addr], $state;"),
+ [(set Int1Regs:$res, (Intrin Int64Regs:$addr, Int64Regs:$state))]>,
+ Requires<[hasPTX70, hasSM80]>;
+}
+
+defm MBARRIER_TEST_WAIT :
+ MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>;
+defm MBARRIER_TEST_WAIT_SHARED :
+ MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>;
+
+class MBARRIER_PENDING_COUNT<Intrinsic Intrin> :
+ NVPTXInst<(outs Int32Regs:$res), (ins Int64Regs:$state),
+ "mbarrier.pending_count.b64 $res, $state;",
+ [(set Int32Regs:$res, (Intrin Int64Regs:$state))]>,
+ Requires<[hasPTX70, hasSM80]>;
+
+def MBARRIER_PENDING_COUNT :
+ MBARRIER_PENDING_COUNT<int_nvvm_mbarrier_pending_count>;
+
+//-----------------------------------
// Math Functions
//-----------------------------------
@@ -1722,21 +1943,21 @@ multiclass VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
!strconcat("ldu.global.", TyStr), []>;
}
-multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
+multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins Int32Regs:$src),
+ regclass:$dst4), (ins Int32Regs:$src),
!strconcat("ldu.global.", TyStr), []>;
def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins Int64Regs:$src),
+ regclass:$dst4), (ins Int64Regs:$src),
!strconcat("ldu.global.", TyStr), []>;
def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins MEMri:$src),
+ regclass:$dst4), (ins MEMri:$src),
!strconcat("ldu.global.", TyStr), []>;
def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins MEMri64:$src),
+ regclass:$dst4), (ins MEMri64:$src),
!strconcat("ldu.global.", TyStr), []>;
def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins imemAny:$src),
+ regclass:$dst4), (ins imemAny:$src),
!strconcat("ldu.global.", TyStr), []>;
}
@@ -1776,7 +1997,7 @@ defm INT_PTX_LDU_G_v4f32_ELE
//-----------------------------------
-// Support for ldg on sm_35 or later
+// Support for ldg on sm_35 or later
//-----------------------------------
// Don't annotate ld.global.nc as mayLoad, because these loads go through the
@@ -1824,7 +2045,7 @@ defm INT_PTX_LDG_GLOBAL_p64
// vector
-// Elementized vector ldg
+// Elementized vector ldg
multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
(ins Int32Regs:$src),
@@ -1843,21 +2064,21 @@ multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
!strconcat("ld.global.nc.", TyStr), []>;
}
-multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
+multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins Int32Regs:$src),
+ regclass:$dst4), (ins Int32Regs:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins Int64Regs:$src),
+ regclass:$dst4), (ins Int64Regs:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins MEMri:$src),
+ regclass:$dst4), (ins MEMri:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins MEMri64:$src),
+ regclass:$dst4), (ins MEMri64:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins imemAny:$src),
+ regclass:$dst4), (ins imemAny:$src),
!strconcat("ld.global.nc.", TyStr), []>;
}
@@ -7347,12 +7568,15 @@ def INT_PTX_SREG_WARPSIZE :
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<WMMA_REGS r>
+class WMMA_REGINFO<WMMA_REGS r, string op>
: WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
!eq(ptx_elt_type, "f16") : Float16x2Regs,
!eq(ptx_elt_type, "f32") : Float32Regs,
+ !eq(ptx_elt_type, "f64") : Float64Regs,
+ !eq(ptx_elt_type, "bf16") : Int32Regs,
+ !eq(ptx_elt_type, "tf32") : Int32Regs,
!eq(ptx_elt_type, "s32") : Int32Regs,
!eq(ptx_elt_type, "s8") : Int32Regs,
!eq(ptx_elt_type, "u8") : Int32Regs,
@@ -7381,6 +7605,9 @@ class WMMA_REGINFO<WMMA_REGS r>
!or(!eq(ptx_elt_type, "f16"),
!eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60],
+ !and(!eq(geom,"m8n8k4"),
+ !eq(ptx_elt_type, "f64")) : [hasSM80, hasPTX70],
+
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
!and(!or(!eq(geom, "m8n32k16"),
!eq(geom, "m32n8k16")),
@@ -7395,11 +7622,46 @@ class WMMA_REGINFO<WMMA_REGS r>
!eq(ptx_elt_type, "s8"),
!eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63],
- // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
- !or(!eq(geom,"m8n8k128"),
- !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
+ !and(!or(!eq(geom,"m16n16k16"),
+ !eq(geom,"m8n32k16"),
+ !eq(geom,"m32n8k16")),
+ !eq(ptx_elt_type, "bf16")) : [hasSM80, hasPTX70],
+
+ !and(!eq(geom,"m16n16k8"),
+ !eq(ptx_elt_type, "tf32")) : [hasSM80, hasPTX70],
+
+ !and(!eq(geom,"m16n16k8"),
+ !eq(ptx_elt_type, "f32")) : [hasSM80, hasPTX70],
- !eq(geom, "m8n8k4") : [hasSM70, hasPTX64]);
+ // b1 -> s32 @ m8n8k128(b1)
+ !and(!ne(op,"mma"),
+ !eq(geom,"m8n8k128")) : [hasSM75, hasPTX63],
+
+ // u4/s4 -> s32 @ m8n8k32 (u4/s4)
+ !and(!ne(op,"mma"),
+ !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
+
+ !or(!eq(geom,"m16n8k8"),
+ !eq(geom,"m8n8k16")) : [hasSM75, hasPTX65],
+
+ !and(!ne(ptx_elt_type,"f64"),
+ !eq(geom, "m8n8k4")) : [hasSM70, hasPTX64],
+
+ // mma m8n8k32 requires higher PTX version
+ !and(!eq(op,"mma"),
+ !eq(geom,"m8n8k32")) : [hasSM75, hasPTX65],
+
+ !and(!eq(ptx_elt_type,"f64"),
+ !eq(geom, "m8n8k4")) : [hasSM80, hasPTX70],
+
+ !and(!eq(op,"mma"),
+ !or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k4"),
+ !eq(geom, "m16n8k32"),
+ !eq(geom, "m16n8k64"),
+ !eq(geom, "m8n8k128"),
+ !eq(geom, "m16n8k128"),
+ !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7523,60 +7785,109 @@ defset list<WMMA_INSTR> MMA_LDSTs = {
foreach space = [".global", ".shared", ""] in {
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
foreach frag = NVVM_MMA_OPS.all_ld_ops in
- if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
- def : WMMA_LOAD<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+ if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
+ def : WMMA_LOAD<WMMA_REGINFO<frag, "load">, layout, space, stride, addr>;
foreach frag = NVVM_MMA_OPS.all_st_ops in
- if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
- def : WMMA_STORE_D<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+ if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
+ def : WMMA_STORE_D<WMMA_REGINFO<frag, "store">, layout, space, stride, addr>;
} // addr
} // space
} // stride
} // layout
} // defset
+// B1 instruction variants need extra constraints.
+class MMA_OP_PREDICATES<WMMA_REGINFO FragA, string b1op> {
+ string Op = b1op;
+ WMMA_REGINFO Frag = FragA;
+ list<Predicate> ret = !listconcat(
+ FragA.Predicates,
+ !if(!eq(b1op, ".and.popc"), [hasSM80,hasPTX71],[])
+ );
+}
// WMMA.MMA
class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
- string ALayout, string BLayout, int Satfinite>
- : WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
- [FragA.Ins, FragB.Ins, FragC.Ins]>,
+ string ALayout, string BLayout, int Satfinite, string rnd, string b1op>
+ : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
- Requires<FragA.Predicates> {
+ Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> {
let OutOperandList = FragD.Outs;
let InOperandList = !con(Args, (ins MmaCode:$ptx));
string TypeList = !cond(
- !eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type
- # ".f16.f16."
- # FragC.ptx_elt_type,
- !eq(FragD.ptx_elt_type, "s32") : ".s32"
- # "." # FragA.ptx_elt_type
- # "." # FragB.ptx_elt_type
- # ".s32",
- 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type,
+ !eq(FragA.ptx_elt_type, "f16") : "." # FragD.ptx_elt_type
+ # "." # FragC.ptx_elt_type,
+ 1: "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type,
);
- let AsmString = !if(!eq(FragA.geom, "m8n8k4"),
- "mma.sync.aligned.m8n8k4"
- # "." # ALayout
- # "." # BLayout
- # TypeList # "\n\t\t"
- # FragD.regstring # ",\n\t\t"
- # FragA.regstring # ",\n\t\t"
- # FragB.regstring # ",\n\t\t"
- # FragC.regstring # ";",
- "wmma.mma"
- # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
- # ".sync"
- # "${ptx:aligned}"
- # "." # ALayout
- # "." # BLayout
- # "." # FragA.geom
- # TypeList
- # !if(Satfinite, ".satfinite", "") # "\n\t\t"
- # FragD.regstring # ",\n\t\t"
- # FragA.regstring # ",\n\t\t"
- # FragB.regstring # ",\n\t\t"
- # FragC.regstring # ";");
+ let AsmString = "wmma.mma"
+ # b1op
+ # ".sync"
+ # "${ptx:aligned}"
+ # "." # ALayout
+ # "." # BLayout
+ # "." # FragA.geom
+ # !if(!ne(rnd, ""), !strconcat(".", rnd), "")
+ # TypeList
+ # !if(Satfinite, ".satfinite", "") # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ";";
+}
+
+defset list<WMMA_INSTR> WMMAs = {
+ foreach layout_a = ["row", "col"] in {
+ foreach layout_b = ["row", "col"] in {
+ foreach satf = [0, 1] in {
+ foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
+ foreach op = NVVM_MMA_OPS.all_wmma_ops in {
+ foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+ if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+ def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
+ WMMA_REGINFO<op[1], "wmma.mma">,
+ WMMA_REGINFO<op[2], "wmma.mma">,
+ WMMA_REGINFO<op[3], "wmma.mma">,
+ layout_a, layout_b, satf, rnd, b1op>;
+ }
+ } // b1op
+ } // op
+ } // rnd
+ } // satf
+ } // layout_b
+ } // layout_a
+} // defset
+
+// MMA
+class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string ALayout, string BLayout, int Satfinite, string b1op>
+ : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ let AsmString = "mma.sync.aligned."
+ # FragA.geom
+ # "." # ALayout
+ # "." # BLayout
+ # !if(Satfinite, ".satfinite", "")
+ # TypeList
+ # b1op # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ";";
}
defset list<WMMA_INSTR> MMAs = {
@@ -7584,13 +7895,15 @@ defset list<WMMA_INSTR> MMAs = {
foreach layout_b = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
- if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def : WMMA_MMA<WMMA_REGINFO<op[0]>,
- WMMA_REGINFO<op[1]>,
- WMMA_REGINFO<op[2]>,
- WMMA_REGINFO<op[3]>,
- layout_a, layout_b, satf>;
- }
+ foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+ if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
+ def : MMA<WMMA_REGINFO<op[0], "mma">,
+ WMMA_REGINFO<op[1], "mma">,
+ WMMA_REGINFO<op[2], "mma">,
+ WMMA_REGINFO<op[3], "mma">,
+ layout_a, layout_b, satf, b1op>;
+ }
+ } // b1op
} // op
} // satf
} // layout_b
@@ -7601,12 +7914,12 @@ defset list<WMMA_INSTR> MMAs = {
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
// the instruction record.
-class WMMA_PAT<WMMA_INSTR wi>
+class MMA_PAT<WMMA_INSTR wi>
: Pat<wi.IntrinsicPattern,
!con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
(wi ptx.version))>,
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, MMA_LDSTs) in
- def : WMMA_PAT<mma>;
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in
+ def : MMA_PAT<mma>;