diff options
Diffstat (limited to 'llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp')
-rw-r--r-- | llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp | 222 |
1 files changed, 169 insertions, 53 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp index 23aaa5160abd..fe656753889f 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp @@ -279,6 +279,7 @@ #include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/SSAUpdaterBulk.h" @@ -454,12 +455,12 @@ static Function *getEmscriptenFunction(FunctionType *Ty, const Twine &Name, // Tell the linker that this function is expected to be imported from the // 'env' module. if (!F->hasFnAttribute("wasm-import-module")) { - llvm::AttrBuilder B; + llvm::AttrBuilder B(M->getContext()); B.addAttribute("wasm-import-module", "env"); F->addFnAttrs(B); } if (!F->hasFnAttribute("wasm-import-name")) { - llvm::AttrBuilder B; + llvm::AttrBuilder B(M->getContext()); B.addAttribute("wasm-import-name", F->getName()); F->addFnAttrs(B); } @@ -547,7 +548,7 @@ Value *WebAssemblyLowerEmscriptenEHSjLj::wrapInvoke(CallBase *CI) { for (unsigned I = 0, E = CI->arg_size(); I < E; ++I) ArgAttributes.push_back(InvokeAL.getParamAttrs(I)); - AttrBuilder FnAttrs(InvokeAL.getFnAttrs()); + AttrBuilder FnAttrs(CI->getContext(), InvokeAL.getFnAttrs()); if (FnAttrs.contains(Attribute::AllocSize)) { // The allocsize attribute (if any) referes to parameters by index and needs // to be adjusted. @@ -610,6 +611,8 @@ static bool canLongjmp(const Value *Callee) { return false; StringRef CalleeName = Callee->getName(); + // TODO Include more functions or consider checking with mangled prefixes + // The reason we include malloc/free here is to exclude the malloc/free // calls generated in setjmp prep / cleanup routines. if (CalleeName == "setjmp" || CalleeName == "malloc" || CalleeName == "free") @@ -626,11 +629,50 @@ static bool canLongjmp(const Value *Callee) { return false; // Exception-catching related functions - if (CalleeName == "__cxa_begin_catch" || CalleeName == "__cxa_end_catch" || + // + // We intentionally excluded __cxa_end_catch here even though it surely cannot + // longjmp, in order to maintain the unwind relationship from all existing + // catchpads (and calls within them) to catch.dispatch.longjmp. + // + // In Wasm EH + Wasm SjLj, we + // 1. Make all catchswitch and cleanuppad that unwind to caller unwind to + // catch.dispatch.longjmp instead + // 2. Convert all longjmpable calls to invokes that unwind to + // catch.dispatch.longjmp + // But catchswitch BBs are removed in isel, so if an EH catchswitch (generated + // from an exception)'s catchpad does not contain any calls that are converted + // into invokes unwinding to catch.dispatch.longjmp, this unwind relationship + // (EH catchswitch BB -> catch.dispatch.longjmp BB) is lost and + // catch.dispatch.longjmp BB can be placed before the EH catchswitch BB in + // CFGSort. + // int ret = setjmp(buf); + // try { + // foo(); // longjmps + // } catch (...) { + // } + // Then in this code, if 'foo' longjmps, it first unwinds to 'catch (...)' + // catchswitch, and is not caught by that catchswitch because it is a longjmp, + // then it should next unwind to catch.dispatch.longjmp BB. But if this 'catch + // (...)' catchswitch -> catch.dispatch.longjmp unwind relationship is lost, + // it will not unwind to catch.dispatch.longjmp, producing an incorrect + // result. + // + // Every catchpad generated by Wasm C++ contains __cxa_end_catch, so we + // intentionally treat it as longjmpable to work around this problem. This is + // a hacky fix but an easy one. + // + // The comment block in findWasmUnwindDestinations() in + // SelectionDAGBuilder.cpp is addressing a similar problem. + if (CalleeName == "__cxa_begin_catch" || CalleeName == "__cxa_allocate_exception" || CalleeName == "__cxa_throw" || CalleeName == "__clang_call_terminate") return false; + // std::terminate, which is generated when another exception occurs while + // handling an exception, cannot longjmp. + if (CalleeName == "_ZSt9terminatev") + return false; + // Otherwise we don't know return true; } @@ -817,6 +859,32 @@ static bool containsLongjmpableCalls(const Function *F) { return false; } +// When a function contains a setjmp call but not other calls that can longjmp, +// we don't do setjmp transformation for that setjmp. But we need to convert the +// setjmp calls into "i32 0" so they don't cause link time errors. setjmp always +// returns 0 when called directly. +static void nullifySetjmp(Function *F) { + Module &M = *F->getParent(); + IRBuilder<> IRB(M.getContext()); + Function *SetjmpF = M.getFunction("setjmp"); + SmallVector<Instruction *, 1> ToErase; + + for (User *U : SetjmpF->users()) { + auto *CI = dyn_cast<CallInst>(U); + // FIXME 'invoke' to setjmp can happen when we use Wasm EH + Wasm SjLj, but + // we don't support two being used together yet. + if (!CI) + report_fatal_error("Wasm EH + Wasm SjLj is not fully supported yet"); + BasicBlock *BB = CI->getParent(); + if (BB->getParent() != F) // in other function + continue; + ToErase.push_back(CI); + CI->replaceAllUsesWith(IRB.getInt32(0)); + } + for (auto *I : ToErase) + I->eraseFromParent(); +} + bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { LLVM_DEBUG(dbgs() << "********** Lower Emscripten EH & SjLj **********\n"); @@ -886,6 +954,10 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { EHTypeIDF = getEmscriptenFunction(EHTypeIDTy, "llvm_eh_typeid_for", &M); } + // Functions that contains calls to setjmp but don't have other longjmpable + // calls within them. + SmallPtrSet<Function *, 4> SetjmpUsersToNullify; + if ((EnableEmSjLj || EnableWasmSjLj) && SetjmpF) { // Precompute setjmp users for (User *U : SetjmpF->users()) { @@ -896,6 +968,8 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { // so can ignore it if (containsLongjmpableCalls(UserF)) SetjmpUsers.insert(UserF); + else + SetjmpUsersToNullify.insert(UserF); } else { std::string S; raw_string_ostream SS(S); @@ -975,6 +1049,14 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) { runSjLjOnFunction(*F); } + // Replace unnecessary setjmp calls with 0 + if ((EnableEmSjLj || EnableWasmSjLj) && !SetjmpUsersToNullify.empty()) { + Changed = true; + assert(SetjmpF); + for (Function *F : SetjmpUsersToNullify) + nullifySetjmp(F); + } + if (!Changed) { // Delete unused global variables and functions if (ResumeF) @@ -1078,20 +1160,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runEHOnFunction(Function &F) { } else { // This can't throw, and we don't need this invoke, just replace it with a // call+branch - SmallVector<Value *, 16> Args(II->args()); - CallInst *NewCall = - IRB.CreateCall(II->getFunctionType(), II->getCalledOperand(), Args); - NewCall->takeName(II); - NewCall->setCallingConv(II->getCallingConv()); - NewCall->setDebugLoc(II->getDebugLoc()); - NewCall->setAttributes(II->getAttributes()); - II->replaceAllUsesWith(NewCall); - ToErase.push_back(II); - - IRB.CreateBr(II->getNormalDest()); - - // Remove any PHI node entries from the exception destination - II->getUnwindDest()->removePredecessor(&BB); + changeToCall(II); } } @@ -1243,16 +1312,19 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) { // Setjmp transformation SmallVector<PHINode *, 4> SetjmpRetPHIs; Function *SetjmpF = M.getFunction("setjmp"); - for (User *U : SetjmpF->users()) { - auto *CI = dyn_cast<CallInst>(U); - // FIXME 'invoke' to setjmp can happen when we use Wasm EH + Wasm SjLj, but - // we don't support two being used together yet. - if (!CI) - report_fatal_error("Wasm EH + Wasm SjLj is not fully supported yet"); - BasicBlock *BB = CI->getParent(); + for (auto *U : make_early_inc_range(SetjmpF->users())) { + auto *CB = dyn_cast<CallBase>(U); + BasicBlock *BB = CB->getParent(); if (BB->getParent() != &F) // in other function continue; + CallInst *CI = nullptr; + // setjmp cannot throw. So if it is an invoke, lower it to a call + if (auto *II = dyn_cast<InvokeInst>(CB)) + CI = llvm::changeToCall(II); + else + CI = cast<CallInst>(CB); + // The tail is everything right after the call, and will be reached once // when setjmp is called, and later when longjmp returns to the setjmp BasicBlock *Tail = SplitBlock(BB, CI->getNextNode()); @@ -1568,6 +1640,13 @@ void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForEmscriptenSjLj( I->eraseFromParent(); } +static BasicBlock *getCleanupRetUnwindDest(const CleanupPadInst *CPI) { + for (const User *U : CPI->users()) + if (const auto *CRI = dyn_cast<CleanupReturnInst>(U)) + return CRI->getUnwindDest(); + return nullptr; +} + // Create a catchpad in which we catch a longjmp's env and val arguments, test // if the longjmp corresponds to one of setjmps in the current function, and if // so, jump to the setjmp dispatch BB from which we go to one of post-setjmp @@ -1619,18 +1698,18 @@ void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForWasmSjLj( BasicBlock::Create(C, "setjmp.dispatch", &F, OrigEntry); cast<BranchInst>(Entry->getTerminator())->setSuccessor(0, SetjmpDispatchBB); - // Create catch.dispatch.longjmp BB a catchswitch instruction - BasicBlock *CatchSwitchBB = + // Create catch.dispatch.longjmp BB and a catchswitch instruction + BasicBlock *CatchDispatchLongjmpBB = BasicBlock::Create(C, "catch.dispatch.longjmp", &F); - IRB.SetInsertPoint(CatchSwitchBB); - CatchSwitchInst *CatchSwitch = + IRB.SetInsertPoint(CatchDispatchLongjmpBB); + CatchSwitchInst *CatchSwitchLongjmp = IRB.CreateCatchSwitch(ConstantTokenNone::get(C), nullptr, 1); // Create catch.longjmp BB and a catchpad instruction BasicBlock *CatchLongjmpBB = BasicBlock::Create(C, "catch.longjmp", &F); - CatchSwitch->addHandler(CatchLongjmpBB); + CatchSwitchLongjmp->addHandler(CatchLongjmpBB); IRB.SetInsertPoint(CatchLongjmpBB); - CatchPadInst *CatchPad = IRB.CreateCatchPad(CatchSwitch, {}); + CatchPadInst *CatchPad = IRB.CreateCatchPad(CatchSwitchLongjmp, {}); // Wasm throw and catch instructions can throw and catch multiple values, but // that requires multivalue support in the toolchain, which is currently not @@ -1696,9 +1775,9 @@ void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForWasmSjLj( // Convert all longjmpable call instructions to invokes that unwind to the // newly created catch.dispatch.longjmp BB. - SmallVector<Instruction *, 64> ToErase; + SmallVector<CallInst *, 64> LongjmpableCalls; for (auto *BB = &*F.begin(); BB; BB = BB->getNextNode()) { - for (Instruction &I : *BB) { + for (auto &I : *BB) { auto *CI = dyn_cast<CallInst>(&I); if (!CI) continue; @@ -1716,29 +1795,66 @@ void WebAssemblyLowerEmscriptenEHSjLj::handleLongjmpableCallsForWasmSjLj( // setjmps in this function. We should not convert this call to an invoke. if (CI == WasmLongjmpCI) continue; - ToErase.push_back(CI); + LongjmpableCalls.push_back(CI); + } + } - // Even if the callee function has attribute 'nounwind', which is true for - // all C functions, it can longjmp, which means it can throw a Wasm - // exception now. - CI->removeFnAttr(Attribute::NoUnwind); - if (Function *CalleeF = CI->getCalledFunction()) { - CalleeF->removeFnAttr(Attribute::NoUnwind); + for (auto *CI : LongjmpableCalls) { + // Even if the callee function has attribute 'nounwind', which is true for + // all C functions, it can longjmp, which means it can throw a Wasm + // exception now. + CI->removeFnAttr(Attribute::NoUnwind); + if (Function *CalleeF = CI->getCalledFunction()) + CalleeF->removeFnAttr(Attribute::NoUnwind); + + // Change it to an invoke and make it unwind to the catch.dispatch.longjmp + // BB. If the call is enclosed in another catchpad/cleanuppad scope, unwind + // to its parent pad's unwind destination instead to preserve the scope + // structure. It will eventually unwind to the catch.dispatch.longjmp. + SmallVector<OperandBundleDef, 1> Bundles; + BasicBlock *UnwindDest = nullptr; + if (auto Bundle = CI->getOperandBundle(LLVMContext::OB_funclet)) { + Instruction *FromPad = cast<Instruction>(Bundle->Inputs[0]); + while (!UnwindDest && FromPad) { + if (auto *CPI = dyn_cast<CatchPadInst>(FromPad)) { + UnwindDest = CPI->getCatchSwitch()->getUnwindDest(); + FromPad = nullptr; // stop searching + } else if (auto *CPI = dyn_cast<CleanupPadInst>(FromPad)) { + // getCleanupRetUnwindDest() can return nullptr when + // 1. This cleanuppad's matching cleanupret uwninds to caller + // 2. There is no matching cleanupret because it ends with + // unreachable. + // In case of 2, we need to traverse the parent pad chain. + UnwindDest = getCleanupRetUnwindDest(CPI); + FromPad = cast<Instruction>(CPI->getParentPad()); + } } + } + if (!UnwindDest) + UnwindDest = CatchDispatchLongjmpBB; + changeToInvokeAndSplitBasicBlock(CI, UnwindDest); + } - IRB.SetInsertPoint(CI); - BasicBlock *Tail = SplitBlock(BB, CI->getNextNode()); - // We will add a new invoke. So remove the branch created when we split - // the BB - ToErase.push_back(BB->getTerminator()); - SmallVector<Value *, 8> Args(CI->args()); - InvokeInst *II = - IRB.CreateInvoke(CI->getFunctionType(), CI->getCalledOperand(), Tail, - CatchSwitchBB, Args); - II->takeName(CI); - II->setDebugLoc(CI->getDebugLoc()); - II->setAttributes(CI->getAttributes()); - CI->replaceAllUsesWith(II); + SmallVector<Instruction *, 16> ToErase; + for (auto &BB : F) { + if (auto *CSI = dyn_cast<CatchSwitchInst>(BB.getFirstNonPHI())) { + if (CSI != CatchSwitchLongjmp && CSI->unwindsToCaller()) { + IRB.SetInsertPoint(CSI); + ToErase.push_back(CSI); + auto *NewCSI = IRB.CreateCatchSwitch(CSI->getParentPad(), + CatchDispatchLongjmpBB, 1); + NewCSI->addHandler(*CSI->handler_begin()); + NewCSI->takeName(CSI); + CSI->replaceAllUsesWith(NewCSI); + } + } + + if (auto *CRI = dyn_cast<CleanupReturnInst>(BB.getTerminator())) { + if (CRI->unwindsToCaller()) { + IRB.SetInsertPoint(CRI); + ToErase.push_back(CRI); + IRB.CreateCleanupRet(CRI->getCleanupPad(), CatchDispatchLongjmpBB); + } } } |