LLVM 17.0.0git
SyncDependenceAnalysis.cpp
Go to the documentation of this file.
1//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an algorithm that returns for a divergent branch
10// the set of basic blocks whose phi nodes become divergent due to divergent
11// control. These are the blocks that are reachable by two disjoint paths from
12// the branch or loop exits that have a reaching path that is disjoint from a
13// path to the loop latch.
14//
15// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
16// control-induced divergence in phi nodes.
17//
18//
19// -- Reference --
20// The algorithm is presented in Section 5 of
21//
22// An abstract interpretation for SPMD divergence
23// on reducible control flow graphs.
24// Julian Rosemann, Simon Moll and Sebastian Hack
25// POPL '21
26//
27//
28// -- Sync dependence --
29// Sync dependence characterizes the control flow aspect of the
30// propagation of branch divergence. For example,
31//
32// %cond = icmp slt i32 %tid, 10
33// br i1 %cond, label %then, label %else
34// then:
35// br label %merge
36// else:
37// br label %merge
38// merge:
39// %a = phi i32 [ 0, %then ], [ 1, %else ]
40//
41// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
42// because %tid is not on its use-def chains, %a is sync dependent on %tid
43// because the branch "br i1 %cond" depends on %tid and affects which value %a
44// is assigned to.
45//
46//
47// -- Reduction to SSA construction --
48// There are two disjoint paths from A to X, if a certain variant of SSA
49// construction places a phi node in X under the following set-up scheme.
50//
51// This variant of SSA construction ignores incoming undef values.
52// That is paths from the entry without a definition do not result in
53// phi nodes.
54//
55// entry
56// / \
57// A \
58// / \ Y
59// B C /
60// \ / \ /
61// D E
62// \ /
63// F
64//
65// Assume that A contains a divergent branch. We are interested
66// in the set of all blocks where each block is reachable from A
67// via two disjoint paths. This would be the set {D, F} in this
68// case.
69// To generally reduce this query to SSA construction we introduce
70// a virtual variable x and assign to x different values in each
71// successor block of A.
72//
73// entry
74// / \
75// A \
76// / \ Y
77// x = 0 x = 1 /
78// \ / \ /
79// D E
80// \ /
81// F
82//
83// Our flavor of SSA construction for x will construct the following
84//
85// entry
86// / \
87// A \
88// / \ Y
89// x0 = 0 x1 = 1 /
90// \ / \ /
91// x2 = phi E
92// \ /
93// x3 = phi
94//
95// The blocks D and F contain phi nodes and are thus each reachable
96// by two disjoins paths from A.
97//
98// -- Remarks --
99// * In case of loop exits we need to check the disjoint path criterion for loops.
100// To this end, we check whether the definition of x differs between the
101// loop exit and the loop header (_after_ SSA construction).
102//
103// -- Known Limitations & Future Work --
104// * The algorithm requires reducible loops because the implementation
105// implicitly performs a single iteration of the underlying data flow analysis.
106// This was done for pragmatism, simplicity and speed.
107//
108// Relevant related work for extending the algorithm to irreducible control:
109// A simple algorithm for global data flow analysis problems.
110// Matthew S. Hecht and Jeffrey D. Ullman.
111// SIAM Journal on Computing, 4(4):519–532, December 1975.
112//
113// * Another reason for requiring reducible loops is that points of
114// synchronization in irreducible loops aren't 'obvious' - there is no unique
115// header where threads 'should' synchronize when entering or coming back
116// around from the latch.
117//
118//===----------------------------------------------------------------------===//
119
121#include "llvm/ADT/SmallPtrSet.h"
123#include "llvm/IR/BasicBlock.h"
124#include "llvm/IR/CFG.h"
125#include "llvm/IR/Dominators.h"
126#include "llvm/IR/Function.h"
127
128#include <functional>
129
130#define DEBUG_TYPE "sync-dependence"
131
132// The SDA algorithm operates on a modified CFG - we modify the edges leaving
133// loop headers as follows:
134//
135// * We remove all edges leaving all loop headers.
136// * We add additional edges from the loop headers to their exit blocks.
137//
138// The modification is virtual, that is whenever we visit a loop header we
139// pretend it had different successors.
140namespace {
141using namespace llvm;
142
143// Custom Post-Order Traveral
144//
145// We cannot use the vanilla (R)PO computation of LLVM because:
146// * We (virtually) modify the CFG.
147// * We want a loop-compact block enumeration, that is the numbers assigned to
148// blocks of a loop form an interval
149//
150using POCB = std::function<void(const BasicBlock &)>;
151using VisitedSet = std::set<const BasicBlock *>;
152using BlockStack = std::vector<const BasicBlock *>;
153
154// forward
155static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
156 VisitedSet &Finalized);
157
158// for a nested region (top-level loop or nested loop)
159static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
160 POCB CallBack, VisitedSet &Finalized) {
161 const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
162 while (!Stack.empty()) {
163 const auto *NextBB = Stack.back();
164
165 auto *NestedLoop = LI.getLoopFor(NextBB);
166 bool IsNestedLoop = NestedLoop != Loop;
167
168 // Treat the loop as a node
169 if (IsNestedLoop) {
171 NestedLoop->getUniqueExitBlocks(NestedExits);
172 bool PushedNodes = false;
173 for (const auto *NestedExitBB : NestedExits) {
174 if (NestedExitBB == LoopHeader)
175 continue;
176 if (Loop && !Loop->contains(NestedExitBB))
177 continue;
178 if (Finalized.count(NestedExitBB))
179 continue;
180 PushedNodes = true;
181 Stack.push_back(NestedExitBB);
182 }
183 if (!PushedNodes) {
184 // All loop exits finalized -> finish this node
185 Stack.pop_back();
186 computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
187 }
188 continue;
189 }
190
191 // DAG-style
192 bool PushedNodes = false;
193 for (const auto *SuccBB : successors(NextBB)) {
194 if (SuccBB == LoopHeader)
195 continue;
196 if (Loop && !Loop->contains(SuccBB))
197 continue;
198 if (Finalized.count(SuccBB))
199 continue;
200 PushedNodes = true;
201 Stack.push_back(SuccBB);
202 }
203 if (!PushedNodes) {
204 // Never push nodes twice
205 Stack.pop_back();
206 if (!Finalized.insert(NextBB).second)
207 continue;
208 CallBack(*NextBB);
209 }
210 }
211}
212
213static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
214 VisitedSet Finalized;
215 BlockStack Stack;
216 Stack.reserve(24); // FIXME made-up number
217 Stack.push_back(&F.getEntryBlock());
218 computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
219}
220
221static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
222 VisitedSet &Finalized) {
223 /// Call CallBack on all loop blocks.
224 std::vector<const BasicBlock *> Stack;
225 const auto *LoopHeader = Loop.getHeader();
226
227 // Visit the header last
228 Finalized.insert(LoopHeader);
229 CallBack(*LoopHeader);
230
231 // Initialize with immediate successors
232 for (const auto *BB : successors(LoopHeader)) {
233 if (!Loop.contains(BB))
234 continue;
235 if (BB == LoopHeader)
236 continue;
237 Stack.push_back(BB);
238 }
239
240 // Compute PO inside region
241 computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
242}
243
244} // namespace
245
246namespace llvm {
247
248ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
249
251 const PostDominatorTree &PDT,
252 const LoopInfo &LI)
253 : DT(DT), PDT(PDT), LI(LI) {
254 computeTopLevelPO(*DT.getRoot()->getParent(), LI,
255 [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
256}
257
259
260namespace {
261// divergence propagator for reducible CFGs
263 const ModifiedPO &LoopPOT;
264 const DominatorTree &DT;
265 const PostDominatorTree &PDT;
266 const LoopInfo &LI;
268
269 // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
270 // block B
271 // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
272 // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
273 // from X or B is an immediate successor of X (initial value).
274 using BlockLabelVec = std::vector<const BasicBlock *>;
275 BlockLabelVec BlockLabels;
276 // divergent join and loop exit descriptor.
277 std::unique_ptr<ControlDivergenceDesc> DivDesc;
278
279 DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
280 const PostDominatorTree &PDT, const LoopInfo &LI,
282 : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
283 BlockLabels(LoopPOT.size(), nullptr),
285
286 void printDefs(raw_ostream &Out) {
287 Out << "Propagator::BlockLabels {\n";
288 for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
289 const auto *Label = BlockLabels[BlockIdx];
290 Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
291 << ") : ";
292 if (!Label) {
293 Out << "<null>\n";
294 } else {
295 Out << Label->getName() << "\n";
296 }
297 }
298 Out << "}\n";
299 }
300
301 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
302 // causes a divergent join.
303 bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
304 auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
305
306 // unset or same reaching label
307 const auto *OldLabel = BlockLabels[SuccIdx];
308 if (!OldLabel || (OldLabel == &PushedLabel)) {
309 BlockLabels[SuccIdx] = &PushedLabel;
310 return false;
311 }
312
313 // Update the definition
314 BlockLabels[SuccIdx] = &SuccBlock;
315 return true;
316 }
317
318 // visiting a virtual loop exit edge from the loop header --> temporal
319 // divergence on join
320 bool visitLoopExitEdge(const BasicBlock &ExitBlock,
321 const BasicBlock &DefBlock, bool FromParentLoop) {
322 // Pushing from a non-parent loop cannot cause temporal divergence.
323 if (!FromParentLoop)
324 return visitEdge(ExitBlock, DefBlock);
325
326 if (!computeJoin(ExitBlock, DefBlock))
327 return false;
328
329 // Identified a divergent loop exit
330 DivDesc->LoopDivBlocks.insert(&ExitBlock);
331 LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
332 << "\n");
333 return true;
334 }
335
336 // process \p SuccBlock with reaching definition \p DefBlock
337 bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
338 if (!computeJoin(SuccBlock, DefBlock))
339 return false;
340
341 // Divergent, disjoint paths join.
342 DivDesc->JoinDivBlocks.insert(&SuccBlock);
343 LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
344 return true;
345 }
346
347 std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
349
350 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
351 << "\n");
352
353 const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
354
355 // Early stopping criterion
356 int FloorIdx = LoopPOT.size() - 1;
357 const BasicBlock *FloorLabel = nullptr;
358
359 // bootstrap with branch targets
360 int BlockIdx = 0;
361
362 for (const auto *SuccBlock : successors(&DivTermBlock)) {
363 auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
364 BlockLabels[SuccIdx] = SuccBlock;
365
366 // Find the successor with the highest index to start with
367 BlockIdx = std::max<int>(BlockIdx, SuccIdx);
368 FloorIdx = std::min<int>(FloorIdx, SuccIdx);
369
370 // Identify immediate divergent loop exits
371 if (!DivBlockLoop)
372 continue;
373
374 const auto *BlockLoop = LI.getLoopFor(SuccBlock);
375 if (BlockLoop && DivBlockLoop->contains(BlockLoop))
376 continue;
377 DivDesc->LoopDivBlocks.insert(SuccBlock);
378 LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
379 << SuccBlock->getName() << "\n");
380 }
381
382 // propagate definitions at the immediate successors of the node in RPO
383 for (; BlockIdx >= FloorIdx; --BlockIdx) {
384 LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
385
386 // Any label available here
387 const auto *Label = BlockLabels[BlockIdx];
388 if (!Label)
389 continue;
390
391 // Ok. Get the block
392 const auto *Block = LoopPOT.getBlockAt(BlockIdx);
393 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
394
395 auto *BlockLoop = LI.getLoopFor(Block);
396 bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
397 bool CausedJoin = false;
398 int LoweredFloorIdx = FloorIdx;
399 if (IsLoopHeader) {
400 // Disconnect from immediate successors and propagate directly to loop
401 // exits.
402 SmallVector<BasicBlock *, 4> BlockLoopExits;
403 BlockLoop->getExitBlocks(BlockLoopExits);
404
405 bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
406 for (const auto *BlockLoopExit : BlockLoopExits) {
407 CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
408 LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
409 LoopPOT.getIndexOf(*BlockLoopExit));
410 }
411 } else {
412 // Acyclic successor case
413 for (const auto *SuccBlock : successors(Block)) {
414 CausedJoin |= visitEdge(*SuccBlock, *Label);
415 LoweredFloorIdx =
416 std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
417 }
418 }
419
420 // Floor update
421 if (CausedJoin) {
422 // 1. Different labels pushed to successors
423 FloorIdx = LoweredFloorIdx;
424 } else if (FloorLabel != Label) {
425 // 2. No join caused BUT we pushed a label that is different than the
426 // last pushed label
427 FloorIdx = LoweredFloorIdx;
428 FloorLabel = Label;
429 }
430 }
431
432 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
433
434 return std::move(DivDesc);
435 }
436};
437} // end anonymous namespace
438
439#ifndef NDEBUG
440static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
441 Out << "[";
442 ListSeparator LS;
443 for (const auto *BB : Blocks)
444 Out << LS << BB->getName();
445 Out << "]";
446}
447#endif
448
451 // trivial case
452 if (Term.getNumSuccessors() <= 1) {
453 return EmptyDivergenceDesc;
454 }
455
456 // already available in cache?
457 auto ItCached = CachedControlDivDescs.find(&Term);
458 if (ItCached != CachedControlDivDescs.end())
459 return *ItCached->second;
460
461 // compute all join points
462 // Special handling of divergent loop exits is not needed for LCSSA
463 const auto &TermBlock = *Term.getParent();
464 DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
465 auto DivDesc = Propagator.computeJoinPoints();
466
467 LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
468 dbgs() << "JoinDivBlocks: ";
469 printBlockSet(DivDesc->JoinDivBlocks, dbgs());
470 dbgs() << "\nLoopDivBlocks: ";
471 printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
472
473 auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
474 assert(ItInserted.second);
475 return *ItInserted.first->second;
476}
477
478} // namespace llvm
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
#define F(x, y, z)
Definition: MD5.cpp:55
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:112
Compute divergence starting with a divergent branch.
bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel)
std::unique_ptr< DivergenceDescriptorT > DivDesc
void printDefs(raw_ostream &Out)
std::unique_ptr< DivergenceDescriptorT > computeJoinPoints()
bool visitEdge(const BlockT &SuccBlock, const BlockT &Label)
NodeT * getRoot() const
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
Definition: LoopInfo.h:139
BlockT * getHeader() const
Definition: LoopInfo.h:105
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:992
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:547
PostDominatorTree Class - Concrete subclass of DominatorTree that is used to compute the post-dominat...
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
std::string str() const
str - Get the contents as an std::string.
Definition: StringRef.h:222
SyncDependenceAnalysis(const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI)
const ControlDivergenceDesc & getJoinBlocks(const Instruction &Term)
Computes divergent join points and loop exits caused by branch divergence in Term.
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:308
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1777
auto successors(const MachineBasicBlock *BB)
static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
const BasicBlock * getBlockAt(unsigned Idx) const
unsigned getIndexOf(const BasicBlock &BB) const
unsigned size() const