diff options
Diffstat (limited to 'contrib/llvm-project/llvm/include/llvm/Transforms/Scalar/GVNExpression.h')
-rw-r--r-- | contrib/llvm-project/llvm/include/llvm/Transforms/Scalar/GVNExpression.h | 664 |
1 files changed, 664 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/include/llvm/Transforms/Scalar/GVNExpression.h b/contrib/llvm-project/llvm/include/llvm/Transforms/Scalar/GVNExpression.h new file mode 100644 index 000000000000..2433890d0df8 --- /dev/null +++ b/contrib/llvm-project/llvm/include/llvm/Transforms/Scalar/GVNExpression.h @@ -0,0 +1,664 @@ +//===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// +/// The header file for the GVN pass that contains expression handling +/// classes +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H +#define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H + +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ArrayRecycler.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <utility> + +namespace llvm { + +class BasicBlock; +class Type; + +namespace GVNExpression { + +enum ExpressionType { + ET_Base, + ET_Constant, + ET_Variable, + ET_Dead, + ET_Unknown, + ET_BasicStart, + ET_Basic, + ET_AggregateValue, + ET_Phi, + ET_MemoryStart, + ET_Call, + ET_Load, + ET_Store, + ET_MemoryEnd, + ET_BasicEnd +}; + +class Expression { +private: + ExpressionType EType; + unsigned Opcode; + mutable hash_code HashVal = 0; + +public: + Expression(ExpressionType ET = ET_Base, unsigned O = ~2U) + : EType(ET), Opcode(O) {} + Expression(const Expression &) = delete; + Expression &operator=(const Expression &) = delete; + virtual ~Expression(); + + static unsigned getEmptyKey() { return ~0U; } + static unsigned getTombstoneKey() { return ~1U; } + + bool operator!=(const Expression &Other) const { return !(*this == Other); } + bool operator==(const Expression &Other) const { + if (getOpcode() != Other.getOpcode()) + return false; + if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey()) + return true; + // Compare the expression type for anything but load and store. + // For load and store we set the opcode to zero to make them equal. + if (getExpressionType() != ET_Load && getExpressionType() != ET_Store && + getExpressionType() != Other.getExpressionType()) + return false; + + return equals(Other); + } + + hash_code getComputedHash() const { + // It's theoretically possible for a thing to hash to zero. In that case, + // we will just compute the hash a few extra times, which is no worse that + // we did before, which was to compute it always. + if (static_cast<unsigned>(HashVal) == 0) + HashVal = getHashValue(); + return HashVal; + } + + virtual bool equals(const Expression &Other) const { return true; } + + // Return true if the two expressions are exactly the same, including the + // normally ignored fields. + virtual bool exactlyEquals(const Expression &Other) const { + return getExpressionType() == Other.getExpressionType() && equals(Other); + } + + unsigned getOpcode() const { return Opcode; } + void setOpcode(unsigned opcode) { Opcode = opcode; } + ExpressionType getExpressionType() const { return EType; } + + // We deliberately leave the expression type out of the hash value. + virtual hash_code getHashValue() const { return getOpcode(); } + + // Debugging support + virtual void printInternal(raw_ostream &OS, bool PrintEType) const { + if (PrintEType) + OS << "etype = " << getExpressionType() << ","; + OS << "opcode = " << getOpcode() << ", "; + } + + void print(raw_ostream &OS) const { + OS << "{ "; + printInternal(OS, true); + OS << "}"; + } + + LLVM_DUMP_METHOD void dump() const; +}; + +inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) { + E.print(OS); + return OS; +} + +class BasicExpression : public Expression { +private: + using RecyclerType = ArrayRecycler<Value *>; + using RecyclerCapacity = RecyclerType::Capacity; + + Value **Operands = nullptr; + unsigned MaxOperands; + unsigned NumOperands = 0; + Type *ValueType = nullptr; + +public: + BasicExpression(unsigned NumOperands) + : BasicExpression(NumOperands, ET_Basic) {} + BasicExpression(unsigned NumOperands, ExpressionType ET) + : Expression(ET), MaxOperands(NumOperands) {} + BasicExpression() = delete; + BasicExpression(const BasicExpression &) = delete; + BasicExpression &operator=(const BasicExpression &) = delete; + ~BasicExpression() override; + + static bool classof(const Expression *EB) { + ExpressionType ET = EB->getExpressionType(); + return ET > ET_BasicStart && ET < ET_BasicEnd; + } + + /// Swap two operands. Used during GVN to put commutative operands in + /// order. + void swapOperands(unsigned First, unsigned Second) { + std::swap(Operands[First], Operands[Second]); + } + + Value *getOperand(unsigned N) const { + assert(Operands && "Operands not allocated"); + assert(N < NumOperands && "Operand out of range"); + return Operands[N]; + } + + void setOperand(unsigned N, Value *V) { + assert(Operands && "Operands not allocated before setting"); + assert(N < NumOperands && "Operand out of range"); + Operands[N] = V; + } + + unsigned getNumOperands() const { return NumOperands; } + + using op_iterator = Value **; + using const_op_iterator = Value *const *; + + op_iterator op_begin() { return Operands; } + op_iterator op_end() { return Operands + NumOperands; } + const_op_iterator op_begin() const { return Operands; } + const_op_iterator op_end() const { return Operands + NumOperands; } + iterator_range<op_iterator> operands() { + return iterator_range<op_iterator>(op_begin(), op_end()); + } + iterator_range<const_op_iterator> operands() const { + return iterator_range<const_op_iterator>(op_begin(), op_end()); + } + + void op_push_back(Value *Arg) { + assert(NumOperands < MaxOperands && "Tried to add too many operands"); + assert(Operands && "Operandss not allocated before pushing"); + Operands[NumOperands++] = Arg; + } + bool op_empty() const { return getNumOperands() == 0; } + + void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) { + assert(!Operands && "Operands already allocated"); + Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator); + } + void deallocateOperands(RecyclerType &Recycler) { + Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands); + } + + void setType(Type *T) { ValueType = T; } + Type *getType() const { return ValueType; } + + bool equals(const Expression &Other) const override { + if (getOpcode() != Other.getOpcode()) + return false; + + const auto &OE = cast<BasicExpression>(Other); + return getType() == OE.getType() && NumOperands == OE.NumOperands && + std::equal(op_begin(), op_end(), OE.op_begin()); + } + + hash_code getHashValue() const override { + return hash_combine(this->Expression::getHashValue(), ValueType, + hash_combine_range(op_begin(), op_end())); + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypeBasic, "; + + this->Expression::printInternal(OS, false); + OS << "operands = {"; + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + OS << "[" << i << "] = "; + Operands[i]->printAsOperand(OS); + OS << " "; + } + OS << "} "; + } +}; + +class op_inserter { +private: + using Container = BasicExpression; + + Container *BE; + +public: + using iterator_category = std::output_iterator_tag; + using value_type = void; + using difference_type = void; + using pointer = void; + using reference = void; + + explicit op_inserter(BasicExpression &E) : BE(&E) {} + explicit op_inserter(BasicExpression *E) : BE(E) {} + + op_inserter &operator=(Value *val) { + BE->op_push_back(val); + return *this; + } + op_inserter &operator*() { return *this; } + op_inserter &operator++() { return *this; } + op_inserter &operator++(int) { return *this; } +}; + +class MemoryExpression : public BasicExpression { +private: + const MemoryAccess *MemoryLeader; + +public: + MemoryExpression(unsigned NumOperands, enum ExpressionType EType, + const MemoryAccess *MemoryLeader) + : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {} + MemoryExpression() = delete; + MemoryExpression(const MemoryExpression &) = delete; + MemoryExpression &operator=(const MemoryExpression &) = delete; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() > ET_MemoryStart && + EB->getExpressionType() < ET_MemoryEnd; + } + + hash_code getHashValue() const override { + return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader); + } + + bool equals(const Expression &Other) const override { + if (!this->BasicExpression::equals(Other)) + return false; + const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other); + + return MemoryLeader == OtherMCE.MemoryLeader; + } + + const MemoryAccess *getMemoryLeader() const { return MemoryLeader; } + void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; } +}; + +class CallExpression final : public MemoryExpression { +private: + CallInst *Call; + +public: + CallExpression(unsigned NumOperands, CallInst *C, + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {} + CallExpression() = delete; + CallExpression(const CallExpression &) = delete; + CallExpression &operator=(const CallExpression &) = delete; + ~CallExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Call; + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypeCall, "; + this->BasicExpression::printInternal(OS, false); + OS << " represents call at "; + Call->printAsOperand(OS); + } +}; + +class LoadExpression final : public MemoryExpression { +private: + LoadInst *Load; + +public: + LoadExpression(unsigned NumOperands, LoadInst *L, + const MemoryAccess *MemoryLeader) + : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {} + + LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L, + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {} + + LoadExpression() = delete; + LoadExpression(const LoadExpression &) = delete; + LoadExpression &operator=(const LoadExpression &) = delete; + ~LoadExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Load; + } + + LoadInst *getLoadInst() const { return Load; } + void setLoadInst(LoadInst *L) { Load = L; } + + bool equals(const Expression &Other) const override; + bool exactlyEquals(const Expression &Other) const override { + return Expression::exactlyEquals(Other) && + cast<LoadExpression>(Other).getLoadInst() == getLoadInst(); + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypeLoad, "; + this->BasicExpression::printInternal(OS, false); + OS << " represents Load at "; + Load->printAsOperand(OS); + OS << " with MemoryLeader " << *getMemoryLeader(); + } +}; + +class StoreExpression final : public MemoryExpression { +private: + StoreInst *Store; + Value *StoredValue; + +public: + StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue, + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S), + StoredValue(StoredValue) {} + StoreExpression() = delete; + StoreExpression(const StoreExpression &) = delete; + StoreExpression &operator=(const StoreExpression &) = delete; + ~StoreExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Store; + } + + StoreInst *getStoreInst() const { return Store; } + Value *getStoredValue() const { return StoredValue; } + + bool equals(const Expression &Other) const override; + + bool exactlyEquals(const Expression &Other) const override { + return Expression::exactlyEquals(Other) && + cast<StoreExpression>(Other).getStoreInst() == getStoreInst(); + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypeStore, "; + this->BasicExpression::printInternal(OS, false); + OS << " represents Store " << *Store; + OS << " with StoredValue "; + StoredValue->printAsOperand(OS); + OS << " and MemoryLeader " << *getMemoryLeader(); + } +}; + +class AggregateValueExpression final : public BasicExpression { +private: + unsigned MaxIntOperands; + unsigned NumIntOperands = 0; + unsigned *IntOperands = nullptr; + +public: + AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands) + : BasicExpression(NumOperands, ET_AggregateValue), + MaxIntOperands(NumIntOperands) {} + AggregateValueExpression() = delete; + AggregateValueExpression(const AggregateValueExpression &) = delete; + AggregateValueExpression & + operator=(const AggregateValueExpression &) = delete; + ~AggregateValueExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_AggregateValue; + } + + using int_arg_iterator = unsigned *; + using const_int_arg_iterator = const unsigned *; + + int_arg_iterator int_op_begin() { return IntOperands; } + int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; } + const_int_arg_iterator int_op_begin() const { return IntOperands; } + const_int_arg_iterator int_op_end() const { + return IntOperands + NumIntOperands; + } + unsigned int_op_size() const { return NumIntOperands; } + bool int_op_empty() const { return NumIntOperands == 0; } + void int_op_push_back(unsigned IntOperand) { + assert(NumIntOperands < MaxIntOperands && + "Tried to add too many int operands"); + assert(IntOperands && "Operands not allocated before pushing"); + IntOperands[NumIntOperands++] = IntOperand; + } + + virtual void allocateIntOperands(BumpPtrAllocator &Allocator) { + assert(!IntOperands && "Operands already allocated"); + IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands); + } + + bool equals(const Expression &Other) const override { + if (!this->BasicExpression::equals(Other)) + return false; + const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other); + return NumIntOperands == OE.NumIntOperands && + std::equal(int_op_begin(), int_op_end(), OE.int_op_begin()); + } + + hash_code getHashValue() const override { + return hash_combine(this->BasicExpression::getHashValue(), + hash_combine_range(int_op_begin(), int_op_end())); + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypeAggregateValue, "; + this->BasicExpression::printInternal(OS, false); + OS << ", intoperands = {"; + for (unsigned i = 0, e = int_op_size(); i != e; ++i) { + OS << "[" << i << "] = " << IntOperands[i] << " "; + } + OS << "}"; + } +}; + +class int_op_inserter { +private: + using Container = AggregateValueExpression; + + Container *AVE; + +public: + using iterator_category = std::output_iterator_tag; + using value_type = void; + using difference_type = void; + using pointer = void; + using reference = void; + + explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {} + explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {} + + int_op_inserter &operator=(unsigned int val) { + AVE->int_op_push_back(val); + return *this; + } + int_op_inserter &operator*() { return *this; } + int_op_inserter &operator++() { return *this; } + int_op_inserter &operator++(int) { return *this; } +}; + +class PHIExpression final : public BasicExpression { +private: + BasicBlock *BB; + +public: + PHIExpression(unsigned NumOperands, BasicBlock *B) + : BasicExpression(NumOperands, ET_Phi), BB(B) {} + PHIExpression() = delete; + PHIExpression(const PHIExpression &) = delete; + PHIExpression &operator=(const PHIExpression &) = delete; + ~PHIExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Phi; + } + + bool equals(const Expression &Other) const override { + if (!this->BasicExpression::equals(Other)) + return false; + const PHIExpression &OE = cast<PHIExpression>(Other); + return BB == OE.BB; + } + + hash_code getHashValue() const override { + return hash_combine(this->BasicExpression::getHashValue(), BB); + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypePhi, "; + this->BasicExpression::printInternal(OS, false); + OS << "bb = " << BB; + } +}; + +class DeadExpression final : public Expression { +public: + DeadExpression() : Expression(ET_Dead) {} + DeadExpression(const DeadExpression &) = delete; + DeadExpression &operator=(const DeadExpression &) = delete; + + static bool classof(const Expression *E) { + return E->getExpressionType() == ET_Dead; + } +}; + +class VariableExpression final : public Expression { +private: + Value *VariableValue; + +public: + VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {} + VariableExpression() = delete; + VariableExpression(const VariableExpression &) = delete; + VariableExpression &operator=(const VariableExpression &) = delete; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Variable; + } + + Value *getVariableValue() const { return VariableValue; } + void setVariableValue(Value *V) { VariableValue = V; } + + bool equals(const Expression &Other) const override { + const VariableExpression &OC = cast<VariableExpression>(Other); + return VariableValue == OC.VariableValue; + } + + hash_code getHashValue() const override { + return hash_combine(this->Expression::getHashValue(), + VariableValue->getType(), VariableValue); + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypeVariable, "; + this->Expression::printInternal(OS, false); + OS << " variable = " << *VariableValue; + } +}; + +class ConstantExpression final : public Expression { +private: + Constant *ConstantValue = nullptr; + +public: + ConstantExpression() : Expression(ET_Constant) {} + ConstantExpression(Constant *constantValue) + : Expression(ET_Constant), ConstantValue(constantValue) {} + ConstantExpression(const ConstantExpression &) = delete; + ConstantExpression &operator=(const ConstantExpression &) = delete; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Constant; + } + + Constant *getConstantValue() const { return ConstantValue; } + void setConstantValue(Constant *V) { ConstantValue = V; } + + bool equals(const Expression &Other) const override { + const ConstantExpression &OC = cast<ConstantExpression>(Other); + return ConstantValue == OC.ConstantValue; + } + + hash_code getHashValue() const override { + return hash_combine(this->Expression::getHashValue(), + ConstantValue->getType(), ConstantValue); + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypeConstant, "; + this->Expression::printInternal(OS, false); + OS << " constant = " << *ConstantValue; + } +}; + +class UnknownExpression final : public Expression { +private: + Instruction *Inst; + +public: + UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {} + UnknownExpression() = delete; + UnknownExpression(const UnknownExpression &) = delete; + UnknownExpression &operator=(const UnknownExpression &) = delete; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Unknown; + } + + Instruction *getInstruction() const { return Inst; } + void setInstruction(Instruction *I) { Inst = I; } + + bool equals(const Expression &Other) const override { + const auto &OU = cast<UnknownExpression>(Other); + return Inst == OU.Inst; + } + + hash_code getHashValue() const override { + return hash_combine(this->Expression::getHashValue(), Inst); + } + + // Debugging support + void printInternal(raw_ostream &OS, bool PrintEType) const override { + if (PrintEType) + OS << "ExpressionTypeUnknown, "; + this->Expression::printInternal(OS, false); + OS << " inst = " << *Inst; + } +}; + +} // end namespace GVNExpression + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H |