diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 291 |
1 files changed, 226 insertions, 65 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 90314b17b5e2..8e251ca940a3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -42,6 +42,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/MatrixUtils.h" using namespace llvm; using namespace PatternMatch; @@ -61,6 +63,9 @@ static cl::opt<unsigned> TileSize( "fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc( "Tile size for matrix instruction fusion using square-shaped tiles.")); +static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false), + cl::Hidden, + cl::desc("Generate loop nest for tiling.")); static cl::opt<bool> ForceFusion( "force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable.")); @@ -182,10 +187,10 @@ class LowerMatrixIntrinsics { Function &Func; const DataLayout &DL; const TargetTransformInfo &TTI; - AliasAnalysis &AA; - DominatorTree &DT; - LoopInfo &LI; - OptimizationRemarkEmitter &ORE; + AliasAnalysis *AA; + DominatorTree *DT; + LoopInfo *LI; + OptimizationRemarkEmitter *ORE; /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. struct OpInfoTy { @@ -241,7 +246,7 @@ class LowerMatrixIntrinsics { void setVector(unsigned i, Value *V) { Vectors[i] = V; } - Type *getElementType() { return getVectorTy()->getElementType(); } + Type *getElementType() const { return getVectorTy()->getElementType(); } unsigned getNumVectors() const { if (isColumnMajor()) @@ -271,7 +276,7 @@ class LowerMatrixIntrinsics { return getVectorTy(); } - VectorType *getVectorTy() { + VectorType *getVectorTy() const { return cast<VectorType>(Vectors[0]->getType()); } @@ -329,9 +334,8 @@ class LowerMatrixIntrinsics { Value *extractVector(unsigned I, unsigned J, unsigned NumElts, IRBuilder<> &Builder) const { Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); - Value *Undef = UndefValue::get(Vec->getType()); return Builder.CreateShuffleVector( - Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), + Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), "block"); } }; @@ -393,8 +397,8 @@ class LowerMatrixIntrinsics { public: LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, - AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, - OptimizationRemarkEmitter &ORE) + AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, + OptimizationRemarkEmitter *ORE) : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), LI(LI), ORE(ORE) {} @@ -442,12 +446,11 @@ public: // Otherwise split MatrixVal. SmallVector<Value *, 16> SplitVecs; - Value *Undef = UndefValue::get(VType); for (unsigned MaskStart = 0; MaskStart < cast<FixedVectorType>(VType)->getNumElements(); MaskStart += SI.getStride()) { Value *V = Builder.CreateShuffleVector( - MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0), + MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), "split"); SplitVecs.push_back(V); } @@ -485,6 +488,7 @@ public: case Instruction::FAdd: case Instruction::FSub: case Instruction::FMul: // Scalar multiply. + case Instruction::FNeg: case Instruction::Add: case Instruction::Mul: case Instruction::Sub: @@ -527,8 +531,7 @@ public: // list. LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n"); while (!WorkList.empty()) { - Instruction *Inst = WorkList.back(); - WorkList.pop_back(); + Instruction *Inst = WorkList.pop_back_val(); // New entry, set the value and insert operands bool Propagate = false; @@ -598,8 +601,7 @@ public: // worklist. LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n"); while (!WorkList.empty()) { - Value *V = WorkList.back(); - WorkList.pop_back(); + Value *V = WorkList.pop_back_val(); size_t BeforeProcessingV = WorkList.size(); if (!isa<Instruction>(V)) @@ -721,14 +723,18 @@ public: Value *Op2; if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) Changed |= VisitBinaryOperator(BinOp); + if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) + Changed |= VisitUnaryOperator(UnOp); if (match(Inst, m_Load(m_Value(Op1)))) Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); } - RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); - RemarkGen.emitRemarks(); + if (ORE) { + RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func); + RemarkGen.emitRemarks(); + } for (Instruction *Inst : reverse(ToRemove)) Inst->eraseFromParent(); @@ -934,10 +940,8 @@ public: unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements(); assert(NumElts >= BlockNumElts && "Too few elements for current block"); - Value *Undef = UndefValue::get(Block->getType()); Block = Builder.CreateShuffleVector( - Block, Undef, - createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); + Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts)); // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, // 8, 4, 5, 6 @@ -1085,7 +1089,7 @@ public: MemoryLocation StoreLoc = MemoryLocation::get(Store); MemoryLocation LoadLoc = MemoryLocation::get(Load); - AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); + AliasResult LdAliased = AA->alias(LoadLoc, StoreLoc); // If we can statically determine noalias we're good. if (!LdAliased) @@ -1101,14 +1105,17 @@ public: // as we adjust Check0 and Check1's branches. SmallVector<DominatorTree::UpdateType, 4> DTUpdates; for (BasicBlock *Succ : successors(Check0)) - DTUpdates.push_back({DT.Delete, Check0, Succ}); + DTUpdates.push_back({DT->Delete, Check0, Succ}); - BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, - nullptr, "alias_cont"); + BasicBlock *Check1 = + SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, + nullptr, "alias_cont"); BasicBlock *Copy = - SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); - BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, - nullptr, "no_alias"); + SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, + nullptr, "copy"); + BasicBlock *Fusion = + SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI, + nullptr, "no_alias"); // Check if the loaded memory location begins before the end of the store // location. If the condition holds, they might overlap, otherwise they are @@ -1152,11 +1159,11 @@ public: PHI->addIncoming(NewLd, Copy); // Adjust DT. - DTUpdates.push_back({DT.Insert, Check0, Check1}); - DTUpdates.push_back({DT.Insert, Check0, Fusion}); - DTUpdates.push_back({DT.Insert, Check1, Copy}); - DTUpdates.push_back({DT.Insert, Check1, Fusion}); - DT.applyUpdates(DTUpdates); + DTUpdates.push_back({DT->Insert, Check0, Check1}); + DTUpdates.push_back({DT->Insert, Check0, Fusion}); + DTUpdates.push_back({DT->Insert, Check1, Copy}); + DTUpdates.push_back({DT->Insert, Check1, Fusion}); + DT->applyUpdates(DTUpdates); return PHI; } @@ -1202,6 +1209,63 @@ public: return Res; } + void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, + Value *RPtr, ShapeInfo RShape, StoreInst *Store, + bool AllowContract) { + auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); + + // Create the main tiling loop nest. + TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + Instruction *InsertI = cast<Instruction>(MatMul); + BasicBlock *Start = InsertI->getParent(); + BasicBlock *End = + SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); + IRBuilder<> Builder(MatMul); + BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI); + + Type *TileVecTy = + FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); + MatrixTy TileResult; + // Insert in the inner loop header. + Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); + // Create PHI nodes for the result columns to accumulate across iterations. + SmallVector<PHINode *, 4> ColumnPhis; + for (unsigned I = 0; I < TileSize; I++) { + auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); + Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), + TI.RowLoopHeader->getSingleSuccessor()); + TileResult.addVector(Phi); + ColumnPhis.push_back(Phi); + } + + // Insert in the inner loop body, which computes + // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) + Builder.SetInsertPoint(InnerBody->getTerminator()); + // Load tiles of the operands. + MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, + {TileSize, TileSize}, EltType, Builder); + MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, + {TileSize, TileSize}, EltType, Builder); + emitMatrixMultiply(TileResult, A, B, AllowContract, Builder, true); + // Store result after the inner loop is done. + Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); + storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), + Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, + TI.CurrentRow, TI.CurrentCol, EltType, Builder); + + for (unsigned I = 0; I < TileResult.getNumVectors(); I++) + ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); + + // Force unrolling of a few iterations of the inner loop, to make sure there + // is enough work per iteration. + // FIXME: The unroller should make this decision directly instead, but + // currently the cost-model is not up to the task. + unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); + addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), + "llvm.loop.unroll.count", InnerLoopUnrollCount); + } + void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, StoreInst *Store, SmallPtrSetImpl<Instruction *> &FusedInsts) { @@ -1224,28 +1288,34 @@ public: bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && MatMul->hasAllowContract()); - IRBuilder<> Builder(Store); - for (unsigned J = 0; J < C; J += TileSize) - for (unsigned I = 0; I < R; I += TileSize) { - const unsigned TileR = std::min(R - I, unsigned(TileSize)); - const unsigned TileC = std::min(C - J, unsigned(TileSize)); - MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); - - for (unsigned K = 0; K < M; K += TileSize) { - const unsigned TileM = std::min(M - K, unsigned(TileSize)); - MatrixTy A = - loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), - LShape, Builder.getInt64(I), Builder.getInt64(K), - {TileR, TileM}, EltType, Builder); - MatrixTy B = - loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), - RShape, Builder.getInt64(K), Builder.getInt64(J), - {TileM, TileC}, EltType, Builder); - emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); + if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) + createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store, + AllowContract); + else { + IRBuilder<> Builder(Store); + for (unsigned J = 0; J < C; J += TileSize) + for (unsigned I = 0; I < R; I += TileSize) { + const unsigned TileR = std::min(R - I, unsigned(TileSize)); + const unsigned TileC = std::min(C - J, unsigned(TileSize)); + MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); + + for (unsigned K = 0; K < M; K += TileSize) { + const unsigned TileM = std::min(M - K, unsigned(TileSize)); + MatrixTy A = + loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), + LShape, Builder.getInt64(I), Builder.getInt64(K), + {TileR, TileM}, EltType, Builder); + MatrixTy B = + loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), + RShape, Builder.getInt64(K), Builder.getInt64(J), + {TileM, TileC}, EltType, Builder); + emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); + } + storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, + Builder.getInt64(I), Builder.getInt64(J), EltType, + Builder); } - storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, - Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); - } + } // Mark eliminated instructions as fused and remove them. FusedInsts.insert(Store); @@ -1272,9 +1342,11 @@ public: void LowerMatrixMultiplyFused(CallInst *MatMul, SmallPtrSetImpl<Instruction *> &FusedInsts) { if (!FuseMatrix || !MatMul->hasOneUse() || - MatrixLayout != MatrixLayoutTy::ColumnMajor) + MatrixLayout != MatrixLayoutTy::ColumnMajor || !DT) return; + assert(AA && LI && "Analyses should be available"); + auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); @@ -1283,7 +1355,7 @@ public: // we create invalid IR. // FIXME: See if we can hoist the store address computation. auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); - if (AddrI && (!DT.dominates(AddrI, MatMul))) + if (AddrI && (!DT->dominates(AddrI, MatMul))) return; emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); @@ -1300,6 +1372,8 @@ public: const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); + assert(Lhs.getElementType() == Rhs.getElementType() && + "Matrix multiply argument element types do not match."); const unsigned R = LShape.NumRows; const unsigned C = RShape.NumColumns; @@ -1307,6 +1381,8 @@ public: // Initialize the output MatrixTy Result(R, C, EltType); + assert(Lhs.getElementType() == Result.getElementType() && + "Matrix multiply result element type does not match arguments."); bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && MatMul->hasAllowContract()); @@ -1424,6 +1500,40 @@ public: return true; } + /// Lower unary operators, if shape information is available. + bool VisitUnaryOperator(UnaryOperator *Inst) { + auto I = ShapeMap.find(Inst); + if (I == ShapeMap.end()) + return false; + + Value *Op = Inst->getOperand(0); + + IRBuilder<> Builder(Inst); + ShapeInfo &Shape = I->second; + + MatrixTy Result; + MatrixTy M = getMatrix(Op, Shape, Builder); + + // Helper to perform unary op on vectors. + auto BuildVectorOp = [&Builder, Inst](Value *Op) { + switch (Inst->getOpcode()) { + case Instruction::FNeg: + return Builder.CreateFNeg(Op); + default: + llvm_unreachable("Unsupported unary operator for matrix"); + } + }; + + for (unsigned I = 0; I < Shape.getNumVectors(); ++I) + Result.addVector(BuildVectorOp(M.getVector(I))); + + finalizeLowering(Inst, + Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()), + Builder); + return true; + } + /// Helper to linearize a matrix expression tree into a string. Currently /// matrix expressions are linarized by starting at an expression leaf and /// linearizing bottom up. @@ -1488,7 +1598,7 @@ public: if (Value *Ptr = getPointerOperand(V)) return getUnderlyingObjectThroughLoads(Ptr); else if (V->getType()->isPointerTy()) - return GetUnderlyingObject(V, DL); + return getUnderlyingObject(V); return V; } @@ -1524,7 +1634,7 @@ public: write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) .drop_front(StringRef("llvm.matrix.").size())); write("."); - std::string Tmp = ""; + std::string Tmp; raw_string_ostream SS(Tmp); switch (II->getIntrinsicID()) { @@ -1737,7 +1847,6 @@ public: for (Value *Op : cast<Instruction>(V)->operand_values()) collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); - return; } /// Calculate the number of exclusive and shared op counts for expression @@ -1863,15 +1972,25 @@ public: PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TTI = AM.getResult<TargetIRAnalysis>(F); - auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - auto &AA = AM.getResult<AAManager>(F); - auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &LI = AM.getResult<LoopAnalysis>(F); + OptimizationRemarkEmitter *ORE = nullptr; + AAResults *AA = nullptr; + DominatorTree *DT = nullptr; + LoopInfo *LI = nullptr; + + if (!Minimal) { + ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + AA = &AM.getResult<AAManager>(F); + DT = &AM.getResult<DominatorTreeAnalysis>(F); + LI = &AM.getResult<LoopAnalysis>(F); + } LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); if (LMT.Visit()) { PreservedAnalyses PA; - PA.preserveSet<CFGAnalyses>(); + if (!Minimal) { + PA.preserve<LoopAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + } return PA; } return PreservedAnalyses::all(); @@ -1894,7 +2013,7 @@ public: auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE); + LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); bool C = LMT.Visit(); return C; } @@ -1925,3 +2044,45 @@ INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, Pass *llvm::createLowerMatrixIntrinsicsPass() { return new LowerMatrixIntrinsicsLegacyPass(); } + +namespace { + +/// A lightweight version of the matrix lowering pass that only requires TTI. +/// Advanced features that require DT, AA or ORE like tiling are disabled. This +/// is used to lower matrix intrinsics if the main lowering pass is not run, for +/// example with -O0. +class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { +public: + static char ID; + + LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { + initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); + bool C = LMT.Visit(); + return C; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.setPreservesCFG(); + } +}; +} // namespace + +static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; +char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, + "lower-matrix-intrinsics-minimal", pass_name_minimal, + false, false) +INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, + "lower-matrix-intrinsics-minimal", pass_name_minimal, false, + false) + +Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { + return new LowerMatrixIntrinsicsMinimalLegacyPass(); +} |