aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/Driver/ToolChains/ROCm.h
blob: 21e62a465d7be387324cbe0dd154b3860f4cc7df (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
//===--- ROCm.h - ROCm installation detector --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H
#define LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H

#include "clang/Basic/Cuda.h"
#include "clang/Basic/LLVM.h"
#include "clang/Driver/Driver.h"
#include "clang/Driver/Options.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Option/ArgList.h"
#include "llvm/Support/VersionTuple.h"

namespace clang {
namespace driver {

/// A class to find a viable ROCM installation
/// TODO: Generalize to handle libclc.
class RocmInstallationDetector {
private:
  struct ConditionalLibrary {
    SmallString<0> On;
    SmallString<0> Off;

    bool isValid() const { return !On.empty() && !Off.empty(); }

    StringRef get(bool Enabled) const {
      assert(isValid());
      return Enabled ? On : Off;
    }
  };

  // Installation path candidate.
  struct Candidate {
    llvm::SmallString<0> Path;
    bool StrictChecking;

    Candidate(std::string Path, bool StrictChecking = false)
        : Path(Path), StrictChecking(StrictChecking) {}
  };

  const Driver &D;
  bool HasHIPRuntime = false;
  bool HasDeviceLibrary = false;

  // Default version if not detected or specified.
  const unsigned DefaultVersionMajor = 3;
  const unsigned DefaultVersionMinor = 5;
  const char *DefaultVersionPatch = "0";

  // The version string in Major.Minor.Patch format.
  std::string DetectedVersion;
  // Version containing major and minor.
  llvm::VersionTuple VersionMajorMinor;
  // Version containing patch.
  std::string VersionPatch;

  // ROCm path specified by --rocm-path.
  StringRef RocmPathArg;
  // ROCm device library paths specified by --rocm-device-lib-path.
  std::vector<std::string> RocmDeviceLibPathArg;
  // HIP version specified by --hip-version.
  StringRef HIPVersionArg;
  // Wheter -nogpulib is specified.
  bool NoBuiltinLibs = false;

  // Paths
  SmallString<0> InstallPath;
  SmallString<0> BinPath;
  SmallString<0> LibPath;
  SmallString<0> LibDevicePath;
  SmallString<0> IncludePath;
  llvm::StringMap<std::string> LibDeviceMap;

  // Libraries that are always linked.
  SmallString<0> OCML;
  SmallString<0> OCKL;

  // Libraries that are always linked depending on the language
  SmallString<0> OpenCL;
  SmallString<0> HIP;

  // Libraries swapped based on compile flags.
  ConditionalLibrary WavefrontSize64;
  ConditionalLibrary FiniteOnly;
  ConditionalLibrary UnsafeMath;
  ConditionalLibrary DenormalsAreZero;
  ConditionalLibrary CorrectlyRoundedSqrt;

  bool allGenericLibsValid() const {
    return !OCML.empty() && !OCKL.empty() && !OpenCL.empty() && !HIP.empty() &&
           WavefrontSize64.isValid() && FiniteOnly.isValid() &&
           UnsafeMath.isValid() && DenormalsAreZero.isValid() &&
           CorrectlyRoundedSqrt.isValid();
  }

  void scanLibDevicePath(llvm::StringRef Path);
  bool parseHIPVersionFile(llvm::StringRef V);
  SmallVector<Candidate, 4> getInstallationPathCandidates();

public:
  RocmInstallationDetector(const Driver &D, const llvm::Triple &HostTriple,
                           const llvm::opt::ArgList &Args,
                           bool DetectHIPRuntime = true,
                           bool DetectDeviceLib = false);

  /// Add arguments needed to link default bitcode libraries.
  void addCommonBitcodeLibCC1Args(const llvm::opt::ArgList &DriverArgs,
                                  llvm::opt::ArgStringList &CC1Args,
                                  StringRef LibDeviceFile, bool Wave64,
                                  bool DAZ, bool FiniteOnly, bool UnsafeMathOpt,
                                  bool FastRelaxedMath, bool CorrectSqrt) const;

  /// Check whether we detected a valid HIP runtime.
  bool hasHIPRuntime() const { return HasHIPRuntime; }

  /// Check whether we detected a valid ROCm device library.
  bool hasDeviceLibrary() const { return HasDeviceLibrary; }

  /// Print information about the detected ROCm installation.
  void print(raw_ostream &OS) const;

  /// Get the detected Rocm install's version.
  // RocmVersion version() const { return Version; }

  /// Get the detected Rocm installation path.
  StringRef getInstallPath() const { return InstallPath; }

  /// Get the detected path to Rocm's bin directory.
  // StringRef getBinPath() const { return BinPath; }

  /// Get the detected Rocm Include path.
  StringRef getIncludePath() const { return IncludePath; }

  /// Get the detected Rocm library path.
  StringRef getLibPath() const { return LibPath; }

  /// Get the detected Rocm device library path.
  StringRef getLibDevicePath() const { return LibDevicePath; }

  StringRef getOCMLPath() const {
    assert(!OCML.empty());
    return OCML;
  }

  StringRef getOCKLPath() const {
    assert(!OCKL.empty());
    return OCKL;
  }

  StringRef getOpenCLPath() const {
    assert(!OpenCL.empty());
    return OpenCL;
  }

  StringRef getHIPPath() const {
    assert(!HIP.empty());
    return HIP;
  }

  StringRef getWavefrontSize64Path(bool Enabled) const {
    return WavefrontSize64.get(Enabled);
  }

  StringRef getFiniteOnlyPath(bool Enabled) const {
    return FiniteOnly.get(Enabled);
  }

  StringRef getUnsafeMathPath(bool Enabled) const {
    return UnsafeMath.get(Enabled);
  }

  StringRef getDenormalsAreZeroPath(bool Enabled) const {
    return DenormalsAreZero.get(Enabled);
  }

  StringRef getCorrectlyRoundedSqrtPath(bool Enabled) const {
    return CorrectlyRoundedSqrt.get(Enabled);
  }

  /// Get libdevice file for given architecture
  std::string getLibDeviceFile(StringRef Gpu) const {
    return LibDeviceMap.lookup(Gpu);
  }

  void AddHIPIncludeArgs(const llvm::opt::ArgList &DriverArgs,
                         llvm::opt::ArgStringList &CC1Args) const;

  void detectDeviceLibrary();
  void detectHIPRuntime();

  /// Get the values for --rocm-device-lib-path arguments
  std::vector<std::string> getRocmDeviceLibPathArg() const {
    return RocmDeviceLibPathArg;
  }

  /// Get the value for --rocm-path argument
  StringRef getRocmPathArg() const { return RocmPathArg; }

  /// Get the value for --hip-version argument
  StringRef getHIPVersionArg() const { return HIPVersionArg; }

  std::string getHIPVersion() const { return DetectedVersion; }
};

} // end namespace driver
} // end namespace clang

#endif // LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H