aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2022-01-27 22:17:16 +0000
committerDimitry Andric <dim@FreeBSD.org>2022-06-04 11:59:19 +0000
commit390adc38fc112be360bd15499e5241bf4e675b6f (patch)
tree712d68d3aa03f7aa4902ba03dcac2a56f49ae0e5 /contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp
parent8a84287b0edc66fc6dede3db770d10ff41da5464 (diff)
downloadsrc-390adc38fc112be360bd15499e5241bf4e675b6f.tar.gz
src-390adc38fc112be360bd15499e5241bf4e675b6f.zip
Merge llvm-project main llvmorg-14-init-17616-g024a1fab5c35
This updates llvm, clang, compiler-rt, libc++, libunwind, lld, lldb and openmp to llvmorg-14-init-17616-g024a1fab5c35. PR: 261742 MFC after: 2 weeks (cherry picked from commit 04eeddc0aa8e0a417a16eaf9d7d095207f4a8623)
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp244
1 files changed, 140 insertions, 104 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 96aff563aa9b..24cd5747c5a4 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -829,39 +829,54 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
default: RetTy = Type::getInt16Ty(header->getContext()); break;
}
- std::vector<Type *> paramTy;
+ std::vector<Type *> ParamTy;
+ std::vector<Type *> AggParamTy;
+ ValueSet StructValues;
// Add the types of the input values to the function's argument list
for (Value *value : inputs) {
LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
- paramTy.push_back(value->getType());
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
+ AggParamTy.push_back(value->getType());
+ StructValues.insert(value);
+ } else
+ ParamTy.push_back(value->getType());
}
// Add the types of the output values to the function's argument list.
for (Value *output : outputs) {
LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
- if (AggregateArgs)
- paramTy.push_back(output->getType());
- else
- paramTy.push_back(PointerType::getUnqual(output->getType()));
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+ AggParamTy.push_back(output->getType());
+ StructValues.insert(output);
+ } else
+ ParamTy.push_back(PointerType::getUnqual(output->getType()));
+ }
+
+ assert(
+ (ParamTy.size() + AggParamTy.size()) ==
+ (inputs.size() + outputs.size()) &&
+ "Number of scalar and aggregate params does not match inputs, outputs");
+ assert(StructValues.empty() ||
+ AggregateArgs && "Expeced StructValues only with AggregateArgs set");
+
+ // Concatenate scalar and aggregate params in ParamTy.
+ size_t NumScalarParams = ParamTy.size();
+ StructType *StructTy = nullptr;
+ if (AggregateArgs && !AggParamTy.empty()) {
+ StructTy = StructType::get(M->getContext(), AggParamTy);
+ ParamTy.push_back(PointerType::getUnqual(StructTy));
}
LLVM_DEBUG({
dbgs() << "Function type: " << *RetTy << " f(";
- for (Type *i : paramTy)
+ for (Type *i : ParamTy)
dbgs() << *i << ", ";
dbgs() << ")\n";
});
- StructType *StructTy = nullptr;
- if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
- StructTy = StructType::get(M->getContext(), paramTy);
- paramTy.clear();
- paramTy.push_back(PointerType::getUnqual(StructTy));
- }
- FunctionType *funcType =
- FunctionType::get(RetTy, paramTy,
- AllowVarArgs && oldFunction->isVarArg());
+ FunctionType *funcType = FunctionType::get(
+ RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());
std::string SuffixToUse =
Suffix.empty()
@@ -871,13 +886,6 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
Function *newFunction = Function::Create(
funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
oldFunction->getName() + "." + SuffixToUse, M);
- // If the old function is no-throw, so is the new one.
- if (oldFunction->doesNotThrow())
- newFunction->setDoesNotThrow();
-
- // Inherit the uwtable attribute if we need to.
- if (oldFunction->hasUWTable())
- newFunction->setHasUWTable();
// Inherit all of the target dependent attributes and white-listed
// target independent attributes.
@@ -893,53 +901,26 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
} else
switch (Attr.getKindAsEnum()) {
// Those attributes cannot be propagated safely. Explicitly list them
- // here so we get a warning if new attributes are added. This list also
- // includes non-function attributes.
- case Attribute::Alignment:
+ // here so we get a warning if new attributes are added.
case Attribute::AllocSize:
case Attribute::ArgMemOnly:
case Attribute::Builtin:
- case Attribute::ByVal:
case Attribute::Convergent:
- case Attribute::Dereferenceable:
- case Attribute::DereferenceableOrNull:
- case Attribute::ElementType:
- case Attribute::InAlloca:
- case Attribute::InReg:
case Attribute::InaccessibleMemOnly:
case Attribute::InaccessibleMemOrArgMemOnly:
case Attribute::JumpTable:
case Attribute::Naked:
- case Attribute::Nest:
- case Attribute::NoAlias:
case Attribute::NoBuiltin:
- case Attribute::NoCapture:
case Attribute::NoMerge:
case Attribute::NoReturn:
case Attribute::NoSync:
- case Attribute::NoUndef:
- case Attribute::None:
- case Attribute::NonNull:
- case Attribute::Preallocated:
case Attribute::ReadNone:
case Attribute::ReadOnly:
- case Attribute::Returned:
case Attribute::ReturnsTwice:
- case Attribute::SExt:
case Attribute::Speculatable:
case Attribute::StackAlignment:
- case Attribute::StructRet:
- case Attribute::SwiftError:
- case Attribute::SwiftSelf:
- case Attribute::SwiftAsync:
case Attribute::WillReturn:
case Attribute::WriteOnly:
- case Attribute::ZExt:
- case Attribute::ImmArg:
- case Attribute::ByRef:
- case Attribute::EndAttrKinds:
- case Attribute::EmptyKey:
- case Attribute::TombstoneKey:
continue;
// Those attributes should be safe to propagate to the extracted function.
case Attribute::AlwaysInline:
@@ -980,30 +961,62 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::MustProgress:
case Attribute::NoProfile:
break;
+ // These attributes cannot be applied to functions.
+ case Attribute::Alignment:
+ case Attribute::ByVal:
+ case Attribute::Dereferenceable:
+ case Attribute::DereferenceableOrNull:
+ case Attribute::ElementType:
+ case Attribute::InAlloca:
+ case Attribute::InReg:
+ case Attribute::Nest:
+ case Attribute::NoAlias:
+ case Attribute::NoCapture:
+ case Attribute::NoUndef:
+ case Attribute::NonNull:
+ case Attribute::Preallocated:
+ case Attribute::Returned:
+ case Attribute::SExt:
+ case Attribute::StructRet:
+ case Attribute::SwiftError:
+ case Attribute::SwiftSelf:
+ case Attribute::SwiftAsync:
+ case Attribute::ZExt:
+ case Attribute::ImmArg:
+ case Attribute::ByRef:
+ // These are not really attributes.
+ case Attribute::None:
+ case Attribute::EndAttrKinds:
+ case Attribute::EmptyKey:
+ case Attribute::TombstoneKey:
+ llvm_unreachable("Not a function attribute");
}
newFunction->addFnAttr(Attr);
}
newFunction->getBasicBlockList().push_back(newRootNode);
- // Create an iterator to name all of the arguments we inserted.
- Function::arg_iterator AI = newFunction->arg_begin();
+ // Create scalar and aggregate iterators to name all of the arguments we
+ // inserted.
+ Function::arg_iterator ScalarAI = newFunction->arg_begin();
+ Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams);
// Rewrite all users of the inputs in the extracted region to use the
// arguments (or appropriate addressing into struct) instead.
- for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
+ for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
Value *RewriteVal;
- if (AggregateArgs) {
+ if (AggregateArgs && StructValues.contains(inputs[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
Instruction *TI = newFunction->begin()->getTerminator();
GetElementPtrInst *GEP = GetElementPtrInst::Create(
- StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
- RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
+ StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
+ RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
"loadgep_" + inputs[i]->getName(), TI);
+ ++aggIdx;
} else
- RewriteVal = &*AI++;
+ RewriteVal = &*ScalarAI++;
std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
for (User *use : Users)
@@ -1013,12 +1026,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
}
// Set names for input and output arguments.
- if (!AggregateArgs) {
- AI = newFunction->arg_begin();
- for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
- AI->setName(inputs[i]->getName());
- for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
- AI->setName(outputs[i]->getName()+".out");
+ if (NumScalarParams) {
+ ScalarAI = newFunction->arg_begin();
+ for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
+ if (!StructValues.contains(inputs[i]))
+ ScalarAI->setName(inputs[i]->getName());
+ for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
+ if (!StructValues.contains(outputs[i]))
+ ScalarAI->setName(outputs[i]->getName() + ".out");
}
// Rewrite branches to basic blocks outside of the loop to new dummy blocks
@@ -1126,7 +1141,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
ValueSet &outputs) {
// Emit a call to the new function, passing in: *pointer to struct (if
// aggregating parameters), or plan inputs and allocated memory for outputs
- std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
+ std::vector<Value *> params, ReloadOutputs, Reloads;
+ ValueSet StructValues;
Module *M = newFunction->getParent();
LLVMContext &Context = M->getContext();
@@ -1134,23 +1150,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
CallInst *call = nullptr;
// Add inputs as params, or to be filled into the struct
- unsigned ArgNo = 0;
+ unsigned ScalarInputArgNo = 0;
SmallVector<unsigned, 1> SwiftErrorArgs;
for (Value *input : inputs) {
- if (AggregateArgs)
- StructValues.push_back(input);
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
+ StructValues.insert(input);
else {
params.push_back(input);
if (input->isSwiftError())
- SwiftErrorArgs.push_back(ArgNo);
+ SwiftErrorArgs.push_back(ScalarInputArgNo);
}
- ++ArgNo;
+ ++ScalarInputArgNo;
}
// Create allocas for the outputs
+ unsigned ScalarOutputArgNo = 0;
for (Value *output : outputs) {
- if (AggregateArgs) {
- StructValues.push_back(output);
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+ StructValues.insert(output);
} else {
AllocaInst *alloca =
new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
@@ -1158,12 +1175,14 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
&codeReplacer->getParent()->front().front());
ReloadOutputs.push_back(alloca);
params.push_back(alloca);
+ ++ScalarOutputArgNo;
}
}
StructType *StructArgTy = nullptr;
AllocaInst *Struct = nullptr;
- if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
+ unsigned NumAggregatedInputs = 0;
+ if (AggregateArgs && !StructValues.empty()) {
std::vector<Type *> ArgTypes;
for (Value *V : StructValues)
ArgTypes.push_back(V->getType());
@@ -1175,14 +1194,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
&codeReplacer->getParent()->front().front());
params.push_back(Struct);
- for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
- Value *Idx[2];
- Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
- GetElementPtrInst *GEP = GetElementPtrInst::Create(
- StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
- codeReplacer->getInstList().push_back(GEP);
- new StoreInst(StructValues[i], GEP, codeReplacer);
+ // Store aggregated inputs in the struct.
+ for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
+ if (inputs.contains(StructValues[i])) {
+ Value *Idx[2];
+ Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
+ GetElementPtrInst *GEP = GetElementPtrInst::Create(
+ StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
+ codeReplacer->getInstList().push_back(GEP);
+ new StoreInst(StructValues[i], GEP, codeReplacer);
+ NumAggregatedInputs++;
+ }
}
}
@@ -1205,24 +1228,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
}
- Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
- unsigned FirstOut = inputs.size();
- if (!AggregateArgs)
- std::advance(OutputArgBegin, inputs.size());
-
- // Reload the outputs passed in by reference.
- for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+ // Reload the outputs passed in by reference, use the struct if output is in
+ // the aggregate or reload from the scalar argument.
+ for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
+ aggIdx = NumAggregatedInputs;
+ i != e; ++i) {
Value *Output = nullptr;
- if (AggregateArgs) {
+ if (AggregateArgs && StructValues.contains(outputs[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
codeReplacer->getInstList().push_back(GEP);
Output = GEP;
+ ++aggIdx;
} else {
- Output = ReloadOutputs[i];
+ Output = ReloadOutputs[scalarIdx];
+ ++scalarIdx;
}
LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
outputs[i]->getName() + ".reload",
@@ -1304,8 +1327,13 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// Store the arguments right after the definition of output value.
// This should be proceeded after creating exit stubs to be ensure that invoke
// result restore will be placed in the outlined function.
- Function::arg_iterator OAI = OutputArgBegin;
- for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+ Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
+ std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
+ Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
+ std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);
+
+ for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
+ ++i) {
auto *OutI = dyn_cast<Instruction>(outputs[i]);
if (!OutI)
continue;
@@ -1325,23 +1353,27 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
assert((InsertBefore->getFunction() == newFunction ||
Blocks.count(InsertBefore->getParent())) &&
"InsertPt should be in new function");
- assert(OAI != newFunction->arg_end() &&
- "Number of output arguments should match "
- "the amount of defined values");
- if (AggregateArgs) {
+ if (AggregateArgs && StructValues.contains(outputs[i])) {
+ assert(AggOutputArgBegin != newFunction->arg_end() &&
+ "Number of aggregate output arguments should match "
+ "the number of defined values");
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
- StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
+ StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
InsertBefore);
new StoreInst(outputs[i], GEP, InsertBefore);
+ ++aggIdx;
// Since there should be only one struct argument aggregating
- // all the output values, we shouldn't increment OAI, which always
- // points to the struct argument, in this case.
+ // all the output values, we shouldn't increment AggOutputArgBegin, which
+ // always points to the struct argument, in this case.
} else {
- new StoreInst(outputs[i], &*OAI, InsertBefore);
- ++OAI;
+ assert(ScalarOutputArgBegin != newFunction->arg_end() &&
+ "Number of scalar output arguments should match "
+ "the number of defined values");
+ new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore);
+ ++ScalarOutputArgBegin;
}
}
@@ -1840,3 +1872,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
}
return false;
}
+
+void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
+ ExcludeArgsFromAggregate.insert(Arg);
+}