aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
blob: 6b1bbd7a2b0794725d86c2a8f3361db9ad6799e5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// Fix bitcasted functions.
///
/// WebAssembly requires caller and callee signatures to match, however in LLVM,
/// some amount of slop is vaguely permitted. Detect mismatch by looking for
/// bitcasts of functions and rewrite them to use wrapper functions instead.
///
/// This doesn't catch all cases, such as when a function's address is taken in
/// one place and casted in another, but it works for many common cases.
///
/// Note that LLVM already optimizes away function bitcasts in common cases by
/// dropping arguments as needed, so this pass only ends up getting used in less
/// common cases.
///
//===----------------------------------------------------------------------===//

#include "WebAssembly.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;

#define DEBUG_TYPE "wasm-fix-function-bitcasts"

namespace {
class FixFunctionBitcasts final : public ModulePass {
  StringRef getPassName() const override {
    return "WebAssembly Fix Function Bitcasts";
  }

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.setPreservesCFG();
    ModulePass::getAnalysisUsage(AU);
  }

  bool runOnModule(Module &M) override;

public:
  static char ID;
  FixFunctionBitcasts() : ModulePass(ID) {}
};
} // End anonymous namespace

char FixFunctionBitcasts::ID = 0;
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
                "Fix mismatching bitcasts for WebAssembly", false, false)

ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
  return new FixFunctionBitcasts();
}

// Recursively descend the def-use lists from V to find non-bitcast users of
// bitcasts of V.
static void findUses(Value *V, Function &F,
                     SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
                     SmallPtrSetImpl<Constant *> &ConstantBCs) {
  for (Use &U : V->uses()) {
    if (auto *BC = dyn_cast<BitCastOperator>(U.getUser()))
      findUses(BC, F, Uses, ConstantBCs);
    else if (auto *A = dyn_cast<GlobalAlias>(U.getUser()))
      findUses(A, F, Uses, ConstantBCs);
    else if (U.get()->getType() != F.getType()) {
      CallSite CS(U.getUser());
      if (!CS)
        // Skip uses that aren't immediately called
        continue;
      Value *Callee = CS.getCalledValue();
      if (Callee != V)
        // Skip calls where the function isn't the callee
        continue;
      if (isa<Constant>(U.get())) {
        // Only add constant bitcasts to the list once; they get RAUW'd
        auto C = ConstantBCs.insert(cast<Constant>(U.get()));
        if (!C.second)
          continue;
      }
      Uses.push_back(std::make_pair(&U, &F));
    }
  }
}

// Create a wrapper function with type Ty that calls F (which may have a
// different type). Attempt to support common bitcasted function idioms:
//  - Call with more arguments than needed: arguments are dropped
//  - Call with fewer arguments than needed: arguments are filled in with undef
//  - Return value is not needed: drop it
//  - Return value needed but not present: supply an undef
//
// If the all the argument types of trivially castable to one another (i.e.
// I32 vs pointer type) then we don't create a wrapper at all (return nullptr
// instead).
//
// If there is a type mismatch that we know would result in an invalid wasm
// module then generate wrapper that contains unreachable (i.e. abort at
// runtime).  Such programs are deep into undefined behaviour territory,
// but we choose to fail at runtime rather than generate and invalid module
// or fail at compiler time.  The reason we delay the error is that we want
// to support the CMake which expects to be able to compile and link programs
// that refer to functions with entirely incorrect signatures (this is how
// CMake detects the existence of a function in a toolchain).
//
// For bitcasts that involve struct types we don't know at this stage if they
// would be equivalent at the wasm level and so we can't know if we need to
// generate a wrapper.
static Function *createWrapper(Function *F, FunctionType *Ty) {
  Module *M = F->getParent();

  Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage,
                                       F->getName() + "_bitcast", M);
  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
  const DataLayout &DL = BB->getModule()->getDataLayout();

  // Determine what arguments to pass.
  SmallVector<Value *, 4> Args;
  Function::arg_iterator AI = Wrapper->arg_begin();
  Function::arg_iterator AE = Wrapper->arg_end();
  FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
  FunctionType::param_iterator PE = F->getFunctionType()->param_end();
  bool TypeMismatch = false;
  bool WrapperNeeded = false;

  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
  Type *RtnType = Ty->getReturnType();

  if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
      (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
      (ExpectedRtnType != RtnType))
    WrapperNeeded = true;

  for (; AI != AE && PI != PE; ++AI, ++PI) {
    Type *ArgType = AI->getType();
    Type *ParamType = *PI;

    if (ArgType == ParamType) {
      Args.push_back(&*AI);
    } else {
      if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
        Instruction *PtrCast =
            CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
        BB->getInstList().push_back(PtrCast);
        Args.push_back(PtrCast);
      } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
        LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
                          << F->getName() << "\n");
        WrapperNeeded = false;
      } else {
        LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
                          << F->getName() << "\n");
        LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
                          << *ParamType << " Got: " << *ArgType << "\n");
        TypeMismatch = true;
        break;
      }
    }
  }

  if (WrapperNeeded && !TypeMismatch) {
    for (; PI != PE; ++PI)
      Args.push_back(UndefValue::get(*PI));
    if (F->isVarArg())
      for (; AI != AE; ++AI)
        Args.push_back(&*AI);

    CallInst *Call = CallInst::Create(F, Args, "", BB);

    Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
    Type *RtnType = Ty->getReturnType();
    // Determine what value to return.
    if (RtnType->isVoidTy()) {
      ReturnInst::Create(M->getContext(), BB);
    } else if (ExpectedRtnType->isVoidTy()) {
      LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
      ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
    } else if (RtnType == ExpectedRtnType) {
      ReturnInst::Create(M->getContext(), Call, BB);
    } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
                                                    DL)) {
      Instruction *Cast =
          CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
      BB->getInstList().push_back(Cast);
      ReturnInst::Create(M->getContext(), Cast, BB);
    } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
      LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
                        << F->getName() << "\n");
      WrapperNeeded = false;
    } else {
      LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
                        << F->getName() << "\n");
      LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
                        << " Got: " << *RtnType << "\n");
      TypeMismatch = true;
    }
  }

  if (TypeMismatch) {
    // Create a new wrapper that simply contains `unreachable`.
    Wrapper->eraseFromParent();
    Wrapper = Function::Create(Ty, Function::PrivateLinkage,
                               F->getName() + "_bitcast_invalid", M);
    BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
    new UnreachableInst(M->getContext(), BB);
    Wrapper->setName(F->getName() + "_bitcast_invalid");
  } else if (!WrapperNeeded) {
    LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
                      << "\n");
    Wrapper->eraseFromParent();
    return nullptr;
  }
  LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
  return Wrapper;
}

