aboutsummaryrefslogtreecommitdiff
path: root/include/llvm/Analysis/ScalarEvolutionExpressions.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/llvm/Analysis/ScalarEvolutionExpressions.h')
-rw-r--r--include/llvm/Analysis/ScalarEvolutionExpressions.h69
1 files changed, 69 insertions, 0 deletions
diff --git a/include/llvm/Analysis/ScalarEvolutionExpressions.h b/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 47b371029186..ded12974face 100644
--- a/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -15,6 +15,7 @@
#define LLVM_ANALYSIS_SCALAREVOLUTION_EXPRESSIONS_H
#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/ErrorHandling.h"
namespace llvm {
@@ -493,6 +494,74 @@ namespace llvm {
llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
}
};
+
+ /// Visit all nodes in the expression tree using worklist traversal.
+ ///
+ /// Visitor implements:
+ /// // return true to follow this node.
+ /// bool follow(const SCEV *S);
+ /// // return true to terminate the search.
+ /// bool isDone();
+ template<typename SV>
+ class SCEVTraversal {
+ SV &Visitor;
+ SmallVector<const SCEV *, 8> Worklist;
+ SmallPtrSet<const SCEV *, 8> Visited;
+
+ void push(const SCEV *S) {
+ if (Visited.insert(S) && Visitor.follow(S))
+ Worklist.push_back(S);
+ }
+ public:
+ SCEVTraversal(SV& V): Visitor(V) {}
+
+ void visitAll(const SCEV *Root) {
+ push(Root);
+ while (!Worklist.empty() && !Visitor.isDone()) {
+ const SCEV *S = Worklist.pop_back_val();
+
+ switch (S->getSCEVType()) {
+ case scConstant:
+ case scUnknown:
+ break;
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend:
+ push(cast<SCEVCastExpr>(S)->getOperand());
+ break;
+ case scAddExpr:
+ case scMulExpr:
+ case scSMaxExpr:
+ case scUMaxExpr:
+ case scAddRecExpr: {
+ const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
+ for (SCEVNAryExpr::op_iterator I = NAry->op_begin(),
+ E = NAry->op_end(); I != E; ++I) {
+ push(*I);
+ }
+ break;
+ }
+ case scUDivExpr: {
+ const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
+ push(UDiv->getLHS());
+ push(UDiv->getRHS());
+ break;
+ }
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+ default:
+ llvm_unreachable("Unknown SCEV kind!");
+ }
+ }
+ }
+ };
+
+ /// Use SCEVTraversal to visit all nodes in the givien expression tree.
+ template<typename SV>
+ void visitAll(const SCEV *Root, SV& Visitor) {
+ SCEVTraversal<SV> T(Visitor);
+ T.visitAll(Root);
+ }
}
#endif