//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// \file Pass to transform <256 x i32> load/store /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only /// provides simple operation on x86_amx. The basic elementwise operation /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> /// and only AMX intrinsics can operate on the type, we need transform /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can /// not be combined with load/store, we transform the bitcast to amx load/store /// and <256 x i32> store/load. /// /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, /// because that is necessary for AMX fast register allocation. (In Fast /// registera allocation, register will be allocated before spill/reload, so /// there is no additional register for amx to identify the step in spill.) /// The volatileTileData() will handle this case. /// e.g. /// ---------------------------------------------------------- /// | def %td = ... | /// | ... | /// | "use %td" | /// ---------------------------------------------------------- /// will transfer to --> /// ---------------------------------------------------------- /// | def %td = ... | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | /// | ... | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| /// | "use %td2" | /// ---------------------------------------------------------- // //===----------------------------------------------------------------------===// // #include "X86.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "lower-amx-type" static bool isAMXCast(Instruction *II) { return match(II, m_Intrinsic(m_Value())) || match(II, m_Intrinsic(m_Value())); } static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty) { Function &F = *BB->getParent(); Module *M = BB->getModule(); const DataLayout &DL = M->getDataLayout(); LLVMContext &Ctx = Builder.getContext(); auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); unsigned AllocaAS = DL.getAllocaAddrSpace(); AllocaInst *AllocaRes = new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front()); AllocaRes->setAlignment(AllocaAlignment); return AllocaRes; } static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { for (Instruction &I : F.getEntryBlock()) if (!isa(&I)) return &I; llvm_unreachable("No terminator in the entry block!"); } static std::pair getShape(IntrinsicInst *II, unsigned OpNo) { IRBuilder<> Builder(II); Value *Row = nullptr, *Col = nullptr; switch (II->getIntrinsicID()) { default: llvm_unreachable("Expect amx intrinsics"); case Intrinsic::x86_tileloadd64_internal: case Intrinsic::x86_tileloaddt164_internal: case Intrinsic::x86_tilestored64_internal: { Row = II->getArgOperand(0); Col = II->getArgOperand(1); break; } // a * b + c // The shape depends on which operand. case Intrinsic::x86_tdpbssd_internal: case Intrinsic::x86_tdpbsud_internal: case Intrinsic::x86_tdpbusd_internal: case Intrinsic::x86_tdpbuud_internal: case Intrinsic::x86_tdpbf16ps_internal: { switch (OpNo) { case 3: Row = II->getArgOperand(0); Col = II->getArgOperand(1); break; case 4: Row = II->getArgOperand(0); Col = II->getArgOperand(2); break; case 5: if (isa(II->getArgOperand(2))) Row = Builder.getInt16( (cast(II->getOperand(2))->getSExtValue()) / 4); else if (isa(II->getArgOperand(2))) { // When it is not a const value and it is not a function argument, we // create Row after the definition of II->getOperand(2) instead of // before II. For example, II is %118, we try to getshape for %117: // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x // i32> %115). // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx // %117). // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its // definition is after its user(new tileload for %117). // So, the best choice is to create %row right after the definition of // %106. Builder.SetInsertPoint(cast(II->getOperand(2))); Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4)); cast(Row)->moveAfter(cast(II->getOperand(2))); } else { // When it is not a const value and it is a function argument, we create // Row at the entry bb. IRBuilder<> NewBuilder( getFirstNonAllocaInTheEntryBlock(*II->getFunction())); Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4)); } Col = II->getArgOperand(1); break; } break; } } return std::make_pair(Row, Col); } namespace { class X86LowerAMXType { Function &Func; // In AMX intrinsics we let Shape = {Row, Col}, but the // RealCol = Col / ElementSize. We may use the RealCol // as a new Row for other new created AMX intrinsics. std::map Col2Row; public: X86LowerAMXType(Function &F) : Func(F) {} bool visit(); void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); bool transformBitcast(BitCastInst *Bitcast); }; // %src = load <256 x i32>, <256 x i32>* %addr, align 64 // %2 = bitcast <256 x i32> %src to x86_amx // --> // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, // i8* %addr, i64 %stride64) void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { Value *Row = nullptr, *Col = nullptr; Use &U = *(Bitcast->use_begin()); unsigned OpNo = U.getOperandNo(); auto *II = cast(U.getUser()); std::tie(Row, Col) = getShape(II, OpNo); IRBuilder<> Builder(Bitcast); // Use the maximun column as stride. Value *Stride = Builder.getInt64(64); Value *I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); std::array Args = {Row, Col, I8Ptr, Stride}; Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); Bitcast->replaceAllUsesWith(NewInst); } // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, // %stride); // %13 = bitcast x86_amx %src to <256 x i32> // store <256 x i32> %13, <256 x i32>* %addr, align 64 // --> // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, // %stride64, %13) void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { Value *Tile = Bitcast->getOperand(0); auto *II = cast(Tile); // Tile is output from AMX intrinsic. The first operand of the // intrinsic is row, the second operand of the intrinsic is column. Value *Row = II->getOperand(0); Value *Col = II->getOperand(1); IRBuilder<> Builder(ST); // Use the maximum column as stride. It must be the same with load // stride. Value *Stride = Builder.getInt64(64); Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); std::array Args = {Row, Col, I8Ptr, Stride, Tile}; Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); if (Bitcast->hasOneUse()) return; // %13 = bitcast x86_amx %src to <256 x i32> // store <256 x i32> %13, <256 x i32>* %addr, align 64 // %add = <256 x i32> %13, <256 x i32> %src2 // --> // %13 = bitcast x86_amx %src to <256 x i32> // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, // %stride64, %13) // %14 = load <256 x i32>, %addr // %add = <256 x i32> %14, <256 x i32> %src2 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); Bitcast->replaceAllUsesWith(Vec); } // transform bitcast to instructions. bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { IRBuilder<> Builder(Bitcast); AllocaInst *AllocaAddr; Value *I8Ptr, *Stride; auto *Src = Bitcast->getOperand(0); auto Prepare = [&](Type *MemTy) { AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); Stride = Builder.getInt64(64); }; if (Bitcast->getType()->isX86_AMXTy()) { // %2 = bitcast <256 x i32> %src to x86_amx // --> // %addr = alloca <256 x i32>, align 64 // store <256 x i32> %src, <256 x i32>* %addr, align 64 // %addr2 = bitcast <256 x i32>* to i8* // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, // i8* %addr2, // i64 64) Use &U = *(Bitcast->use_begin()); unsigned OpNo = U.getOperandNo(); auto *II = dyn_cast(U.getUser()); if (!II) return false; // May be bitcast from x86amx to <256 x i32>. Prepare(Bitcast->getOperand(0)->getType()); Builder.CreateStore(Src, AllocaAddr); // TODO we can pick an constant operand for the shape. Value *Row = nullptr, *Col = nullptr; std::tie(Row, Col) = getShape(II, OpNo); std::array Args = {Row, Col, I8Ptr, Stride}; Value *NewInst = Builder.CreateIntrinsic( Intrinsic::x86_tileloadd64_internal, None, Args); Bitcast->replaceAllUsesWith(NewInst); } else { // %2 = bitcast x86_amx %src to <256 x i32> // --> // %addr = alloca <256 x i32>, align 64 // %addr2 = bitcast <256 x i32>* to i8* // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, // i8* %addr2, i64 %stride) // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 auto *II = dyn_cast(Src); if (!II) return false; // May be bitcast from <256 x i32> to x86amx. Prepare(Bitcast->getType()); Value *Row = II->getOperand(0); Value *Col = II->getOperand(1); std::array Args = {Row, Col, I8Ptr, Stride, Src}; Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); Bitcast->replaceAllUsesWith(NewInst); } return true; } bool X86LowerAMXType::visit() { SmallVector DeadInsts; Col2Row.clear(); for (BasicBlock *BB : post_order(&Func)) { for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) { auto *Bitcast = dyn_cast(&Inst); if (!Bitcast) continue; Value *Src = Bitcast->getOperand(0); if (Bitcast->getType()->isX86_AMXTy()) { if (Bitcast->user_empty()) { DeadInsts.push_back(Bitcast); continue; } LoadInst *LD = dyn_cast(Src); if (!LD) { if (transformBitcast(Bitcast)) DeadInsts.push_back(Bitcast); continue; } // If load has mutli-user, duplicate a vector load. // %src = load <256 x i32>, <256 x i32>* %addr, align 64 // %2 = bitcast <256 x i32> %src to x86_amx // %add = add <256 x i32> %src, <256 x i32> %src2 // --> // %src = load <256 x i32>, <256 x i32>* %addr, align 64 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, // i8* %addr, i64 %stride64) // %add = add <256 x i32> %src, <256 x i32> %src2 // If load has one user, the load will be eliminated in DAG ISel. // %src = load <256 x i32>, <256 x i32>* %addr, align 64 // %2 = bitcast <256 x i32> %src to x86_amx // --> // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, // i8* %addr, i64 %stride64) combineLoadBitcast(LD, Bitcast); DeadInsts.push_back(Bitcast); if (LD->hasOneUse()) DeadInsts.push_back(LD); } else if (Src->getType()->isX86_AMXTy()) { if (Bitcast->user_empty()) { DeadInsts.push_back(Bitcast); continue; } StoreInst *ST = nullptr; for (Use &U : Bitcast->uses()) { ST = dyn_cast(U.getUser()); if (ST) break; } if (!ST) { if (transformBitcast(Bitcast)) DeadInsts.push_back(Bitcast); continue; } // If bitcast (%13) has one use, combine bitcast and store to amx store. // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, // %stride); // %13 = bitcast x86_amx %src to <256 x i32> // store <256 x i32> %13, <256 x i32>* %addr, align 64 // --> // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, // %stride64, %13) // // If bitcast (%13) has multi-use, transform as below. // %13 = bitcast x86_amx %src to <256 x i32> // store <256 x i32> %13, <256 x i32>* %addr, align 64 // %add = <256 x i32> %13, <256 x i32> %src2 // --> // %13 = bitcast x86_amx %src to <256 x i32> // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, // %stride64, %13) // %14 = load <256 x i32>, %addr // %add = <256 x i32> %14, <256 x i32> %src2 // combineBitcastStore(Bitcast, ST); // Delete user first. DeadInsts.push_back(ST); DeadInsts.push_back(Bitcast); } } } bool C = !DeadInsts.empty(); for (auto *Inst : DeadInsts) Inst->eraseFromParent(); return C; } } // anonymous namespace static Value *getAllocaPos(BasicBlock *BB) { Module *M = BB->getModule(); Function *F = BB->getParent(); IRBuilder<> Builder(&F->getEntryBlock().front()); const DataLayout &DL = M->getDataLayout(); unsigned AllocaAS = DL.getAllocaAddrSpace(); Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); AllocaInst *AllocaRes = new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); BasicBlock::iterator Iter = AllocaRes->getIterator(); ++Iter; Builder.SetInsertPoint(&*Iter); Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy()); return I8Ptr; } static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); auto *II = cast(TileDef); assert(II && "Not tile intrinsic!"); Value *Row = II->getOperand(0); Value *Col = II->getOperand(1); BasicBlock *BB = TileDef->getParent(); BasicBlock::iterator Iter = TileDef->getIterator(); IRBuilder<> Builder(BB, ++Iter); Value *Stride = Builder.getInt64(64); std::array Args = {Row, Col, Ptr, Stride, TileDef}; Instruction *TileStore = Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); return TileStore; } static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { Value *V = U.get(); assert(V->getType()->isX86_AMXTy() && "Not define tile!"); // Get tile shape. IntrinsicInst *II = nullptr; if (IsPHI) { Value *PhiOp = dyn_cast(V)->getIncomingValue(0); II = cast(PhiOp); } else { II = cast(V); } Value *Row = II->getOperand(0); Value *Col = II->getOperand(1); Instruction *UserI = dyn_cast(U.getUser()); IRBuilder<> Builder(UserI); Value *Stride = Builder.getInt64(64); std::array Args = {Row, Col, Ptr, Stride}; Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); UserI->replaceUsesOfWith(V, TileLoad); } static bool isIncomingOfPHI(Instruction *I) { for (Use &U : I->uses()) { User *V = U.getUser(); if (isa(V)) return true; } return false; } // Let all AMX tile data become volatile data, shorten the life range // of each tile register before fast register allocation. namespace { class X86VolatileTileData { Function &F; public: X86VolatileTileData(Function &Func) : F(Func) {} Value *updatePhiIncomings(BasicBlock *BB, SmallVector &Incomings); void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); bool volatileTileData(); void volatileTilePHI(PHINode *Inst); void volatileTileNonPHI(Instruction *I); }; Value *X86VolatileTileData::updatePhiIncomings( BasicBlock *BB, SmallVector &Incomings) { Value *I8Ptr = getAllocaPos(BB); for (auto *I : Incomings) { User *Store = createTileStore(I, I8Ptr); // All its uses (except phi) should load from stored mem. for (Use &U : I->uses()) { User *V = U.getUser(); if (isa(V) || V == Store) continue; replaceWithTileLoad(U, I8Ptr); } } return I8Ptr; } void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr) { for (Use &U : PHI->uses()) replaceWithTileLoad(U, StorePtr, true); PHI->eraseFromParent(); } // Smilar with volatileTileNonPHI, this function only handle PHI Nodes // and their related AMX intrinsics. // 1) PHI Def should change to tileload. // 2) PHI Incoming Values should tilestored in just after their def. // 3) The mem of these tileload and tilestores should be same. // e.g. // ------------------------------------------------------ // bb_dom: // ... // br i1 %bool.cond, label %if.else, label %if.then // // if.then: // def %t0 = ... // ... // use %t0 // ... // br label %if.end // // if.else: // def %t1 = ... // br label %if.end // // if.end: // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] // ... // use %td // ------------------------------------------------------ // --> // ------------------------------------------------------ // bb_entry: // %mem = alloca <256 x i32>, align 1024 * // ... // bb_dom: // ... // br i1 %bool.cond, label %if.else, label %if.then // // if.then: // def %t0 = ... // call void @llvm.x86.tilestored64.internal(mem, %t0) * // ... // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* // use %t0` * // ... // br label %if.end // // if.else: // def %t1 = ... // call void @llvm.x86.tilestored64.internal(mem, %t1) * // br label %if.end // // if.end: // ... // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * // use %td // ------------------------------------------------------ void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { BasicBlock *BB = PHI->getParent(); SmallVector Incomings; for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { Value *Op = PHI->getIncomingValue(I); Instruction *Inst = dyn_cast(Op); assert(Inst && "We shouldn't fold AMX instrution!"); Incomings.push_back(Inst); } Value *StorePtr = updatePhiIncomings(BB, Incomings); replacePhiDefWithLoad(PHI, StorePtr); } // Store the defined tile and load it before use. // All its users are not PHI. // e.g. // ------------------------------------------------------ // def %td = ... // ... // "use %td" // ------------------------------------------------------ // --> // ------------------------------------------------------ // def %td = ... // call void @llvm.x86.tilestored64.internal(mem, %td) // ... // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) // "use %td2" // ------------------------------------------------------ void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { BasicBlock *BB = I->getParent(); Value *I8Ptr = getAllocaPos(BB); User *Store = createTileStore(I, I8Ptr); // All its uses should load from stored mem. for (Use &U : I->uses()) { User *V = U.getUser(); assert(!isa(V) && "PHI Nodes should be excluded!"); if (V != Store) replaceWithTileLoad(U, I8Ptr); } } // Volatile Tile Model: // 1) All the uses of tile data comes from tileload in time. // 2) All the defs of tile data tilestore into mem immediately. // For example: // -------------------------------------------------------------------------- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) // call void @llvm.x86.tilestored64.internal(... td) area // -------------------------------------------------------------------------- // 3) No terminator, call or other amx instructions in the key amx area. bool X86VolatileTileData::volatileTileData() { bool Changed = false; for (BasicBlock &BB : F) { SmallVector PHIInsts; SmallVector AMXDefInsts; for (Instruction &I : BB) { if (!I.getType()->isX86_AMXTy()) continue; if (isa(&I)) PHIInsts.push_back(&I); else AMXDefInsts.push_back(&I); } // First we "volatile" the non-phi related amx intrinsics. for (Instruction *I : AMXDefInsts) { if (isIncomingOfPHI(I)) continue; volatileTileNonPHI(I); Changed = true; } for (Instruction *I : PHIInsts) { volatileTilePHI(dyn_cast(I)); Changed = true; } } return Changed; } } // anonymous namespace namespace { class X86LowerAMXCast { Function &Func; public: X86LowerAMXCast(Function &F) : Func(F) {} bool combineAMXcast(TargetLibraryInfo *TLI); bool transformAMXCast(IntrinsicInst *AMXCast); bool transformAllAMXCast(); bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, SmallSetVector &DeadInst); }; static bool DCEInstruction(Instruction *I, SmallSetVector &WorkList, const TargetLibraryInfo *TLI) { if (isInstructionTriviallyDead(I, TLI)) { salvageDebugInfo(*I); salvageKnowledge(I); // Null out all of the instruction's operands to see if any operand becomes // dead as we go. for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { Value *OpV = I->getOperand(i); I->setOperand(i, nullptr); if (!OpV->use_empty() || I == OpV) continue; // If the operand is an instruction that became dead as we nulled out the // operand, and if it is 'trivially' dead, delete it in a future loop // iteration. if (Instruction *OpI = dyn_cast(OpV)) { if (isInstructionTriviallyDead(OpI, TLI)) { WorkList.insert(OpI); } } } I->eraseFromParent(); return true; } return false; } /// This function handles following case /// /// A -> B amxcast /// PHI /// B -> A amxcast /// /// All the related PHI nodes can be replaced by new PHI nodes with type A. /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. bool X86LowerAMXCast::optimizeAMXCastFromPhi( IntrinsicInst *CI, PHINode *PN, SmallSetVector &DeadInst) { IRBuilder<> Builder(CI); Value *Src = CI->getOperand(0); Type *SrcTy = Src->getType(); // Type B Type *DestTy = CI->getType(); // Type A SmallVector PhiWorklist; SmallSetVector OldPhiNodes; // Find all of the A->B casts and PHI nodes. // We need to inspect all related PHI nodes, but PHIs can be cyclic, so // OldPhiNodes is used to track all known PHI nodes, before adding a new // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. PhiWorklist.push_back(PN); OldPhiNodes.insert(PN); while (!PhiWorklist.empty()) { auto *OldPN = PhiWorklist.pop_back_val(); for (Value *IncValue : OldPN->incoming_values()) { // TODO: currently, We ignore cases where it is a const. In the future, we // might support const. if (isa(IncValue)) return false; if (auto *PNode = dyn_cast(IncValue)) { if (OldPhiNodes.insert(PNode)) PhiWorklist.push_back(PNode); continue; } Instruction *ACI = dyn_cast(IncValue); if (ACI && isAMXCast(ACI)) { // Verify it's a A->B cast. Type *TyA = ACI->getOperand(0)->getType(); Type *TyB = ACI->getType(); if (TyA != DestTy || TyB != SrcTy) return false; continue; } return false; } } // Check that each user of each old PHI node is something that we can // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. for (auto *OldPN : OldPhiNodes) { for (User *V : OldPN->users()) { Instruction *ACI = dyn_cast(V); if (ACI && isAMXCast(ACI)) { // Verify it's a B->A cast. Type *TyB = ACI->getOperand(0)->getType(); Type *TyA = ACI->getType(); if (TyA != DestTy || TyB != SrcTy) return false; } else if (auto *PHI = dyn_cast(V)) { // As long as the user is another old PHI node, then even if we don't // rewrite it, the PHI web we're considering won't have any users // outside itself, so it'll be dead. // example: // bb.0: // %0 = amxcast ... // bb.1: // %1 = amxcast ... // bb.2: // %goodphi = phi %0, %1 // %3 = amxcast %goodphi // bb.3: // %goodphi2 = phi %0, %goodphi // %4 = amxcast %goodphi2 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is // outside the phi-web, so the combination stop When // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization // will be done. if (OldPhiNodes.count(PHI) == 0) return false; } else return false; } } // For each old PHI node, create a corresponding new PHI node with a type A. SmallDenseMap NewPNodes; for (auto *OldPN : OldPhiNodes) { Builder.SetInsertPoint(OldPN); PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); NewPNodes[OldPN] = NewPN; } // Fill in the operands of new PHI nodes. for (auto *OldPN : OldPhiNodes) { PHINode *NewPN = NewPNodes[OldPN]; for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { Value *V = OldPN->getOperand(j); Value *NewV = nullptr; Instruction *ACI = dyn_cast(V); // There should not be a AMXcast from a const. if (ACI && isAMXCast(ACI)) NewV = ACI->getOperand(0); else if (auto *PrevPN = dyn_cast(V)) NewV = NewPNodes[PrevPN]; assert(NewV); NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); } } // Traverse all accumulated PHI nodes and process its users, // which are Stores and BitcCasts. Without this processing // NewPHI nodes could be replicated and could lead to extra // moves generated after DeSSA. // If there is a store with type B, change it to type A. // Replace users of BitCast B->A with NewPHI. These will help // later to get rid of a closure formed by OldPHI nodes. for (auto *OldPN : OldPhiNodes) { PHINode *NewPN = NewPNodes[OldPN]; for (User *V : make_early_inc_range(OldPN->users())) { Instruction *ACI = dyn_cast(V); if (ACI && isAMXCast(ACI)) { Type *TyB = ACI->getOperand(0)->getType(); Type *TyA = ACI->getType(); assert(TyA == DestTy && TyB == SrcTy); (void)TyA; (void)TyB; ACI->replaceAllUsesWith(NewPN); DeadInst.insert(ACI); } else if (auto *PHI = dyn_cast(V)) { // We don't need to push PHINode into DeadInst since they are operands // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. assert(OldPhiNodes.contains(PHI)); (void)PHI; } else llvm_unreachable("all uses should be handled"); } } return true; } bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { bool Change = false; // Collect tile cast instruction. SmallVector Vec2TileInsts; SmallVector Tile2VecInsts; SmallVector PhiCastWorkList; SmallSetVector DeadInst; for (BasicBlock &BB : Func) { for (Instruction &I : BB) { Value *Vec; if (match(&I, m_Intrinsic(m_Value(Vec)))) Vec2TileInsts.push_back(&I); else if (match(&I, m_Intrinsic( m_Value(Vec)))) Tile2VecInsts.push_back(&I); } } auto Convert = [&](SmallVectorImpl &Insts, Intrinsic::ID IID) { for (auto *Inst : Insts) { for (User *U : Inst->users()) { IntrinsicInst *II = dyn_cast(U); if (!II || II->getIntrinsicID() != IID) continue; // T1 = vec2tile V0 // V2 = tile2vec T1 // V3 = OP V2 // --> // T1 = vec2tile V0 // V2 = tile2vec T1 // V3 = OP V0 II->replaceAllUsesWith(Inst->getOperand(0)); Change = true; } } }; Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); auto EraseInst = [&](SmallVectorImpl &Insts) { for (auto *Inst : Insts) { if (Inst->use_empty()) { Inst->eraseFromParent(); Change = true; } } }; EraseInst(Vec2TileInsts); EraseInst(Tile2VecInsts); // Handle the A->B->A cast, and there is an intervening PHI node. for (BasicBlock &BB : Func) { for (Instruction &I : BB) { if (isAMXCast(&I)) { if (isa(I.getOperand(0))) PhiCastWorkList.push_back(&I); } } } for (auto *I : PhiCastWorkList) { // We skip the dead Amxcast. if (DeadInst.contains(I)) continue; PHINode *PN = cast(I->getOperand(0)); if (optimizeAMXCastFromPhi(cast(I), PN, DeadInst)) { DeadInst.insert(PN); Change = true; } } // Since we create new phi and merge AMXCast, some old phis and AMXCast might // have no uses. We do some DeadCodeElimination for them. while (!DeadInst.empty()) { Instruction *I = DeadInst.pop_back_val(); Change |= DCEInstruction(I, DeadInst, TLI); } return Change; } // There might be remaining AMXcast after combineAMXcast and they should be // handled elegantly. bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { IRBuilder<> Builder(AMXCast); AllocaInst *AllocaAddr; Value *I8Ptr, *Stride; auto *Src = AMXCast->getOperand(0); auto Prepare = [&](Type *MemTy) { AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); Stride = Builder.getInt64(64); }; if (AMXCast->getType()->isX86_AMXTy()) { // %2 = amxcast <225 x i32> %src to x86_amx // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, // i8* %addr3, i64 60, x86_amx %2) // --> // %addr = alloca <225 x i32>, align 64 // store <225 x i32> %src, <225 x i32>* %addr, align 64 // %addr2 = bitcast <225 x i32>* %addr to i8* // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, // i8* %addr2, // i64 60) // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, // i8* %addr3, i64 60, x86_amx %2) Use &U = *(AMXCast->use_begin()); unsigned OpNo = U.getOperandNo(); auto *II = dyn_cast(U.getUser()); if (!II) return false; // May be bitcast from x86amx to <256 x i32>. Prepare(AMXCast->getOperand(0)->getType()); Builder.CreateStore(Src, AllocaAddr); // TODO we can pick an constant operand for the shape. Value *Row = nullptr, *Col = nullptr; std::tie(Row, Col) = getShape(II, OpNo); std::array Args = { Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; Value *NewInst = Builder.CreateIntrinsic( Intrinsic::x86_tileloadd64_internal, None, Args); AMXCast->replaceAllUsesWith(NewInst); AMXCast->eraseFromParent(); } else { // %2 = amxcast x86_amx %src to <225 x i32> // --> // %addr = alloca <225 x i32>, align 64 // %addr2 = bitcast <225 x i32>* to i8* // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, // i8* %addr2, i64 %stride) // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 auto *II = dyn_cast(Src); if (!II) return false; // May be bitcast from <256 x i32> to x86amx. Prepare(AMXCast->getType()); Value *Row = II->getOperand(0); Value *Col = II->getOperand(1); std::array Args = { Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); AMXCast->replaceAllUsesWith(NewInst); AMXCast->eraseFromParent(); } return true; } bool X86LowerAMXCast::transformAllAMXCast() { bool Change = false; // Collect tile cast instruction. SmallVector WorkLists; for (BasicBlock &BB : Func) { for (Instruction &I : BB) { if (isAMXCast(&I)) WorkLists.push_back(&I); } } for (auto *Inst : WorkLists) { Change |= transformAMXCast(cast(Inst)); } return Change; } } // anonymous namespace namespace { class X86LowerAMXTypeLegacyPass : public FunctionPass { public: static char ID; X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { bool C = false; TargetMachine *TM = &getAnalysis().getTM(); TargetLibraryInfo *TLI = &getAnalysis().getTLI(F); X86LowerAMXCast LAC(F); C |= LAC.combineAMXcast(TLI); // There might be remaining AMXcast after combineAMXcast and they should be // handled elegantly. C |= LAC.transformAllAMXCast(); X86LowerAMXType LAT(F); C |= LAT.visit(); // Prepare for fast register allocation at O0. // Todo: May better check the volatile model of AMX code, not just // by checking Attribute::OptimizeNone and CodeGenOpt::None. if (TM->getOptLevel() == CodeGenOpt::None) { // If Front End not use O0 but the Mid/Back end use O0, (e.g. // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make // sure the amx data is volatile, that is nessary for AMX fast // register allocation. if (!F.hasFnAttribute(Attribute::OptimizeNone)) { X86VolatileTileData VTD(F); C = VTD.volatileTileData() || C; } } return C; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); AU.addRequired(); } }; } // anonymous namespace static const char PassName[] = "Lower AMX type for load/store"; char X86LowerAMXTypeLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) FunctionPass *llvm::createX86LowerAMXTypePass() { return new X86LowerAMXTypeLegacyPass(); }