aboutsummaryrefslogtreecommitdiff
path: root/include/llvm/Analysis/DivergenceAnalysis.h
blob: 3cfb9d13df94d094760ee816d81a5229057ced09 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
//===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file
// The divergence analysis determines which instructions and branches are
// divergent given a set of divergent source instructions.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
#define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H

#include "llvm/ADT/DenseSet.h"
#include "llvm/Analysis/SyncDependenceAnalysis.h"
#include "llvm/IR/Function.h"
#include "llvm/Pass.h"
#include <vector>

namespace llvm {
class Module;
class Value;
class Instruction;
class Loop;
class raw_ostream;
class TargetTransformInfo;

/// \brief Generic divergence analysis for reducible CFGs.
///
/// This analysis propagates divergence in a data-parallel context from sources
/// of divergence to all users. It requires reducible CFGs. All assignments
/// should be in SSA form.
class DivergenceAnalysis {
public:
  /// \brief This instance will analyze the whole function \p F or the loop \p
  /// RegionLoop.
  ///
  /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
  /// Otherwise the whole function is analyzed.
  /// \param IsLCSSAForm whether the analysis may assume that the IR in the
  /// region in in LCSSA form.
  DivergenceAnalysis(const Function &F, const Loop *RegionLoop,
                     const DominatorTree &DT, const LoopInfo &LI,
                     SyncDependenceAnalysis &SDA, bool IsLCSSAForm);

  /// \brief The loop that defines the analyzed region (if any).
  const Loop *getRegionLoop() const { return RegionLoop; }
  const Function &getFunction() const { return F; }

  /// \brief Whether \p BB is part of the region.
  bool inRegion(const BasicBlock &BB) const;
  /// \brief Whether \p I is part of the region.
  bool inRegion(const Instruction &I) const;

  /// \brief Mark \p UniVal as a value that is always uniform.
  void addUniformOverride(const Value &UniVal);

  /// \brief Mark \p DivVal as a value that is always divergent.
  void markDivergent(const Value &DivVal);

  /// \brief Propagate divergence to all instructions in the region.
  /// Divergence is seeded by calls to \p markDivergent.
  void compute();

  /// \brief Whether any value was marked or analyzed to be divergent.
  bool hasDetectedDivergence() const { return !DivergentValues.empty(); }

  /// \brief Whether \p Val will always return a uniform value regardless of its
  /// operands
  bool isAlwaysUniform(const Value &Val) const;

  /// \brief Whether \p Val is a divergent value
  bool isDivergent(const Value &Val) const;

  void print(raw_ostream &OS, const Module *) const;

private:
  bool updateTerminator(const Instruction &Term) const;
  bool updatePHINode(const PHINode &Phi) const;

  /// \brief Computes whether \p Inst is divergent based on the
  /// divergence of its operands.
  ///
  /// \returns Whether \p Inst is divergent.
  ///
  /// This should only be called for non-phi, non-terminator instructions.
  bool updateNormalInstruction(const Instruction &Inst) const;

  /// \brief Mark users of live-out users as divergent.
  ///
  /// \param LoopHeader the header of the divergent loop.
  ///
  /// Marks all users of live-out values of the loop headed by \p LoopHeader
  /// as divergent and puts them on the worklist.
  void taintLoopLiveOuts(const BasicBlock &LoopHeader);

  /// \brief Push all users of \p Val (in the region) to the worklist
  void pushUsers(const Value &I);

  /// \brief Push all phi nodes in @block to the worklist
  void pushPHINodes(const BasicBlock &Block);

  /// \brief Mark \p Block as join divergent
  ///
  /// A block is join divergent if two threads may reach it from different
  /// incoming blocks at the same time.
  void markBlockJoinDivergent(const BasicBlock &Block) {
    DivergentJoinBlocks.insert(&Block);
  }

  /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
  bool isTemporalDivergent(const BasicBlock &ObservingBlock,
                           const Value &Val) const;

  /// \brief Whether \p Block is join divergent
  ///
  /// (see markBlockJoinDivergent).
  bool isJoinDivergent(const BasicBlock &Block) const {
    return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end();
  }

  /// \brief Propagate control-induced divergence to users (phi nodes and
  /// instructions).
  //
  // \param JoinBlock is a divergent loop exit or join point of two disjoint
  // paths.
  // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop.
  bool propagateJoinDivergence(const BasicBlock &JoinBlock,
                               const Loop *TermLoop);

  /// \brief Propagate induced value divergence due to control divergence in \p
  /// Term.
  void propagateBranchDivergence(const Instruction &Term);

  /// \brief Propagate divergent caused by a divergent loop exit.
  ///
  /// \param ExitingLoop is a divergent loop.
  void propagateLoopDivergence(const Loop &ExitingLoop);

private:
  const Function &F;
  // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
  // Otw, analyze the whole function
  const Loop *RegionLoop;

  const DominatorTree &DT;
  const LoopInfo &LI;

  // Recognized divergent loops
  DenseSet<const Loop *> DivergentLoops;

  // The SDA links divergent branches to divergent control-flow joins.
  SyncDependenceAnalysis &SDA;

  // Use simplified code path for LCSSA form.
  bool IsLCSSAForm;

  // Set of known-uniform values.
  DenseSet<const Value *> UniformOverrides;

  // Blocks with joining divergent control from different predecessors.
  DenseSet<const BasicBlock *> DivergentJoinBlocks;

  // Detected/marked divergent values.
  DenseSet<const Value *> DivergentValues;

  // Internal worklist for divergence propagation.
  std::vector<const Instruction *> Worklist;
};

/// \brief Divergence analysis frontend for GPU kernels.
class GPUDivergenceAnalysis {
  SyncDependenceAnalysis SDA;
  DivergenceAnalysis DA;

public:
  /// Runs the divergence analysis on @F, a GPU kernel
  GPUDivergenceAnalysis(Function &F, const DominatorTree &DT,
                        const PostDominatorTree &PDT, const LoopInfo &LI,
                        const TargetTransformInfo &TTI);

  /// Whether any divergence was detected.
  bool hasDivergence() const { return DA.hasDetectedDivergence(); }

  /// The GPU kernel this analysis result is for
  const Function &getFunction() const { return DA.getFunction(); }

  /// Whether \p V is divergent.
  bool isDivergent(const Value &V) const;

  /// Whether \p V is uniform/non-divergent
  bool isUniform(const Value &V) const { return !isDivergent(V); }

  /// Print all divergent values in the kernel.
  void print(raw_ostream &OS, const Module *) const;
};

} // namespace llvm

#endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H