diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/X86/X86LowerAMXType.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Target/X86/X86LowerAMXType.cpp | 351 |
1 files changed, 351 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/X86/X86LowerAMXType.cpp b/contrib/llvm-project/llvm/lib/Target/X86/X86LowerAMXType.cpp new file mode 100644 index 000000000000..85166decd8cd --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -0,0 +1,351 @@ +//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file Pass to transform <256 x i32> load/store +/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only +/// provides simple operation on x86_amx. The basic elementwise operation +/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> +/// and only AMX intrinsics can operate on the type, we need transform +/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can +/// not be combined with load/store, we transform the bitcast to amx load/store +/// and <256 x i32> store/load. +// +//===----------------------------------------------------------------------===// +// +#include "X86.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/ValueTypes.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsX86.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "lower-amx-type" + +static AllocaInst *CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB) { + Function &F = *BB->getParent(); + Module *M = BB->getModule(); + const DataLayout &DL = M->getDataLayout(); + + Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); + LLVMContext &Ctx = Builder.getContext(); + auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); + unsigned AllocaAS = DL.getAllocaAddrSpace(); + AllocaInst *AllocaRes = + new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front()); + AllocaRes->setAlignment(AllocaAlignment); + return AllocaRes; +} + +static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { + Value *Row = nullptr, *Col = nullptr; + switch (II->getIntrinsicID()) { + default: + llvm_unreachable("Expect amx intrinsics"); + case Intrinsic::x86_tileloadd64_internal: + case Intrinsic::x86_tilestored64_internal: { + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + } + // a * b + c + // The shape depends on which operand. + case Intrinsic::x86_tdpbssd_internal: { + switch (OpNo) { + case 3: + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + case 4: + Row = II->getArgOperand(0); + Col = II->getArgOperand(2); + break; + case 5: + Row = II->getArgOperand(2); + Col = II->getArgOperand(1); + break; + } + break; + } + } + + return std::make_pair(Row, Col); +} + +// %src = load <256 x i32>, <256 x i32>* %addr, align 64 +// %2 = bitcast <256 x i32> %src to x86_amx +// --> +// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, +// i8* %addr, i64 %stride64) +static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { + Value *Row = nullptr, *Col = nullptr; + Use &U = *(Bitcast->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = cast<IntrinsicInst>(U.getUser()); + std::tie(Row, Col) = getShape(II, OpNo); + IRBuilder<> Builder(Bitcast); + // Use the maximun column as stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); + std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; + + Value *NewInst = + Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); + Bitcast->replaceAllUsesWith(NewInst); +} + +// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, +// %stride); +// %13 = bitcast x86_amx %src to <256 x i32> +// store <256 x i32> %13, <256 x i32>* %addr, align 64 +// --> +// call void @llvm.x86.tilestored64.internal(%row, %col, %addr, +// %stride64, %13) +static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { + + Value *Tile = Bitcast->getOperand(0); + auto *II = cast<IntrinsicInst>(Tile); + // Tile is output from AMX intrinsic. The first operand of the + // intrinsic is row, the second operand of the intrinsic is column. + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + IRBuilder<> Builder(ST); + // Use the maximum column as stride. It must be the same with load + // stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); + std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); + if (Bitcast->hasOneUse()) + return; + // %13 = bitcast x86_amx %src to <256 x i32> + // store <256 x i32> %13, <256 x i32>* %addr, align 64 + // %add = <256 x i32> %13, <256 x i32> %src2 + // --> + // %13 = bitcast x86_amx %src to <256 x i32> + // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, + // %stride64, %13) + // %14 = load <256 x i32>, %addr + // %add = <256 x i32> %14, <256 x i32> %src2 + Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); + Bitcast->replaceAllUsesWith(Vec); +} + +// transform bitcast to <store, load> instructions. +static bool transformBitcast(BitCastInst *Bitcast) { + IRBuilder<> Builder(Bitcast); + AllocaInst *AllocaAddr; + Value *I8Ptr, *Stride; + auto *Src = Bitcast->getOperand(0); + + auto Prepare = [&]() { + AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent()); + I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); + Stride = Builder.getInt64(64); + }; + + if (Bitcast->getType()->isX86_AMXTy()) { + // %2 = bitcast <256 x i32> %src to x86_amx + // --> + // %addr = alloca <256 x i32>, align 64 + // store <256 x i32> %src, <256 x i32>* %addr, align 64 + // %addr2 = bitcast <256 x i32>* to i8* + // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, + // i8* %addr2, + // i64 64) + Use &U = *(Bitcast->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = dyn_cast<IntrinsicInst>(U.getUser()); + if (!II) + return false; // May be bitcast from x86amx to <256 x i32>. + Prepare(); + Builder.CreateStore(Src, AllocaAddr); + // TODO we can pick an constant operand for the shape. + Value *Row = nullptr, *Col = nullptr; + std::tie(Row, Col) = getShape(II, OpNo); + std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; + Value *NewInst = Builder.CreateIntrinsic( + Intrinsic::x86_tileloadd64_internal, None, Args); + Bitcast->replaceAllUsesWith(NewInst); + } else { + // %2 = bitcast x86_amx %src to <256 x i32> + // --> + // %addr = alloca <256 x i32>, align 64 + // %addr2 = bitcast <256 x i32>* to i8* + // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, + // i8* %addr2, i64 %stride) + // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 + auto *II = dyn_cast<IntrinsicInst>(Src); + if (!II) + return false; // May be bitcast from <256 x i32> to x86amx. + Prepare(); + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); + Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); + Bitcast->replaceAllUsesWith(NewInst); + } + + return true; +} + +namespace { +class X86LowerAMXType { + Function &Func; + +public: + X86LowerAMXType(Function &F) : Func(F) {} + bool visit(); +}; + +bool X86LowerAMXType::visit() { + SmallVector<Instruction *, 8> DeadInsts; + + for (BasicBlock *BB : post_order(&Func)) { + for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); + II != IE;) { + Instruction &Inst = *II++; + auto *Bitcast = dyn_cast<BitCastInst>(&Inst); + if (!Bitcast) + continue; + + Value *Src = Bitcast->getOperand(0); + if (Bitcast->getType()->isX86_AMXTy()) { + if (Bitcast->user_empty()) { + DeadInsts.push_back(Bitcast); + continue; + } + LoadInst *LD = dyn_cast<LoadInst>(Src); + if (!LD) { + if (transformBitcast(Bitcast)) + DeadInsts.push_back(Bitcast); + continue; + } + // If load has mutli-user, duplicate a vector load. + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 + // %2 = bitcast <256 x i32> %src to x86_amx + // %add = add <256 x i32> %src, <256 x i32> %src2 + // --> + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 + // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, + // i8* %addr, i64 %stride64) + // %add = add <256 x i32> %src, <256 x i32> %src2 + + // If load has one user, the load will be eliminated in DAG ISel. + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 + // %2 = bitcast <256 x i32> %src to x86_amx + // --> + // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, + // i8* %addr, i64 %stride64) + combineLoadBitcast(LD, Bitcast); + DeadInsts.push_back(Bitcast); + if (LD->hasOneUse()) + DeadInsts.push_back(LD); + } else if (Src->getType()->isX86_AMXTy()) { + if (Bitcast->user_empty()) { + DeadInsts.push_back(Bitcast); + continue; + } + StoreInst *ST = nullptr; + for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); + UI != UE;) { + Value *I = (UI++)->getUser(); + ST = dyn_cast<StoreInst>(I); + if (ST) + break; + } + if (!ST) { + if (transformBitcast(Bitcast)) + DeadInsts.push_back(Bitcast); + continue; + } + // If bitcast (%13) has one use, combine bitcast and store to amx store. + // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, + // %stride); + // %13 = bitcast x86_amx %src to <256 x i32> + // store <256 x i32> %13, <256 x i32>* %addr, align 64 + // --> + // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, + // %stride64, %13) + // + // If bitcast (%13) has multi-use, transform as below. + // %13 = bitcast x86_amx %src to <256 x i32> + // store <256 x i32> %13, <256 x i32>* %addr, align 64 + // %add = <256 x i32> %13, <256 x i32> %src2 + // --> + // %13 = bitcast x86_amx %src to <256 x i32> + // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, + // %stride64, %13) + // %14 = load <256 x i32>, %addr + // %add = <256 x i32> %14, <256 x i32> %src2 + // + combineBitcastStore(Bitcast, ST); + // Delete user first. + DeadInsts.push_back(ST); + DeadInsts.push_back(Bitcast); + } + } + } + + bool C = !DeadInsts.empty(); + + for (auto *Inst : DeadInsts) + Inst->eraseFromParent(); + + return C; +} +} // anonymous namespace + +namespace { + +class X86LowerAMXTypeLegacyPass : public FunctionPass { +public: + static char ID; + + X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { + initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + X86LowerAMXType LAT(F); + bool C = LAT.visit(); + return C; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } +}; + +} // anonymous namespace + +static const char PassName[] = "Lower AMX type for load/store"; +char X86LowerAMXTypeLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, + false) +INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, + false) + +FunctionPass *llvm::createX86LowerAMXTypePass() { + return new X86LowerAMXTypeLegacyPass(); +} |