// Test whether a main function with type FuncTy should be rewritten to have
// type MainTy.
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
  // Only fix the main function if it's the standard zero-arg form. That way,
  // the standard cases will work as expected, and users will see signature
  // mismatches from the linker for non-standard cases.
  return FuncTy->getReturnType() == MainTy->getReturnType() &&
         FuncTy->getNumParams() == 0 &&
         !FuncTy->isVarArg();
}

bool FixFunctionBitcasts::runOnModule(Module &M) {
  LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");

  Function *Main = nullptr;
  CallInst *CallMain = nullptr;
  SmallVector<std::pair<Use *, Function *>, 0> Uses;
  SmallPtrSet<Constant *, 2> ConstantBCs;

  // Collect all the places that need wrappers.
  for (Function &F : M) {
    findUses(&F, F, Uses, ConstantBCs);

    // If we have a "main" function, and its type isn't
    // "int main(int argc, char *argv[])", create an artificial call with it
    // bitcasted to that type so that we generate a wrapper for it, so that
    // the C runtime can call it.
    if (F.getName() == "main") {
      Main = &F;
      LLVMContext &C = M.getContext();
      Type *MainArgTys[] = {Type::getInt32Ty(C),
                            PointerType::get(Type::getInt8PtrTy(C), 0)};
      FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
                                               /*isVarArg=*/false);
      if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
        LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
                          << *F.getFunctionType() << "\n");
        Value *Args[] = {UndefValue::get(MainArgTys[0]),
                         UndefValue::get(MainArgTys[1])};
        Value *Casted =
            ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
        CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
        Use *UseMain = &CallMain->getOperandUse(2);
        Uses.push_back(std::make_pair(UseMain, &F));
      }
    }
  }

  DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;

  for (auto &UseFunc : Uses) {
    Use *U = UseFunc.first;
    Function *F = UseFunc.second;
    auto *PTy = cast<PointerType>(U->get()->getType());
    auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());

    // If the function is casted to something like i8* as a "generic pointer"
    // to be later casted to something else, we can't generate a wrapper for it.
    // Just ignore such casts for now.
    if (!Ty)
      continue;

    auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
    if (Pair.second)
      Pair.first->second = createWrapper(F, Ty);

    Function *Wrapper = Pair.first->second;
    if (!Wrapper)
      continue;

    if (isa<Constant>(U->get()))
      U->get()->replaceAllUsesWith(Wrapper);
    else
      U->set(Wrapper);
  }

  // If we created a wrapper for main, rename the wrapper so that it's the
  // one that gets called from startup.
  if (CallMain) {
    Main->setName("__original_main");
    auto *MainWrapper =
        cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
    delete CallMain;
    if (Main->isDeclaration()) {
      // The wrapper is not needed in this case as we don't need to export
      // it to anyone else.
      MainWrapper->eraseFromParent();
    } else {
      // Otherwise give the wrapper the same linkage as the original main
      // function, so that it can be called from the same places.
      MainWrapper->setName("main");
      MainWrapper->setLinkage(Main->getLinkage());
      MainWrapper->setVisibility(Main->getVisibility());
    }
  }

  return true;
}