diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXIntrinsics.td')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 42 |
1 files changed, 28 insertions, 14 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index c52195fb0449..76a4a1d4030a 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7400,7 +7400,9 @@ class WMMA_REGINFO<WMMA_REGS r> // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1) !or(!eq(geom,"m8n8k128"), - !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63]); + !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63], + + !eq(geom, "m8n8k4") : [hasSM70, hasPTX64]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -7546,25 +7548,37 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, let OutOperandList = FragD.Outs; let InOperandList = !con(Args, (ins MmaCode:$ptx)); string TypeList = !cond( + !eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type + # ".f16.f16." + # FragC.ptx_elt_type, !eq(FragD.ptx_elt_type, "s32") : ".s32" # "." # FragA.ptx_elt_type # "." # FragB.ptx_elt_type # ".s32", 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type, ); - let AsmString = "wmma.mma" - # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") - # ".sync" - # "${ptx:aligned}" - # "." # ALayout - # "." # BLayout - # "." # FragA.geom - # TypeList - # !if(Satfinite, ".satfinite", "") # "\n\t\t" - # FragD.regstring # ",\n\t\t" - # FragA.regstring # ",\n\t\t" - # FragB.regstring # ",\n\t\t" - # FragC.regstring # ";"; + let AsmString = !if(!eq(FragA.geom, "m8n8k4"), + "mma.sync.aligned.m8n8k4" + # "." # ALayout + # "." # BLayout + # TypeList # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";", + "wmma.mma" + # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") + # ".sync" + # "${ptx:aligned}" + # "." # ALayout + # "." # BLayout + # "." # FragA.geom + # TypeList + # !if(Satfinite, ".satfinite", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"); } defset list<WMMA_INSTR> MMAs = { |