diff --git a/tests/test_parameter_interpolation.py b/tests/test_parameter_interpolation.py index cca4b389c..3f7d92ac7 100644 --- a/tests/test_parameter_interpolation.py +++ b/tests/test_parameter_interpolation.py @@ -1,7 +1,8 @@ from jax.config import config; config.update("jax_enable_x64", True) - +import copy import functools import numpy as np +import jax.numpy as jnp from common import GradientTest from timemachine.lib import potentials @@ -83,4 +84,104 @@ def test_nonbonded(self): test_interpolated_potential, rtol, precision=precision + ) + + + def test_nonbonded_advanced(self): + + # This test checks that we can supply arbitrary transformations of lambda to + # the nonbonded potential, and that the resulting derivatives (both du/dp and du/dl) + # are correct. + + np.random.seed(4321) + D = 3 + + cutoff = 1.0 + size = 36 + + water_coords = self.get_water_coords(D, sort=False) + coords = water_coords[:size] + padding = 0.2 + diag = np.amax(coords, axis=0) - np.amin(coords, axis=0) + padding + box = np.eye(3) + np.fill_diagonal(box, diag) + + N = coords.shape[0] + + lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) + lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) + + for precision, rtol in [(np.float64, 1e-8), (np.float32, 1e-4)]: + + # E = 0 # DEBUG! + qlj_src, ref_potential, test_potential = prepare_water_system( + coords, + lambda_plane_idxs, + lambda_offset_idxs, + p_scale=1.0, + cutoff=cutoff + ) + + qlj_dst, _, _ = prepare_water_system( + coords, + lambda_plane_idxs, + lambda_offset_idxs, + p_scale=1.0, + cutoff=cutoff + ) + + def transform_q(lamb): + return lamb*lamb + + def transform_s(lamb): + return jnp.sin(lamb*np.pi/2) + + def transform_e(lamb): + return jnp.cos(lamb*np.pi/2) + + def transform_w(lamb): + return (1-lamb*lamb) + + def interpolate_params(lamb, qlj_src, qlj_dst): + new_q = (1-transform_q(lamb))*qlj_src[:, 0] + transform_q(lamb)*qlj_dst[:, 0] + new_s = (1-transform_s(lamb))*qlj_src[:, 1] + transform_s(lamb)*qlj_dst[:, 1] + new_e = (1-transform_e(lamb))*qlj_src[:, 2] + transform_e(lamb)*qlj_dst[:, 2] + return jnp.stack([new_q, new_s, new_e], axis=1) + + def u_reference(x, params, box, lamb): + d4 = cutoff*(lambda_plane_idxs + lambda_offset_idxs*transform_w(lamb)) + d4 = jnp.expand_dims(d4, axis=-1) + x = jnp.concatenate((x, d4), axis=1) + + qlj_src = params[:len(params)//2] + qlj_dst = params[len(params)//2:] + qlj = interpolate_params(lamb, qlj_src, qlj_dst) + return ref_potential(x, qlj, box, lamb) + + + for lamb in [0.0, 0.2, 1.0]: + + qlj = np.concatenate([qlj_src, qlj_dst]) + + print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape) + + args = copy.deepcopy(test_potential.args) + args.append("lambda*lambda") # transform q + args.append("sin(lambda*PI/2)") # transform sigma + args.append("cos(lambda*PI/2)") # transform epsilon + args.append("1-lambda*lambda") # transform w + + test_interpolated_potential = potentials.NonbondedInterpolated( + *args, + ) + + self.compare_forces( + coords, + qlj, + box, + lamb, + u_reference, + test_interpolated_potential, + rtol, + precision=precision ) \ No newline at end of file diff --git a/timemachine/cpp/CMakeLists.txt b/timemachine/cpp/CMakeLists.txt index 0115bca62..ae76e6f30 100644 --- a/timemachine/cpp/CMakeLists.txt +++ b/timemachine/cpp/CMakeLists.txt @@ -60,7 +60,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/eigen) include_directories(${CMAKE_CURRENT_BINARY_DIR}/${CUB_SRC_DIR}) set_property(TARGET ${LIBRARY_NAME} PROPERTY CUDA_STANDARD 14) -target_link_libraries(${LIBRARY_NAME} PRIVATE -lcurand -lcudart -lcudadevrt) +target_link_libraries(${LIBRARY_NAME} PRIVATE -lcurand -lcuda -lcudart -lcudadevrt -lnvrtc) set_target_properties(${LIBRARY_NAME} PROPERTIES PREFIX "") install(TARGETS ${LIBRARY_NAME} DESTINATION "lib") diff --git a/timemachine/cpp/src/gpu_utils.cuh b/timemachine/cpp/src/gpu_utils.cuh index b6f9a3a8b..cc34318f1 100644 --- a/timemachine/cpp/src/gpu_utils.cuh +++ b/timemachine/cpp/src/gpu_utils.cuh @@ -16,6 +16,8 @@ curandStatus_t templateCurandNormal( + + #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) { @@ -36,6 +38,16 @@ inline void curandAssert(curandStatus_t code, const char *file, int line, bool a } } +#define NVRTC_SAFE_CALL(x) \ + do { \ + nvrtcResult result = x; \ + if (result != NVRTC_SUCCESS) { \ + std::cerr << "\nerror: " #x " failed with error " \ + << nvrtcGetErrorString(result) << '\n'; \ + exit(1); \ + } \ + } while(0) + // safe is for use of gpuErrchk template T* gpuErrchkCudaMallocAndCopy(const T *host_array, int count) { diff --git a/timemachine/cpp/src/jitify.hpp b/timemachine/cpp/src/jitify.hpp new file mode 100644 index 000000000..817eaab4a --- /dev/null +++ b/timemachine/cpp/src/jitify.hpp @@ -0,0 +1,4397 @@ +/* + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + ----------- + Jitify 0.9 + ----------- + A C++ library for easy integration of CUDA runtime compilation into + existing codes. + + -------------- + How to compile + -------------- + Compiler dependencies: , -std=c++11 + Linker dependencies: dl cuda nvrtc + + -------------------------------------- + Embedding source files into executable + -------------------------------------- + g++ ... -ldl -rdynamic -DJITIFY_ENABLE_EMBEDDED_FILES=1 + -Wl,-b,binary,my_kernel.cu,include/my_header.cuh,-b,default nvcc ... -ldl + -Xcompiler "-rdynamic + -Wl\,-b\,binary\,my_kernel.cu\,include/my_header.cuh\,-b\,default" + JITIFY_INCLUDE_EMBEDDED_FILE(my_kernel_cu); + JITIFY_INCLUDE_EMBEDDED_FILE(include_my_header_cuh); + + ---- + TODO + ---- + Extract valid compile options and pass the rest to cuModuleLoadDataEx + See if can have stringified headers automatically looked-up + by having stringify add them to a (static) global map. + The global map can be updated by creating a static class instance + whose constructor performs the registration. + Can then remove all headers from JitCache constructor in example code + See other TODOs in code +*/ + +/*! \file jitify.hpp + * \brief The Jitify library header + */ + +/*! \mainpage Jitify - A C++ library that simplifies the use of NVRTC + * \p Use class jitify::JitCache to manage and launch JIT-compiled CUDA + * kernels. + * + * \p Use namespace jitify::reflection to reflect types and values into + * code-strings. + * + * \p Use JITIFY_INCLUDE_EMBEDDED_FILE() to declare files that have been + * embedded into the executable using the GCC linker. + * + * \p Use jitify::parallel_for and JITIFY_LAMBDA() to generate and launch + * simple kernels. + */ + +#pragma once + +#ifndef JITIFY_THREAD_SAFE +#define JITIFY_THREAD_SAFE 1 +#endif + +#if JITIFY_ENABLE_EMBEDDED_FILES +#include +#endif +#include +#include +#include +#include // For strtok_r etc. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if JITIFY_THREAD_SAFE +#include +#endif + +#include +#include // For dim3, cudaStream_t +#if CUDA_VERSION >= 8000 +#define NVRTC_GET_TYPE_NAME 1 +#endif +#include + +// For use by get_current_executable_path(). +#ifdef __linux__ +#include // For PATH_MAX + +#include // For realpath +#define JITIFY_PATH_MAX PATH_MAX +#elif defined(_WIN32) || defined(_WIN64) +#include +#define JITIFY_PATH_MAX MAX_PATH +#else +#error "Unsupported platform" +#endif + +#ifdef _MSC_VER // MSVC compiler +#include // For UnDecorateSymbolName +#else +#include // For abi::__cxa_demangle +#endif + +#if defined(_WIN32) || defined(_WIN64) +// WAR for strtok_r being called strtok_s on Windows +#pragma push_macro("strtok_r") +#undef strtok_r +#define strtok_r strtok_s +// WAR for min and max possibly being macros defined by windows.h +#pragma push_macro("min") +#pragma push_macro("max") +#undef min +#undef max +#endif + +#ifndef JITIFY_PRINT_LOG +#define JITIFY_PRINT_LOG 1 +#endif + +#if JITIFY_PRINT_ALL +#define JITIFY_PRINT_INSTANTIATION 1 +#define JITIFY_PRINT_SOURCE 1 +#define JITIFY_PRINT_LOG 1 +#define JITIFY_PRINT_PTX 1 +#define JITIFY_PRINT_LINKER_LOG 1 +#define JITIFY_PRINT_LAUNCH 1 +#define JITIFY_PRINT_HEADER_PATHS 1 +#endif + +#if JITIFY_ENABLE_EMBEDDED_FILES +#define JITIFY_FORCE_UNDEFINED_SYMBOL(x) void* x##_forced = (void*)&x +/*! Include a source file that has been embedded into the executable using the + * GCC linker. + * \param name The name of the source file (not as a string), which must + * be sanitized by replacing non-alpha-numeric characters with underscores. + * E.g., \code{.cpp}JITIFY_INCLUDE_EMBEDDED_FILE(my_header_h)\endcode will + * include the embedded file "my_header.h". + * \note Files declared with this macro can be referenced using + * their original (unsanitized) filenames when creating a \p + * jitify::Program instance. + */ +#define JITIFY_INCLUDE_EMBEDDED_FILE(name) \ + extern "C" uint8_t _jitify_binary_##name##_start[] asm("_binary_" #name \ + "_start"); \ + extern "C" uint8_t _jitify_binary_##name##_end[] asm("_binary_" #name \ + "_end"); \ + JITIFY_FORCE_UNDEFINED_SYMBOL(_jitify_binary_##name##_start); \ + JITIFY_FORCE_UNDEFINED_SYMBOL(_jitify_binary_##name##_end) +#endif // JITIFY_ENABLE_EMBEDDED_FILES + +/*! Jitify library namespace + */ +namespace jitify { + +/*! Source-file load callback. + * + * \param filename The name of the requested source file. + * \param tmp_stream A temporary stream that can be used to hold source code. + * \return A pointer to an input stream containing the source code, or NULL + * to defer loading of the file to Jitify's file-loading mechanisms. + */ +typedef std::istream* (*file_callback_type)(std::string filename, + std::iostream& tmp_stream); + +// Exclude from Doxygen +//! \cond + +class JitCache; + +// Simple cache using LRU discard policy +template +class ObjectCache { + public: + typedef KeyType key_type; + typedef ValueType value_type; + + private: + typedef std::map object_map; + typedef std::deque key_rank; + typedef typename key_rank::iterator rank_iterator; + object_map _objects; + key_rank _ranked_keys; + size_t _capacity; + + inline void discard_old(size_t n = 0) { + if (n > _capacity) { + throw std::runtime_error("Insufficient capacity in cache"); + } + while (_objects.size() > _capacity - n) { + key_type discard_key = _ranked_keys.back(); + _ranked_keys.pop_back(); + _objects.erase(discard_key); + } + } + + public: + inline ObjectCache(size_t capacity = 8) : _capacity(capacity) {} + inline void resize(size_t capacity) { + _capacity = capacity; + this->discard_old(); + } + inline bool contains(const key_type& k) const { + return (bool)_objects.count(k); + } + inline void touch(const key_type& k) { + if (!this->contains(k)) { + throw std::runtime_error("Key not found in cache"); + } + rank_iterator rank = std::find(_ranked_keys.begin(), _ranked_keys.end(), k); + if (rank != _ranked_keys.begin()) { + // Move key to front of ranks + _ranked_keys.erase(rank); + _ranked_keys.push_front(k); + } + } + inline value_type& get(const key_type& k) { + if (!this->contains(k)) { + throw std::runtime_error("Key not found in cache"); + } + this->touch(k); + return _objects[k]; + } + inline value_type& insert(const key_type& k, + const value_type& v = value_type()) { + this->discard_old(1); + _ranked_keys.push_front(k); + return _objects.insert(std::make_pair(k, v)).first->second; + } + template + inline value_type& emplace(const key_type& k, Args&&... args) { + this->discard_old(1); + // Note: Use of piecewise_construct allows non-movable non-copyable types + auto iter = _objects + .emplace(std::piecewise_construct, std::forward_as_tuple(k), + std::forward_as_tuple(args...)) + .first; + _ranked_keys.push_front(iter->first); + return iter->second; + } +}; + +namespace detail { + +// Convenience wrapper for std::vector that provides handy constructors +template +class vector : public std::vector { + typedef std::vector super_type; + + public: + vector() : super_type() {} + vector(size_t n) : super_type(n) {} // Note: Not explicit, allows =0 + vector(std::vector const& vals) : super_type(vals) {} + template + vector(T const (&vals)[N]) : super_type(vals, vals + N) {} + vector(std::vector&& vals) : super_type(vals) {} + vector(std::initializer_list vals) : super_type(vals) {} +}; + +// Helper functions for parsing/manipulating source code + +inline std::string replace_characters(std::string str, + std::string const& oldchars, + char newchar) { + size_t i = str.find_first_of(oldchars); + while (i != std::string::npos) { + str[i] = newchar; + i = str.find_first_of(oldchars, i + 1); + } + return str; +} +inline std::string sanitize_filename(std::string name) { + return replace_characters(name, "/\\.-: ?%*|\"<>", '_'); +} + +#if JITIFY_ENABLE_EMBEDDED_FILES +class EmbeddedData { + void* _app; + EmbeddedData(EmbeddedData const&); + EmbeddedData& operator=(EmbeddedData const&); + + public: + EmbeddedData() { + _app = dlopen(NULL, RTLD_LAZY); + if (!_app) { + throw std::runtime_error(std::string("dlopen failed: ") + dlerror()); + } + dlerror(); // Clear any existing error + } + ~EmbeddedData() { + if (_app) { + dlclose(_app); + } + } + const uint8_t* operator[](std::string key) const { + key = sanitize_filename(key); + key = "_binary_" + key; + uint8_t const* data = (uint8_t const*)dlsym(_app, key.c_str()); + if (!data) { + throw std::runtime_error(std::string("dlsym failed: ") + dlerror()); + } + return data; + } + const uint8_t* begin(std::string key) const { + return (*this)[key + "_start"]; + } + const uint8_t* end(std::string key) const { return (*this)[key + "_end"]; } +}; +#endif // JITIFY_ENABLE_EMBEDDED_FILES + +inline bool is_tokenchar(char c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_'; +} +inline std::string replace_token(std::string src, std::string token, + std::string replacement) { + size_t i = src.find(token); + while (i != std::string::npos) { + if (i == 0 || i == src.size() - token.size() || + (!is_tokenchar(src[i - 1]) && !is_tokenchar(src[i + token.size()]))) { + src.replace(i, token.size(), replacement); + i += replacement.size(); + } else { + i += token.size(); + } + i = src.find(token, i); + } + return src; +} +inline std::string path_base(std::string p) { + // "/usr/local/myfile.dat" -> "/usr/local" + // "foo/bar" -> "foo" + // "foo/bar/" -> "foo/bar" +#if defined _WIN32 || defined _WIN64 + char sep = '\\'; +#else + char sep = '/'; +#endif + size_t i = p.find_last_of(sep); + if (i != std::string::npos) { + return p.substr(0, i); + } else { + return ""; + } +} +inline std::string path_join(std::string p1, std::string p2) { +#ifdef _WIN32 + char sep = '\\'; +#else + char sep = '/'; +#endif + if (p1.size() && p2.size() && p2[0] == sep) { + throw std::invalid_argument("Cannot join to absolute path"); + } + if (p1.size() && p1[p1.size() - 1] != sep) { + p1 += sep; + } + return p1 + p2; +} +// Elides "/." and "/.." tokens from path. +inline std::string path_simplify(const std::string& path) { + std::vector dirs; + std::string cur_dir; + bool after_slash = false; + for (int i = 0; i < (int)path.size(); ++i) { + if (path[i] == '/') { + if (after_slash) continue; // Ignore repeat slashes + after_slash = true; + if (cur_dir == ".." && !dirs.empty() && dirs.back() != "..") { + if (dirs.size() == 1 && dirs.front().empty()) { + throw std::runtime_error( + "Invalid path: back-traversals exceed depth of absolute path"); + } + dirs.pop_back(); + } else if (cur_dir != ".") { // Ignore /./ + dirs.push_back(cur_dir); + } + cur_dir.clear(); + } else { + after_slash = false; + cur_dir.push_back(path[i]); + } + } + if (!after_slash) { + dirs.push_back(cur_dir); + } + std::stringstream ss; + for (int i = 0; i < (int)dirs.size() - 1; ++i) { + ss << dirs[i] << "/"; + } + if (!dirs.empty()) ss << dirs.back(); + if (after_slash) ss << "/"; + return ss.str(); +} +inline unsigned long long hash_larson64(const char* s, + unsigned long long seed = 0) { + unsigned long long hash = seed; + while (*s) { + hash = hash * 101 + *s++; + } + return hash; +} + +inline uint64_t hash_combine(uint64_t a, uint64_t b) { + // Note: The magic number comes from the golden ratio + return a ^ (0x9E3779B97F4A7C17ull + b + (b >> 2) + (a << 6)); +} + +inline bool extract_include_info_from_compile_error(std::string log, + std::string& name, + std::string& parent, + int& line_num) { + static const std::vector pattern = { + "could not open source file \"", "cannot open source file \""}; + + for (auto& p : pattern) { + size_t beg = log.find(p); + if (beg != std::string::npos) { + beg += p.size(); + size_t end = log.find("\"", beg); + name = log.substr(beg, end - beg); + + size_t line_beg = log.rfind("\n", beg); + if (line_beg == std::string::npos) { + line_beg = 0; + } else { + line_beg += 1; + } + + size_t split = log.find("(", line_beg); + parent = log.substr(line_beg, split - line_beg); + line_num = + atoi(log.substr(split + 1, log.find(")", split + 1) - (split + 1)) + .c_str()); + + return true; + } + } + + return false; +} + +inline bool is_include_directive_with_quotes(const std::string& source, + int line_num) { + // TODO: Check each find() for failure. + size_t beg = 0; + for (int i = 1; i < line_num; ++i) { + beg = source.find("\n", beg) + 1; + } + beg = source.find("include", beg) + 7; + beg = source.find_first_of("\"<", beg); + return source[beg] == '"'; +} + +inline std::string comment_out_code_line(int line_num, std::string source) { + size_t beg = 0; + for (int i = 1; i < line_num; ++i) { + beg = source.find("\n", beg) + 1; + } + return (source.substr(0, beg) + "//" + source.substr(beg)); +} + +inline void print_with_line_numbers(std::string const& source) { + int linenum = 1; + std::stringstream source_ss(source); + for (std::string line; std::getline(source_ss, line); ++linenum) { + std::cout << std::setfill(' ') << std::setw(3) << linenum << " " << line + << std::endl; + } +} + +inline void print_compile_log(std::string program_name, + std::string const& log) { + std::cout << "---------------------------------------------------" + << std::endl; + std::cout << "--- JIT compile log for " << program_name << " ---" + << std::endl; + std::cout << "---------------------------------------------------" + << std::endl; + std::cout << log << std::endl; + std::cout << "---------------------------------------------------" + << std::endl; +} + +inline std::vector split_string(std::string str, + long maxsplit = -1, + std::string delims = " \t") { + std::vector results; + if (maxsplit == 0) { + results.push_back(str); + return results; + } + // Note: +1 to include NULL-terminator + std::vector v_str(str.c_str(), str.c_str() + (str.size() + 1)); + char* c_str = v_str.data(); + char* saveptr = c_str; + char* token = nullptr; + for (long i = 0; i != maxsplit; ++i) { + token = ::strtok_r(c_str, delims.c_str(), &saveptr); + c_str = 0; + if (!token) { + return results; + } + results.push_back(token); + } + // Check if there's a final piece + token += ::strlen(token) + 1; + if (token - v_str.data() < (ptrdiff_t)str.size()) { + // Find the start of the final piece + token += ::strspn(token, delims.c_str()); + if (*token) { + results.push_back(token); + } + } + return results; +} + +static const std::map& get_jitsafe_headers_map(); + +inline bool load_source( + std::string filename, std::map& sources, + std::string current_dir = "", + std::vector include_paths = std::vector(), + file_callback_type file_callback = 0, std::string* program_name = nullptr, + std::map* fullpaths = nullptr, + bool search_current_dir = true) { + std::istream* source_stream = 0; + std::stringstream string_stream; + std::ifstream file_stream; + // First detect direct source-code string ("my_program\nprogram_code...") + size_t newline_pos = filename.find("\n"); + if (newline_pos != std::string::npos) { + std::string source = filename.substr(newline_pos + 1); + filename = filename.substr(0, newline_pos); + string_stream << source; + source_stream = &string_stream; + } + if (program_name) { + *program_name = filename; + } + if (sources.count(filename)) { + // Already got this one + return true; + } + if (!source_stream) { + std::string fullpath = path_join(current_dir, filename); + // Try loading from callback + if (!file_callback || + !((source_stream = file_callback(fullpath, string_stream)) != 0)) { +#if JITIFY_ENABLE_EMBEDDED_FILES + // Try loading as embedded file + EmbeddedData embedded; + std::string source; + try { + source.assign(embedded.begin(fullpath), embedded.end(fullpath)); + string_stream << source; + source_stream = &string_stream; + } catch (std::runtime_error const&) +#endif // JITIFY_ENABLE_EMBEDDED_FILES + { + // Try loading from filesystem + bool found_file = false; + if (search_current_dir) { + file_stream.open(fullpath.c_str()); + if (file_stream) { + source_stream = &file_stream; + found_file = true; + } + } + // Search include directories + if (!found_file) { + for (int i = 0; i < (int)include_paths.size(); ++i) { + fullpath = path_join(include_paths[i], filename); + file_stream.open(fullpath.c_str()); + if (file_stream) { + source_stream = &file_stream; + found_file = true; + break; + } + } + if (!found_file) { + // Try loading from builtin headers + fullpath = path_join("__jitify_builtin", filename); + auto it = get_jitsafe_headers_map().find(filename); + if (it != get_jitsafe_headers_map().end()) { + string_stream << it->second; + source_stream = &string_stream; + } else { + return false; + } + } + } + } + } + if (fullpaths) { + // Record the full file path corresponding to this include name. + (*fullpaths)[filename] = path_simplify(fullpath); + } + } + sources[filename] = std::string(); + std::string& source = sources[filename]; + std::string line; + size_t linenum = 0; + unsigned long long hash = 0; + bool pragma_once = false; + bool remove_next_blank_line = false; + while (std::getline(*source_stream, line)) { + ++linenum; + + // HACK WAR for static variables not allowed on the device (unless + // __shared__) + // TODO: This breaks static member variables + // line = replace_token(line, "static const", "/*static*/ const"); + + // TODO: Need to watch out for /* */ comments too + std::string cleanline = + line.substr(0, line.find("//")); // Strip line comments + // if( cleanline.back() == "\r" ) { // Remove Windows line ending + // cleanline = cleanline.substr(0, cleanline.size()-1); + //} + // TODO: Should trim whitespace before checking .empty() + if (cleanline.empty() && remove_next_blank_line) { + remove_next_blank_line = false; + continue; + } + // Maintain a file hash for use in #pragma once WAR + hash = hash_larson64(line.c_str(), hash); + if (cleanline.find("#pragma once") != std::string::npos) { + pragma_once = true; + // Note: This is an attempt to recover the original line numbering, + // which otherwise gets off-by-one due to the include guard. + remove_next_blank_line = true; + // line = "//" + line; // Comment out the #pragma once line + continue; + } + + // HACK WAR for Thrust using "#define FOO #pragma bar" + // TODO: This is not robust to block comments, line continuations, or tabs. + size_t pragma_beg = cleanline.find("#pragma "); + if (pragma_beg != std::string::npos) { + std::string line_after_pragma = line.substr(pragma_beg + 8); + // TODO: Handle block comments (currently they cause a compilation error). + size_t comment_start = line_after_pragma.find("//"); + std::string pragma_args = line_after_pragma.substr(0, comment_start); + std::string comment = comment_start != std::string::npos + ? line_after_pragma.substr(comment_start) + : ""; + line = line.substr(0, pragma_beg) + "_Pragma(\"" + pragma_args + "\")" + + comment; + } + + source += line + "\n"; + } + // HACK TESTING (WAR for cub) + // source = "#define cudaDeviceSynchronize() cudaSuccess\n" + source; + ////source = "cudaError_t cudaDeviceSynchronize() { return cudaSuccess; }\n" + + /// source; + + // WAR for #pragma once causing problems when there are multiple inclusions + // of the same header from different paths. + if (pragma_once) { + std::stringstream ss; + ss << std::uppercase << std::hex << std::setw(8) << std::setfill('0') + << hash; + std::string include_guard_name = "_JITIFY_INCLUDE_GUARD_" + ss.str() + "\n"; + std::string include_guard_header; + include_guard_header += "#ifndef " + include_guard_name; + include_guard_header += "#define " + include_guard_name; + std::string include_guard_footer; + include_guard_footer += "#endif // " + include_guard_name; + source = include_guard_header + source + "\n" + include_guard_footer; + } + // return filename; + return true; +} + +} // namespace detail + +//! \endcond + +/*! Jitify reflection utilities namespace + */ +namespace reflection { + +// Provides type and value reflection via a function 'reflect': +// reflect() -> "Type" +// reflect(value) -> "(T)value" +// reflect() -> "VAL" +// reflect -> "VAL" +// reflect_template,char>() -> "" +// reflect_template({"float", "7", "char"}) -> "" + +/*! A wrapper class for non-type template parameters. + */ +template +struct NonType { + constexpr static T VALUE = VALUE_; +}; + +// Forward declaration +template +inline std::string reflect(T const& value); + +//! \cond + +namespace detail { + +template +inline std::string value_string(const T& x) { + std::stringstream ss; + ss << x; + return ss.str(); +} +// WAR for non-printable characters +template <> +inline std::string value_string(const char& x) { + std::stringstream ss; + ss << (int)x; + return ss.str(); +} +template <> +inline std::string value_string(const signed char& x) { + std::stringstream ss; + ss << (int)x; + return ss.str(); +} +template <> +inline std::string value_string(const unsigned char& x) { + std::stringstream ss; + ss << (int)x; + return ss.str(); +} +template <> +inline std::string value_string(const wchar_t& x) { + std::stringstream ss; + ss << (long)x; + return ss.str(); +} +// Specialisation for bool true/false literals +template <> +inline std::string value_string(const bool& x) { + return x ? "true" : "false"; +} + +// Removes all tokens that start with double underscores. +inline void strip_double_underscore_tokens(char* s) { + using jitify::detail::is_tokenchar; + char* w = s; + do { + if (*s == '_' && *(s + 1) == '_') { + while (is_tokenchar(*++s)) + ; + } + } while ((*w++ = *s++)); +} + +//#if CUDA_VERSION < 8000 +#ifdef _MSC_VER // MSVC compiler +inline std::string demangle_cuda_symbol(const char* mangled_name) { + // We don't have a way to demangle CUDA symbol names under MSVC. + return mangled_name; +} +inline std::string demangle_native_type(const std::type_info& typeinfo) { + // Get the decorated name and skip over the leading '.'. + const char* decorated_name = typeinfo.raw_name() + 1; + char undecorated_name[4096]; + if (UnDecorateSymbolName( + decorated_name, undecorated_name, + sizeof(undecorated_name) / sizeof(*undecorated_name), + UNDNAME_NO_ARGUMENTS | // Treat input as a type name + UNDNAME_NAME_ONLY // No "class" and "struct" prefixes + /*UNDNAME_NO_MS_KEYWORDS*/)) { // No "__cdecl", "__ptr64" etc. + // WAR for UNDNAME_NO_MS_KEYWORDS messing up function types. + strip_double_underscore_tokens(undecorated_name); + return undecorated_name; + } + throw std::runtime_error("UnDecorateSymbolName failed"); +} +#else // not MSVC +inline std::string demangle_cuda_symbol(const char* mangled_name) { + size_t bufsize = 0; + char* buf = nullptr; + std::string demangled_name; + int status; + auto demangled_ptr = std::unique_ptr( + abi::__cxa_demangle(mangled_name, buf, &bufsize, &status), free); + if (status == 0) { + demangled_name = demangled_ptr.get(); // all worked as expected + } else if (status == -2) { + demangled_name = mangled_name; // we interpret this as plain C name + } else if (status == -1) { + throw std::runtime_error( + std::string("memory allocation failure in __cxa_demangle")); + } else if (status == -3) { + throw std::runtime_error(std::string("invalid argument to __cxa_demangle")); + } + return demangled_name; +} +inline std::string demangle_native_type(const std::type_info& typeinfo) { + return demangle_cuda_symbol(typeinfo.name()); +} +#endif // not MSVC +//#endif // CUDA_VERSION < 8000 + +template +class JitifyTypeNameWrapper_ {}; + +template +struct type_reflection { + inline static std::string name() { + //#if CUDA_VERSION < 8000 + // TODO: Use nvrtcGetTypeName once it has the same behavior as this. + // WAR for typeid discarding cv qualifiers on value-types + // Wrap type in dummy template class to preserve cv-qualifiers, then strip + // off the wrapper from the resulting string. + std::string wrapped_name = + demangle_native_type(typeid(JitifyTypeNameWrapper_)); + // Note: The reflected name of this class also has namespace prefixes. + const std::string wrapper_class_name = "JitifyTypeNameWrapper_<"; + size_t start = wrapped_name.find(wrapper_class_name); + if (start == std::string::npos) { + throw std::runtime_error("Type reflection failed: " + wrapped_name); + } + start += wrapper_class_name.size(); + std::string name = + wrapped_name.substr(start, wrapped_name.size() - (start + 1)); + return name; + //#else + // std::string ret; + // nvrtcResult status = nvrtcGetTypeName(&ret); + // if( status != NVRTC_SUCCESS ) { + // throw std::runtime_error(std::string("nvrtcGetTypeName + // failed: + //")+ nvrtcGetErrorString(status)); + // } + // return ret; + //#endif + } +}; // namespace detail +template +struct type_reflection > { + inline static std::string name() { + return jitify::reflection::reflect(VALUE); + } +}; + +} // namespace detail + +//! \endcond + +/*! Create an Instance object that contains a const reference to the + * value. We use this to wrap abstract objects from which we want to extract + * their type at runtime (e.g., derived type). This is used to facilitate + * templating on derived type when all we know at compile time is abstract + * type. + */ +template +struct Instance { + const T& value; + Instance(const T& value_arg) : value(value_arg) {} +}; + +/*! Create an Instance object from which we can extract the value's run-time + * type. + * \param value The const value to be captured. + */ +template +inline Instance instance_of(T const& value) { + return Instance(value); +} + +/*! A wrapper used for representing types as values. + */ +template +struct Type {}; + +// Type reflection +// E.g., reflect() -> "float" +// Note: This strips trailing const and volatile qualifiers +/*! Generate a code-string for a type. + * \code{.cpp}reflect() --> "float"\endcode + */ +template +inline std::string reflect() { + return detail::type_reflection::name(); +} +// Value reflection +// E.g., reflect(3.14f) -> "(float)3.14" +/*! Generate a code-string for a value. + * \code{.cpp}reflect(3.14f) --> "(float)3.14"\endcode + */ +template +inline std::string reflect(T const& value) { + return "(" + reflect() + ")" + detail::value_string(value); +} +// Non-type template arg reflection (implicit conversion to int64_t) +// E.g., reflect<7>() -> "(int64_t)7" +/*! Generate a code-string for an integer non-type template argument. + * \code{.cpp}reflect<7>() --> "(int64_t)7"\endcode + */ +template +inline std::string reflect() { + return reflect >(); +} +// Non-type template arg reflection (explicit type) +// E.g., reflect() -> "(int)7" +/*! Generate a code-string for a generic non-type template argument. + * \code{.cpp} reflect() --> "(int)7" \endcode + */ +template +inline std::string reflect() { + return reflect >(); +} +// Type reflection via value +// E.g., reflect(Type()) -> "float" +/*! Generate a code-string for a type wrapped as a Type instance. + * \code{.cpp}reflect(Type()) --> "float"\endcode + */ +template +inline std::string reflect(jitify::reflection::Type) { + return reflect(); +} + +/*! Generate a code-string for a type wrapped as an Instance instance. + * \code{.cpp}reflect(Instance(3.1f)) --> "float"\endcode + * or more simply when passed to a instance_of helper + * \code{.cpp}reflect(instance_of(3.1f)) --> "float"\endcodei + * This is specifically for the case where we want to extract the run-time + * type, e.g., derived type, of an object pointer. + */ +template +inline std::string reflect(jitify::reflection::Instance& value) { + return detail::demangle_native_type(typeid(value.value)); +} + +// Type from value +// E.g., type_of(3.14f) -> Type() +/*! Create a Type object representing a value's type. + * \param value The value whose type is to be captured. + */ +template +inline Type type_of(T&) { + return Type(); +} +/*! Create a Type object representing a value's type. + * \param value The const value whose type is to be captured. + */ +template +inline Type type_of(T const&) { + return Type(); +} + +// Multiple value reflections one call, returning list of strings +template +inline std::vector reflect_all(Args... args) { + return {reflect(args)...}; +} + +inline std::string reflect_list(jitify::detail::vector const& args, + std::string opener = "", + std::string closer = "") { + std::stringstream ss; + ss << opener; + for (int i = 0; i < (int)args.size(); ++i) { + if (i > 0) ss << ","; + ss << args[i]; + } + ss << closer; + return ss.str(); +} + +// Template instantiation reflection +// inline std::string reflect_template(std::vector const& args) { +inline std::string reflect_template( + jitify::detail::vector const& args) { + // Note: The space in " >" is a WAR to avoid '>>' appearing + return reflect_list(args, "<", " >"); +} +// TODO: See if can make this evaluate completely at compile-time +template +inline std::string reflect_template() { + return reflect_template({reflect()...}); + // return reflect_template({reflect()...}); +} + +} // namespace reflection + +//! \cond + +namespace detail { + +// Demangles nested variable names using the PTX name mangling scheme +// (which follows the Itanium64 ABI). E.g., _ZN1a3Foo2bcE -> a::Foo::bc. +inline std::string demangle_ptx_variable_name(const char* name) { + std::stringstream ss; + const char* c = name; + if (*c++ != '_' || *c++ != 'Z') return name; // Non-mangled name + if (*c++ != 'N') return ""; // Not a nested name, unsupported + while (true) { + // Parse identifier length. + int n = 0; + while (std::isdigit(*c)) { + n = n * 10 + (*c - '0'); + c++; + } + if (!n) return ""; // Invalid or unsupported mangled name + // Parse identifier. + const char* c0 = c; + while (n-- && *c) c++; + if (!*c) return ""; // Mangled name is truncated + std::string id(c0, c); + // Identifiers starting with "_GLOBAL" are anonymous namespaces. + ss << (id.substr(0, 7) == "_GLOBAL" ? "(anonymous namespace)" : id); + // Nested name specifiers end with 'E'. + if (*c == 'E') break; + // There are more identifiers to come, add join token. + ss << "::"; + } + return ss.str(); +} + +static const char* get_current_executable_path() { + static const char* path = []() -> const char* { + static char buffer[JITIFY_PATH_MAX] = {}; +#ifdef __linux__ + if (!::realpath("/proc/self/exe", buffer)) return nullptr; +#elif defined(_WIN32) || defined(_WIN64) + if (!GetModuleFileNameA(nullptr, buffer, JITIFY_PATH_MAX)) return nullptr; +#endif + return buffer; + }(); + return path; +} + +inline bool endswith(const std::string& str, const std::string& suffix) { + return str.size() >= suffix.size() && + str.substr(str.size() - suffix.size()) == suffix; +} + +// Infers the JIT input type from the filename suffix. If no known suffix is +// present, the filename is assumed to refer to a library, and the associated +// suffix (and possibly prefix) is automatically added to the filename. +inline CUjitInputType get_cuda_jit_input_type(std::string* filename) { + if (endswith(*filename, ".ptx")) { + return CU_JIT_INPUT_PTX; + } else if (endswith(*filename, ".cubin")) { + return CU_JIT_INPUT_CUBIN; + } else if (endswith(*filename, ".fatbin")) { + return CU_JIT_INPUT_FATBINARY; + } else if (endswith(*filename, +#if defined _WIN32 || defined _WIN64 + ".obj" +#else // Linux + ".o" +#endif + )) { + return CU_JIT_INPUT_OBJECT; + } else { // Assume library +#if defined _WIN32 || defined _WIN64 + if (!endswith(*filename, ".lib")) { + *filename += ".lib"; + } +#else // Linux + if (!endswith(*filename, ".a")) { + *filename = "lib" + *filename + ".a"; + } +#endif + return CU_JIT_INPUT_LIBRARY; + } +} + +class CUDAKernel { + std::vector _link_files; + std::vector _link_paths; + CUlinkState _link_state; + CUmodule _module; + CUfunction _kernel; + std::string _func_name; + std::string _ptx; + std::map _global_map; + std::vector _opts; + std::vector _optvals; +#ifdef JITIFY_PRINT_LINKER_LOG + static const unsigned int _log_size = 8192; + char _error_log[_log_size]; + char _info_log[_log_size]; +#endif + + inline void cuda_safe_call(CUresult res) const { + if (res != CUDA_SUCCESS) { + const char* msg; + cuGetErrorName(res, &msg); + throw std::runtime_error(msg); + } + } + inline void create_module(std::vector link_files, + std::vector link_paths) { + CUresult result; +#ifndef JITIFY_PRINT_LINKER_LOG + // WAR since linker log does not seem to be constructed using a single call + // to cuModuleLoadDataEx. + if (link_files.empty()) { + result = + cuModuleLoadDataEx(&_module, _ptx.c_str(), (unsigned)_opts.size(), + _opts.data(), _optvals.data()); + } else +#endif + { + cuda_safe_call(cuLinkCreate((unsigned)_opts.size(), _opts.data(), + _optvals.data(), &_link_state)); + cuda_safe_call(cuLinkAddData(_link_state, CU_JIT_INPUT_PTX, + (void*)_ptx.c_str(), _ptx.size(), + "jitified_source.ptx", 0, 0, 0)); + for (int i = 0; i < (int)link_files.size(); ++i) { + std::string link_file = link_files[i]; + CUjitInputType jit_input_type; + if (link_file == ".") { + // Special case for linking to current executable. + link_file = get_current_executable_path(); + jit_input_type = CU_JIT_INPUT_OBJECT; + } else { + // Infer based on filename. + jit_input_type = get_cuda_jit_input_type(&link_file); + } + result = cuLinkAddFile(_link_state, jit_input_type, link_file.c_str(), + 0, 0, 0); + int path_num = 0; + while (result == CUDA_ERROR_FILE_NOT_FOUND && + path_num < (int)link_paths.size()) { + std::string filename = path_join(link_paths[path_num++], link_file); + result = cuLinkAddFile(_link_state, jit_input_type, filename.c_str(), + 0, 0, 0); + } +#if JITIFY_PRINT_LINKER_LOG + if (result == CUDA_ERROR_FILE_NOT_FOUND) { + std::cerr << "Linker error: Device library not found: " << link_file + << std::endl; + } else if (result != CUDA_SUCCESS) { + std::cerr << "Linker error: Failed to add file: " << link_file + << std::endl; + std::cerr << _error_log << std::endl; + } +#endif + cuda_safe_call(result); + } + size_t cubin_size; + void* cubin; + result = cuLinkComplete(_link_state, &cubin, &cubin_size); + if (result == CUDA_SUCCESS) { + result = cuModuleLoadData(&_module, cubin); + } + } +#ifdef JITIFY_PRINT_LINKER_LOG + std::cout << "---------------------------------------" << std::endl; + std::cout << "--- Linker for " + << reflection::detail::demangle_cuda_symbol(_func_name.c_str()) + << " ---" << std::endl; + std::cout << "---------------------------------------" << std::endl; + std::cout << _info_log << std::endl; + std::cout << std::endl; + std::cout << _error_log << std::endl; + std::cout << "---------------------------------------" << std::endl; +#endif + cuda_safe_call(result); + // Allow _func_name to be empty to support cases where we want to generate + // PTX containing extern symbol definitions but no kernels. + if (!_func_name.empty()) { + cuda_safe_call( + cuModuleGetFunction(&_kernel, _module, _func_name.c_str())); + } + } + inline void destroy_module() { + if (_link_state) { + cuda_safe_call(cuLinkDestroy(_link_state)); + } + _link_state = 0; + if (_module) { + cuModuleUnload(_module); + } + _module = 0; + } + + // create a map of __constant__ and __device__ variables in the ptx file + // mapping demangled to mangled name + inline void create_global_variable_map() { + size_t pos = 0; + while (pos < _ptx.size()) { + pos = std::min(_ptx.find(".const .align", pos), + _ptx.find(".global .align", pos)); + if (pos == std::string::npos) break; + size_t end = _ptx.find_first_of(";=", pos); + if (_ptx[end] == '=') --end; + std::string line = _ptx.substr(pos, end - pos); + pos = end; + size_t symbol_start = line.find_last_of(" ") + 1; + size_t symbol_end = line.find_last_of("["); + std::string entry = line.substr(symbol_start, symbol_end - symbol_start); + std::string key = detail::demangle_ptx_variable_name(entry.c_str()); + // Skip unsupported mangled names. E.g., a static variable defined inside + // a function (such variables are not directly addressable from outside + // the function, so skipping them is the correct behavior). + if (key == "") continue; + _global_map[key] = entry; + } + } + + inline void set_linker_log() { +#ifdef JITIFY_PRINT_LINKER_LOG + _opts.push_back(CU_JIT_INFO_LOG_BUFFER); + _optvals.push_back((void*)_info_log); + _opts.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES); + _optvals.push_back((void*)(long)_log_size); + _opts.push_back(CU_JIT_ERROR_LOG_BUFFER); + _optvals.push_back((void*)_error_log); + _opts.push_back(CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES); + _optvals.push_back((void*)(long)_log_size); + _opts.push_back(CU_JIT_LOG_VERBOSE); + _optvals.push_back((void*)1); +#endif + } + + public: + inline CUDAKernel() : _link_state(0), _module(0), _kernel(0) {} + inline CUDAKernel(const CUDAKernel& other) = delete; + inline CUDAKernel& operator=(const CUDAKernel& other) = delete; + inline CUDAKernel(CUDAKernel&& other) = delete; + inline CUDAKernel& operator=(CUDAKernel&& other) = delete; + inline CUDAKernel(const char* func_name, const char* ptx, + std::vector link_files, + std::vector link_paths, unsigned int nopts = 0, + CUjit_option* opts = 0, void** optvals = 0) + : _link_files(link_files), + _link_paths(link_paths), + _link_state(0), + _module(0), + _kernel(0), + _func_name(func_name), + _ptx(ptx), + _opts(opts, opts + nopts), + _optvals(optvals, optvals + nopts) { + this->set_linker_log(); + this->create_module(link_files, link_paths); + this->create_global_variable_map(); + } + + inline CUDAKernel& set(const char* func_name, const char* ptx, + std::vector link_files, + std::vector link_paths, + unsigned int nopts = 0, CUjit_option* opts = 0, + void** optvals = 0) { + this->destroy_module(); + _func_name = func_name; + _ptx = ptx; + _link_files = link_files; + _link_paths = link_paths; + _opts.assign(opts, opts + nopts); + _optvals.assign(optvals, optvals + nopts); + this->set_linker_log(); + this->create_module(link_files, link_paths); + this->create_global_variable_map(); + return *this; + } + inline ~CUDAKernel() { this->destroy_module(); } + inline operator CUfunction() const { return _kernel; } + + inline CUresult launch(dim3 grid, dim3 block, unsigned int smem, + CUstream stream, std::vector arg_ptrs) const { + return cuLaunchKernel(_kernel, grid.x, grid.y, grid.z, block.x, block.y, + block.z, smem, stream, arg_ptrs.data(), NULL); + } + + inline void safe_launch(dim3 grid, dim3 block, unsigned int smem, + CUstream stream, std::vector arg_ptrs) const { + return cuda_safe_call(cuLaunchKernel(_kernel, grid.x, grid.y, grid.z, + block.x, block.y, block.z, smem, + stream, arg_ptrs.data(), NULL)); + } + + inline int get_func_attribute(CUfunction_attribute attribute) const { + int value; + cuda_safe_call(cuFuncGetAttribute(&value, attribute, _kernel)); + return value; + } + + inline void set_func_attribute(CUfunction_attribute attribute, + int value) const { + cuda_safe_call(cuFuncSetAttribute(_kernel, attribute, value)); + } + + inline CUdeviceptr get_global_ptr(const char* name, + size_t* size = nullptr) const { + CUdeviceptr global_ptr = 0; + auto global = _global_map.find(name); + if (global != _global_map.end()) { + cuda_safe_call(cuModuleGetGlobal(&global_ptr, size, _module, + global->second.c_str())); + } else { + throw std::runtime_error(std::string("failed to look up global ") + name); + } + return global_ptr; + } + + template + inline CUresult get_global_data(const char* name, T* data, size_t count, + CUstream stream = 0) const { + size_t size_bytes; + CUdeviceptr ptr = get_global_ptr(name, &size_bytes); + size_t given_size_bytes = count * sizeof(T); + if (given_size_bytes != size_bytes) { + throw std::runtime_error( + std::string("Value for global variable ") + name + + " has wrong size: got " + std::to_string(given_size_bytes) + + " bytes, expected " + std::to_string(size_bytes)); + } + return cuMemcpyDtoHAsync(data, ptr, size_bytes, stream); + } + + template + inline CUresult set_global_data(const char* name, const T* data, size_t count, + CUstream stream = 0) const { + size_t size_bytes; + CUdeviceptr ptr = get_global_ptr(name, &size_bytes); + size_t given_size_bytes = count * sizeof(T); + if (given_size_bytes != size_bytes) { + throw std::runtime_error( + std::string("Value for global variable ") + name + + " has wrong size: got " + std::to_string(given_size_bytes) + + " bytes, expected " + std::to_string(size_bytes)); + } + return cuMemcpyHtoDAsync(ptr, data, size_bytes, stream); + } + + const std::string& function_name() const { return _func_name; } + const std::string& ptx() const { return _ptx; } + const std::vector& link_files() const { return _link_files; } + const std::vector& link_paths() const { return _link_paths; } +}; + +static const char* jitsafe_header_preinclude_h = R"( +//// WAR for Thrust (which appears to have forgotten to include this in result_of_adaptable_function.h +//#include + +//// WAR for Thrust (which appear to have forgotten to include this in error_code.h) +//#include + +// WAR for generics/shfl.h +#define THRUST_STATIC_ASSERT(x) + +// WAR for CUB +#ifdef __host__ +#undef __host__ +#endif +#define __host__ + +// WAR to allow exceptions to be parsed +#define try +#define catch(...) +)"; + +static const char* jitsafe_header_float_h = R"( +#pragma once + +#define FLT_RADIX 2 +#define FLT_MANT_DIG 24 +#define DBL_MANT_DIG 53 +#define FLT_DIG 6 +#define DBL_DIG 15 +#define FLT_MIN_EXP -125 +#define DBL_MIN_EXP -1021 +#define FLT_MIN_10_EXP -37 +#define DBL_MIN_10_EXP -307 +#define FLT_MAX_EXP 128 +#define DBL_MAX_EXP 1024 +#define FLT_MAX_10_EXP 38 +#define DBL_MAX_10_EXP 308 +#define FLT_MAX 3.4028234e38f +#define DBL_MAX 1.7976931348623157e308 +#define FLT_EPSILON 1.19209289e-7f +#define DBL_EPSILON 2.220440492503130e-16 +#define FLT_MIN 1.1754943e-38f +#define DBL_MIN 2.2250738585072013e-308 +#define FLT_ROUNDS 1 +#if defined __cplusplus && __cplusplus >= 201103L +#define FLT_EVAL_METHOD 0 +#define DECIMAL_DIG 21 +#endif +)"; + +static const char* jitsafe_header_limits_h = R"( +#pragma once + +#if defined _WIN32 || defined _WIN64 + #define __WORDSIZE 32 +#else + #if defined __x86_64__ && !defined __ILP32__ + #define __WORDSIZE 64 + #else + #define __WORDSIZE 32 + #endif +#endif +#define MB_LEN_MAX 16 +#define CHAR_BIT 8 +#define SCHAR_MIN (-128) +#define SCHAR_MAX 127 +#define UCHAR_MAX 255 +enum { + _JITIFY_CHAR_IS_UNSIGNED = (char)-1 >= 0, + CHAR_MIN = _JITIFY_CHAR_IS_UNSIGNED ? 0 : SCHAR_MIN, + CHAR_MAX = _JITIFY_CHAR_IS_UNSIGNED ? UCHAR_MAX : SCHAR_MAX, +}; +#define SHRT_MIN (-32768) +#define SHRT_MAX 32767 +#define USHRT_MAX 65535 +#define INT_MIN (-INT_MAX - 1) +#define INT_MAX 2147483647 +#define UINT_MAX 4294967295U +#if __WORDSIZE == 64 + # define LONG_MAX 9223372036854775807L +#else + # define LONG_MAX 2147483647L +#endif +#define LONG_MIN (-LONG_MAX - 1L) +#if __WORDSIZE == 64 + #define ULONG_MAX 18446744073709551615UL +#else + #define ULONG_MAX 4294967295UL +#endif +#define LLONG_MAX 9223372036854775807LL +#define LLONG_MIN (-LLONG_MAX - 1LL) +#define ULLONG_MAX 18446744073709551615ULL +)"; + +static const char* jitsafe_header_iterator = R"( +#pragma once + +namespace std { +struct output_iterator_tag {}; +struct input_iterator_tag {}; +struct forward_iterator_tag {}; +struct bidirectional_iterator_tag {}; +struct random_access_iterator_tag {}; +template +struct iterator_traits { + typedef typename Iterator::iterator_category iterator_category; + typedef typename Iterator::value_type value_type; + typedef typename Iterator::difference_type difference_type; + typedef typename Iterator::pointer pointer; + typedef typename Iterator::reference reference; +}; +template +struct iterator_traits { + typedef random_access_iterator_tag iterator_category; + typedef T value_type; + typedef ptrdiff_t difference_type; + typedef T* pointer; + typedef T& reference; +}; +template +struct iterator_traits { + typedef random_access_iterator_tag iterator_category; + typedef T value_type; + typedef ptrdiff_t difference_type; + typedef T const* pointer; + typedef T const& reference; +}; +} // namespace std +)"; + +// TODO: This is incomplete; need floating point limits +// Joe Eaton: added IEEE float and double types, none of the smaller types +// using type specific structs since we can't template on floats. +static const char* jitsafe_header_limits = R"( +#pragma once +#include +#include +#include +// TODO: epsilon(), infinity(), etc +namespace std { +namespace __jitify_detail { +#if __cplusplus >= 201103L +#define JITIFY_CXX11_CONSTEXPR constexpr +#define JITIFY_CXX11_NOEXCEPT noexcept +#else +#define JITIFY_CXX11_CONSTEXPR +#define JITIFY_CXX11_NOEXCEPT +#endif + +struct FloatLimits { +#if __cplusplus >= 201103L + static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ + float lowest() JITIFY_CXX11_NOEXCEPT { return -FLT_MAX;} + static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ + float min() JITIFY_CXX11_NOEXCEPT { return FLT_MIN; } + static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ + float max() JITIFY_CXX11_NOEXCEPT { return FLT_MAX; } +#endif // __cplusplus >= 201103L + enum { + is_specialized = true, + is_signed = true, + is_integer = false, + is_exact = false, + has_infinity = true, + has_quiet_NaN = true, + has_signaling_NaN = true, + has_denorm = 1, + has_denorm_loss = true, + round_style = 1, + is_iec559 = true, + is_bounded = true, + is_modulo = false, + digits = 24, + digits10 = 6, + max_digits10 = 9, + radix = 2, + min_exponent = -125, + min_exponent10 = -37, + max_exponent = 128, + max_exponent10 = 38, + tinyness_before = false, + traps = false + }; +}; +struct DoubleLimits { +#if __cplusplus >= 201103L + static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ + double lowest() noexcept { return -DBL_MAX; } + static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ + double min() noexcept { return DBL_MIN; } + static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ + double max() noexcept { return DBL_MAX; } +#endif // __cplusplus >= 201103L + enum { + is_specialized = true, + is_signed = true, + is_integer = false, + is_exact = false, + has_infinity = true, + has_quiet_NaN = true, + has_signaling_NaN = true, + has_denorm = 1, + has_denorm_loss = true, + round_style = 1, + is_iec559 = true, + is_bounded = true, + is_modulo = false, + digits = 53, + digits10 = 15, + max_digits10 = 17, + radix = 2, + min_exponent = -1021, + min_exponent10 = -307, + max_exponent = 1024, + max_exponent10 = 308, + tinyness_before = false, + traps = false + }; +}; +template +struct IntegerLimits { + static inline __host__ __device__ T min() { return Min; } + static inline __host__ __device__ T max() { return Max; } +#if __cplusplus >= 201103L + static constexpr inline __host__ __device__ T lowest() noexcept { + return Min; + } +#endif // __cplusplus >= 201103L + enum { + is_specialized = true, + digits = (Digits == -1) ? (int)(sizeof(T)*8 - (Min != 0)) : Digits, + digits10 = (digits * 30103) / 100000, + is_signed = ((T)(-1)<0), + is_integer = true, + is_exact = true, + radix = 2, + is_bounded = true, + is_modulo = false + }; +}; +} // namespace __jitify_detail +template struct numeric_limits { + enum { is_specialized = false }; +}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits {}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits {}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits {}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::IntegerLimits +{}; +//template struct numeric_limits { static const bool +//is_signed = ((T)(-1)<0); }; +template<> struct numeric_limits : public +__jitify_detail::FloatLimits +{}; +template<> struct numeric_limits : public +__jitify_detail::DoubleLimits +{}; +} // namespace std +)"; + +// TODO: This is highly incomplete +static const char* jitsafe_header_type_traits = R"( + #pragma once + #if __cplusplus >= 201103L + namespace std { + + template struct enable_if {}; + template struct enable_if { typedef T type; }; + #if __cplusplus >= 201402L + template< bool B, class T = void > using enable_if_t = typename enable_if::type; + #endif + + struct true_type { + enum { value = true }; + operator bool() const { return true; } + }; + struct false_type { + enum { value = false }; + operator bool() const { return false; } + }; + + template struct is_floating_point : false_type {}; + template<> struct is_floating_point : true_type {}; + template<> struct is_floating_point : true_type {}; + template<> struct is_floating_point : true_type {}; + + template struct is_integral : false_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + template<> struct is_integral : true_type {}; + + template struct is_signed : false_type {}; + template<> struct is_signed : true_type {}; + template<> struct is_signed : true_type {}; + template<> struct is_signed : true_type {}; + template<> struct is_signed : true_type {}; + template<> struct is_signed : true_type {}; + template<> struct is_signed : true_type {}; + template<> struct is_signed : true_type {}; + template<> struct is_signed : true_type {}; + + template struct is_unsigned : false_type {}; + template<> struct is_unsigned : true_type {}; + template<> struct is_unsigned : true_type {}; + template<> struct is_unsigned : true_type {}; + template<> struct is_unsigned : true_type {}; + template<> struct is_unsigned : true_type {}; + + template struct is_same : false_type {}; + template struct is_same : true_type {}; + + template struct is_array : false_type {}; + template struct is_array : true_type {}; + template struct is_array : true_type {}; + + //partial implementation only of is_function + template struct is_function : false_type { }; + template struct is_function : true_type {}; //regular + template struct is_function : true_type {}; // variadic + + template struct result_of; + template + struct result_of { + // TODO: This is a hack; a proper implem is quite complicated. + typedef typename F::result_type type; + }; + + template struct remove_reference { typedef T type; }; + template struct remove_reference { typedef T type; }; + template struct remove_reference { typedef T type; }; + #if __cplusplus >= 201402L + template< class T > using remove_reference_t = typename remove_reference::type; + #endif + + template struct remove_extent { typedef T type; }; + template struct remove_extent { typedef T type; }; + template struct remove_extent { typedef T type; }; + #if __cplusplus >= 201402L + template< class T > using remove_extent_t = typename remove_extent::type; + #endif + + template< class T > struct remove_const { typedef T type; }; + template< class T > struct remove_const { typedef T type; }; + template< class T > struct remove_volatile { typedef T type; }; + template< class T > struct remove_volatile { typedef T type; }; + template< class T > struct remove_cv { typedef typename remove_volatile::type>::type type; }; + #if __cplusplus >= 201402L + template< class T > using remove_cv_t = typename remove_cv::type; + template< class T > using remove_const_t = typename remove_const::type; + template< class T > using remove_volatile_t = typename remove_volatile::type; + #endif + + template struct conditional { typedef T type; }; + template struct conditional { typedef F type; }; + #if __cplusplus >= 201402L + template< bool B, class T, class F > using conditional_t = typename conditional::type; + #endif + + namespace __jitify_detail { + template< class T, bool is_function_type = false > struct add_pointer { using type = typename remove_reference::type*; }; + template< class T > struct add_pointer { using type = T; }; + template< class T, class... Args > struct add_pointer { using type = T(*)(Args...); }; + template< class T, class... Args > struct add_pointer { using type = T(*)(Args..., ...); }; + } // namespace __jitify_detail + template< class T > struct add_pointer : __jitify_detail::add_pointer::value> {}; + #if __cplusplus >= 201402L + template< class T > using add_pointer_t = typename add_pointer::type; + #endif + + template< class T > struct decay { + private: + typedef typename remove_reference::type U; + public: + typedef typename conditional::value, typename remove_extent::type*, + typename conditional::value,typename add_pointer::type,typename remove_cv::type + >::type>::type type; + }; + #if __cplusplus >= 201402L + template< class T > using decay_t = typename decay::type; + #endif + + template + struct integral_constant { + static constexpr T value = v; + typedef T value_type; + typedef integral_constant type; // using injected-class-name + constexpr operator value_type() const noexcept { return value; } + #if __cplusplus >= 201402L + constexpr value_type operator()() const noexcept { return value; } + #endif + }; + + template struct is_lvalue_reference : false_type {}; + template struct is_lvalue_reference : true_type {}; + + template struct is_rvalue_reference : false_type {}; + template struct is_rvalue_reference : true_type {}; + + namespace __jitify_detail { + template struct type_identity { using type = T; }; + template auto add_lvalue_reference(int) -> type_identity; + template auto add_lvalue_reference(...) -> type_identity; + template auto add_rvalue_reference(int) -> type_identity; + template auto add_rvalue_reference(...) -> type_identity; + } // namespace _jitify_detail + + template struct add_lvalue_reference : decltype(__jitify_detail::add_lvalue_reference(0)) {}; + template struct add_rvalue_reference : decltype(__jitify_detail::add_rvalue_reference(0)) {}; + #if __cplusplus >= 201402L + template using add_lvalue_reference_t = typename add_lvalue_reference::type; + template using add_rvalue_reference_t = typename add_rvalue_reference::type; + #endif + + template struct is_const : public false_type {}; + template struct is_const : public true_type {}; + + template struct is_volatile : public false_type {}; + template struct is_volatile : public true_type {}; + + template struct is_void : public false_type {}; + template<> struct is_void : public true_type {}; + template<> struct is_void : public true_type {}; + + template struct is_reference : public false_type {}; + template struct is_reference : public true_type {}; + + template::value || is_reference<_Tp>::value)> + struct __add_reference_helper { typedef _Tp& type; }; + + template struct __add_reference_helper<_Tp, true> { typedef _Tp type; }; + template struct add_reference : public __add_reference_helper<_Tp>{}; + + namespace __jitify_detail { + template struct is_int_or_cref { + typedef typename remove_reference::type type_sans_ref; + static const bool value = (is_integral::value || (is_integral::value + && is_const::value && !is_volatile::value)); + }; // end is_int_or_cref + template struct is_convertible_sfinae { + private: + typedef char yes; + typedef struct { char two_chars[2]; } no; + static inline yes test(To) { return yes(); } + static inline no test(...) { return no(); } + static inline typename remove_reference::type& from() { typename remove_reference::type* ptr = 0; return *ptr; } + public: + static const bool value = sizeof(test(from())) == sizeof(yes); + }; // end is_convertible_sfinae + template struct is_convertible_needs_simple_test { + static const bool from_is_void = is_void::value; + static const bool to_is_void = is_void::value; + static const bool from_is_float = is_floating_point::type>::value; + static const bool to_is_int_or_cref = is_int_or_cref::value; + static const bool value = (from_is_void || to_is_void || (from_is_float && to_is_int_or_cref)); + }; // end is_convertible_needs_simple_test + template::value> + struct is_convertible { + static const bool value = (is_void::value || (is_int_or_cref::value && !is_void::value)); + }; // end is_convertible + template struct is_convertible { + static const bool value = (is_convertible_sfinae::type, To>::value); + }; // end is_convertible + } // end __jitify_detail + // implementation of is_convertible taken from thrust's pre C++11 path + template struct is_convertible + : public integral_constant::value> + { }; // end is_convertible + + template struct is_base_of { }; + + template struct aligned_storage { struct type { alignas(alignment) char data[len]; }; }; + template struct alignment_of : std::integral_constant {}; + + } // namespace std + #endif // c++11 +)"; + +// TODO: INT_FAST8_MAX et al. and a few other misc constants +static const char* jitsafe_header_stdint_h = + "#pragma once\n" + "#include \n" + "namespace __jitify_stdint_ns {\n" + "typedef signed char int8_t;\n" + "typedef signed short int16_t;\n" + "typedef signed int int32_t;\n" + "typedef signed long long int64_t;\n" + "typedef signed char int_fast8_t;\n" + "typedef signed short int_fast16_t;\n" + "typedef signed int int_fast32_t;\n" + "typedef signed long long int_fast64_t;\n" + "typedef signed char int_least8_t;\n" + "typedef signed short int_least16_t;\n" + "typedef signed int int_least32_t;\n" + "typedef signed long long int_least64_t;\n" + "typedef signed long long intmax_t;\n" + "typedef signed long intptr_t; //optional\n" + "typedef unsigned char uint8_t;\n" + "typedef unsigned short uint16_t;\n" + "typedef unsigned int uint32_t;\n" + "typedef unsigned long long uint64_t;\n" + "typedef unsigned char uint_fast8_t;\n" + "typedef unsigned short uint_fast16_t;\n" + "typedef unsigned int uint_fast32_t;\n" + "typedef unsigned long long uint_fast64_t;\n" + "typedef unsigned char uint_least8_t;\n" + "typedef unsigned short uint_least16_t;\n" + "typedef unsigned int uint_least32_t;\n" + "typedef unsigned long long uint_least64_t;\n" + "typedef unsigned long long uintmax_t;\n" + "#define INT8_MIN SCHAR_MIN\n" + "#define INT16_MIN SHRT_MIN\n" + "#if defined _WIN32 || defined _WIN64\n" + "#define WCHAR_MIN SHRT_MIN\n" + "#define WCHAR_MAX SHRT_MAX\n" + "typedef unsigned long long uintptr_t; //optional\n" + "#else\n" + "#define WCHAR_MIN INT_MIN\n" + "#define WCHAR_MAX INT_MAX\n" + "typedef unsigned long uintptr_t; //optional\n" + "#endif\n" + "#define INT32_MIN INT_MIN\n" + "#define INT64_MIN LLONG_MIN\n" + "#define INT8_MAX SCHAR_MAX\n" + "#define INT16_MAX SHRT_MAX\n" + "#define INT32_MAX INT_MAX\n" + "#define INT64_MAX LLONG_MAX\n" + "#define UINT8_MAX UCHAR_MAX\n" + "#define UINT16_MAX USHRT_MAX\n" + "#define UINT32_MAX UINT_MAX\n" + "#define UINT64_MAX ULLONG_MAX\n" + "#define INTPTR_MIN LONG_MIN\n" + "#define INTMAX_MIN LLONG_MIN\n" + "#define INTPTR_MAX LONG_MAX\n" + "#define INTMAX_MAX LLONG_MAX\n" + "#define UINTPTR_MAX ULONG_MAX\n" + "#define UINTMAX_MAX ULLONG_MAX\n" + "#define PTRDIFF_MIN INTPTR_MIN\n" + "#define PTRDIFF_MAX INTPTR_MAX\n" + "#define SIZE_MAX UINT64_MAX\n" + "} // namespace __jitify_stdint_ns\n" + "namespace std { using namespace __jitify_stdint_ns; }\n" + "using namespace __jitify_stdint_ns;\n"; + +// TODO: offsetof +static const char* jitsafe_header_stddef_h = + "#pragma once\n" + "#include \n" + "namespace __jitify_stddef_ns {\n" + "#if __cplusplus >= 201103L\n" + "typedef decltype(nullptr) nullptr_t;\n" + "#if defined(_MSC_VER)\n" + " typedef double max_align_t;\n" + "#elif defined(__APPLE__)\n" + " typedef long double max_align_t;\n" + "#else\n" + " // Define max_align_t to match the GCC definition.\n" + " typedef struct {\n" + " long long __jitify_max_align_nonce1\n" + " __attribute__((__aligned__(__alignof__(long long))));\n" + " long double __jitify_max_align_nonce2\n" + " __attribute__((__aligned__(__alignof__(long double))));\n" + " } max_align_t;\n" + "#endif\n" + "#endif // __cplusplus >= 201103L\n" + "#if __cplusplus >= 201703L\n" + "enum class byte : unsigned char {};\n" + "#endif // __cplusplus >= 201703L\n" + "} // namespace __jitify_stddef_ns\n" + "namespace std {\n" + " // NVRTC provides built-in definitions of ::size_t and ::ptrdiff_t.\n" + " using ::size_t;\n" + " using ::ptrdiff_t;\n" + " using namespace __jitify_stddef_ns;\n" + "} // namespace std\n" + "using namespace __jitify_stddef_ns;\n"; + +static const char* jitsafe_header_stdlib_h = + "#pragma once\n" + "#include \n"; +static const char* jitsafe_header_stdio_h = + "#pragma once\n" + "#include \n" + "#define FILE int\n" + "int fflush ( FILE * stream );\n" + "int fprintf ( FILE * stream, const char * format, ... );\n"; + +static const char* jitsafe_header_string_h = + "#pragma once\n" + "char* strcpy ( char * destination, const char * source );\n" + "int strcmp ( const char * str1, const char * str2 );\n" + "char* strerror( int errnum );\n"; + +static const char* jitsafe_header_cstring = + "#pragma once\n" + "\n" + "namespace __jitify_cstring_ns {\n" + "char* strcpy ( char * destination, const char * source );\n" + "int strcmp ( const char * str1, const char * str2 );\n" + "char* strerror( int errnum );\n" + "} // namespace __jitify_cstring_ns\n" + "namespace std { using namespace __jitify_cstring_ns; }\n" + "using namespace __jitify_cstring_ns;\n"; + +// HACK TESTING (WAR for cub) +static const char* jitsafe_header_iostream = + "#pragma once\n" + "#include \n" + "#include \n"; +// HACK TESTING (WAR for Thrust) +static const char* jitsafe_header_ostream = + "#pragma once\n" + "\n" + "namespace std {\n" + "template\n" // = std::char_traits + // >\n" + "struct basic_ostream {\n" + "};\n" + "typedef basic_ostream ostream;\n" + "ostream& endl(ostream& os);\n" + "ostream& operator<<( ostream&, ostream& (*f)( ostream& ) );\n" + "template< class CharT, class Traits > basic_ostream& endl( " + "basic_ostream& os );\n" + "template< class CharT, class Traits > basic_ostream& " + "operator<<( basic_ostream& os, const char* c );\n" + "#if __cplusplus >= 201103L\n" + "template< class CharT, class Traits, class T > basic_ostream& operator<<( basic_ostream&& os, const T& value );\n" + "#endif // __cplusplus >= 201103L\n" + "} // namespace std\n"; + +static const char* jitsafe_header_istream = + "#pragma once\n" + "\n" + "namespace std {\n" + "template\n" // = std::char_traits + // >\n" + "struct basic_istream {\n" + "};\n" + "typedef basic_istream istream;\n" + "} // namespace std\n"; + +static const char* jitsafe_header_sstream = + "#pragma once\n" + "#include \n" + "#include \n"; + +static const char* jitsafe_header_utility = + "#pragma once\n" + "namespace std {\n" + "template\n" + "struct pair {\n" + " T1 first;\n" + " T2 second;\n" + " inline pair() {}\n" + " inline pair(T1 const& first_, T2 const& second_)\n" + " : first(first_), second(second_) {}\n" + " // TODO: Standard includes many more constructors...\n" + " // TODO: Comparison operators\n" + "};\n" + "template\n" + "pair make_pair(T1 const& first, T2 const& second) {\n" + " return pair(first, second);\n" + "}\n" + "} // namespace std\n"; + +// TODO: incomplete +static const char* jitsafe_header_vector = + "#pragma once\n" + "namespace std {\n" + "template\n" // = std::allocator> \n" + "struct vector {\n" + "};\n" + "} // namespace std\n"; + +// TODO: incomplete +static const char* jitsafe_header_string = + "#pragma once\n" + "namespace std {\n" + "template\n" + "struct basic_string {\n" + "basic_string();\n" + "basic_string( const CharT* s );\n" //, const Allocator& alloc = + // Allocator() );\n" + "const CharT* c_str() const;\n" + "bool empty() const;\n" + "void operator+=(const char *);\n" + "void operator+=(const basic_string &);\n" + "};\n" + "typedef basic_string string;\n" + "} // namespace std\n"; + +// TODO: incomplete +static const char* jitsafe_header_stdexcept = + "#pragma once\n" + "namespace std {\n" + "struct runtime_error {\n" + "explicit runtime_error( const std::string& what_arg );" + "explicit runtime_error( const char* what_arg );" + "virtual const char* what() const;\n" + "};\n" + "} // namespace std\n"; + +// TODO: incomplete +static const char* jitsafe_header_complex = + "#pragma once\n" + "namespace std {\n" + "template\n" + "class complex {\n" + " T _real;\n" + " T _imag;\n" + "public:\n" + " complex() : _real(0), _imag(0) {}\n" + " complex(T const& real, T const& imag)\n" + " : _real(real), _imag(imag) {}\n" + " complex(T const& real)\n" + " : _real(real), _imag(static_cast(0)) {}\n" + " T const& real() const { return _real; }\n" + " T& real() { return _real; }\n" + " void real(const T &r) { _real = r; }\n" + " T const& imag() const { return _imag; }\n" + " T& imag() { return _imag; }\n" + " void imag(const T &i) { _imag = i; }\n" + " complex& operator+=(const complex z)\n" + " { _real += z.real(); _imag += z.imag(); return *this; }\n" + "};\n" + "template\n" + "complex operator*(const complex& lhs, const complex& rhs)\n" + " { return complex(lhs.real()*rhs.real()-lhs.imag()*rhs.imag(),\n" + " lhs.real()*rhs.imag()+lhs.imag()*rhs.real()); }\n" + "template\n" + "complex operator*(const complex& lhs, const T & rhs)\n" + " { return complexs(lhs.real()*rhs,lhs.imag()*rhs); }\n" + "template\n" + "complex operator*(const T& lhs, const complex& rhs)\n" + " { return complexs(rhs.real()*lhs,rhs.imag()*lhs); }\n" + "} // namespace std\n"; + +// TODO: This is incomplete (missing binary and integer funcs, macros, +// constants, types) +static const char* jitsafe_header_math_h = + "#pragma once\n" + "namespace __jitify_math_ns {\n" + "#if __cplusplus >= 201103L\n" + "#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\\n" + " inline double f(double x) { return ::f(x); } \\\n" + " inline float f##f(float x) { return ::f(x); } \\\n" + " /*inline long double f##l(long double x) { return ::f(x); }*/ \\\n" + " inline float f(float x) { return ::f(x); } \\\n" + " /*inline long double f(long double x) { return ::f(x); }*/\n" + "#else\n" + "#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f) \\\n" + " inline double f(double x) { return ::f(x); } \\\n" + " inline float f##f(float x) { return ::f(x); } \\\n" + " /*inline long double f##l(long double x) { return ::f(x); }*/\n" + "#endif\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(cos)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(sin)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(tan)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(acos)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(asin)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(atan)\n" + "template inline T atan2(T y, T x) { return ::atan2(y, x); }\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(cosh)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(sinh)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(tanh)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(exp)\n" + "template inline T frexp(T x, int* exp) { return ::frexp(x, " + "exp); }\n" + "template inline T ldexp(T x, int exp) { return ::ldexp(x, " + "exp); }\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(log)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(log10)\n" + "template inline T modf(T x, T* intpart) { return ::modf(x, " + "intpart); }\n" + "template inline T pow(T x, T y) { return ::pow(x, y); }\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(sqrt)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(ceil)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(floor)\n" + "template inline T fmod(T n, T d) { return ::fmod(n, d); }\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(fabs)\n" + "template inline T abs(T x) { return ::abs(x); }\n" + "#if __cplusplus >= 201103L\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(acosh)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(asinh)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(atanh)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(exp2)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(expm1)\n" + "template inline int ilogb(T x) { return ::ilogb(x); }\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(log1p)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(log2)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(logb)\n" + "template inline T scalbn (T x, int n) { return ::scalbn(x, " + "n); }\n" + "template inline T scalbln(T x, long n) { return ::scalbn(x, " + "n); }\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(cbrt)\n" + "template inline T hypot(T x, T y) { return ::hypot(x, y); }\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(erf)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(erfc)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(tgamma)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(lgamma)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(trunc)\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(round)\n" + "template inline long lround(T x) { return ::lround(x); }\n" + "template inline long long llround(T x) { return ::llround(x); " + "}\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(rint)\n" + "template inline long lrint(T x) { return ::lrint(x); }\n" + "template inline long long llrint(T x) { return ::llrint(x); " + "}\n" + "DEFINE_MATH_UNARY_FUNC_WRAPPER(nearbyint)\n" + // TODO: remainder, remquo, copysign, nan, nextafter, nexttoward, fdim, + // fmax, fmin, fma + "#endif\n" + "#undef DEFINE_MATH_UNARY_FUNC_WRAPPER\n" + "} // namespace __jitify_math_ns\n" + "namespace std { using namespace __jitify_math_ns; }\n" + "#define M_PI 3.14159265358979323846\n" + // Note: Global namespace already includes CUDA math funcs + "//using namespace __jitify_math_ns;\n"; + +static const char* jitsafe_header_memory_h = R"( + #pragma once + #include + )"; + +// TODO: incomplete +static const char* jitsafe_header_mutex = R"( + #pragma once + #if __cplusplus >= 201103L + namespace std { + class mutex { + public: + void lock(); + bool try_lock(); + void unlock(); + }; + } // namespace std + #endif + )"; + +static const char* jitsafe_header_algorithm = R"( + #pragma once + #if __cplusplus >= 201103L + namespace std { + + #if __cplusplus == 201103L + #define JITIFY_CXX14_CONSTEXPR + #else + #define JITIFY_CXX14_CONSTEXPR constexpr + #endif + + template JITIFY_CXX14_CONSTEXPR const T& max(const T& a, const T& b) + { + return (b > a) ? b : a; + } + template JITIFY_CXX14_CONSTEXPR const T& min(const T& a, const T& b) + { + return (b < a) ? b : a; + } + + } // namespace std + #endif + )"; + +static const char* jitsafe_header_time_h = R"( + #pragma once + #define NULL 0 + #define CLOCKS_PER_SEC 1000000 + namespace __jitify_time_ns { + typedef long time_t; + struct tm { + int tm_sec; + int tm_min; + int tm_hour; + int tm_mday; + int tm_mon; + int tm_year; + int tm_wday; + int tm_yday; + int tm_isdst; + }; + #if __cplusplus >= 201703L + struct timespec { + time_t tv_sec; + long tv_nsec; + }; + #endif + } // namespace __jitify_time_ns + namespace std { + // NVRTC provides built-in definitions of ::size_t and ::clock_t. + using ::size_t; + using ::clock_t; + using namespace __jitify_time_ns; + } + using namespace __jitify_time_ns; + )"; + +static const char* jitsafe_header_tuple = R"( + #pragma once + #if __cplusplus >= 201103L + namespace std { + template class tuple; + } // namespace std + #endif + )"; + +static const char* jitsafe_header_assert = R"( + #pragma once + )"; + +// WAR: These need to be pre-included as a workaround for NVRTC implicitly using +// /usr/include as an include path. The other built-in headers will be included +// lazily as needed. +static const char* preinclude_jitsafe_header_names[] = {"jitify_preinclude.h", + "limits.h", + "math.h", + "memory.h", + "stdint.h", + "stdlib.h", + "stdio.h", + "string.h", + "time.h", + "assert.h"}; + +template +int array_size(T (&)[N]) { + return N; +} +const int preinclude_jitsafe_headers_count = + array_size(preinclude_jitsafe_header_names); + +static const std::map& get_jitsafe_headers_map() { + static const std::map jitsafe_headers_map = { + {"jitify_preinclude.h", jitsafe_header_preinclude_h}, + {"float.h", jitsafe_header_float_h}, + {"cfloat", jitsafe_header_float_h}, + {"limits.h", jitsafe_header_limits_h}, + {"climits", jitsafe_header_limits_h}, + {"stdint.h", jitsafe_header_stdint_h}, + {"cstdint", jitsafe_header_stdint_h}, + {"stddef.h", jitsafe_header_stddef_h}, + {"cstddef", jitsafe_header_stddef_h}, + {"stdlib.h", jitsafe_header_stdlib_h}, + {"cstdlib", jitsafe_header_stdlib_h}, + {"stdio.h", jitsafe_header_stdio_h}, + {"cstdio", jitsafe_header_stdio_h}, + {"string.h", jitsafe_header_string_h}, + {"cstring", jitsafe_header_cstring}, + {"iterator", jitsafe_header_iterator}, + {"limits", jitsafe_header_limits}, + {"type_traits", jitsafe_header_type_traits}, + {"utility", jitsafe_header_utility}, + {"math.h", jitsafe_header_math_h}, + {"cmath", jitsafe_header_math_h}, + {"memory.h", jitsafe_header_memory_h}, + {"complex", jitsafe_header_complex}, + {"iostream", jitsafe_header_iostream}, + {"ostream", jitsafe_header_ostream}, + {"istream", jitsafe_header_istream}, + {"sstream", jitsafe_header_sstream}, + {"vector", jitsafe_header_vector}, + {"string", jitsafe_header_string}, + {"stdexcept", jitsafe_header_stdexcept}, + {"mutex", jitsafe_header_mutex}, + {"algorithm", jitsafe_header_algorithm}, + {"time.h", jitsafe_header_time_h}, + {"ctime", jitsafe_header_time_h}, + {"tuple", jitsafe_header_tuple}, + {"assert.h", jitsafe_header_assert}, + {"cassert", jitsafe_header_assert}}; + return jitsafe_headers_map; +} + +inline void add_options_from_env(std::vector& options) { + // Add options from environment variable + const char* env_options = std::getenv("JITIFY_OPTIONS"); + if (env_options) { + std::stringstream ss; + ss << env_options; + std::string opt; + while (!(ss >> opt).fail()) { + options.push_back(opt); + } + } + // Add options from JITIFY_OPTIONS macro +#ifdef JITIFY_OPTIONS +#define JITIFY_TOSTRING_IMPL(x) #x +#define JITIFY_TOSTRING(x) JITIFY_TOSTRING_IMPL(x) + std::stringstream ss; + ss << JITIFY_TOSTRING(JITIFY_OPTIONS); + std::string opt; + while (!(ss >> opt).fail()) { + options.push_back(opt); + } +#undef JITIFY_TOSTRING +#undef JITIFY_TOSTRING_IMPL +#endif // JITIFY_OPTIONS +} + +inline void detect_and_add_cuda_arch(std::vector& options) { + for (int i = 0; i < (int)options.size(); ++i) { + // Note that this will also match the middle of "--gpu-architecture". + if (options[i].find("-arch") != std::string::npos) { + // Arch already specified in options + return; + } + } + // Use the compute capability of the current device + // TODO: Check these API calls for errors + cudaError_t status; + int device; + status = cudaGetDevice(&device); + if (status != cudaSuccess) { + throw std::runtime_error( + std::string( + "Failed to detect GPU architecture: cudaGetDevice failed: ") + + cudaGetErrorString(status)); + } + int cc_major; + cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device); + int cc_minor; + cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, device); + int cc = cc_major * 10 + cc_minor; + // Note: We must limit the architecture to the max supported by the current + // version of NVRTC, otherwise newer hardware will cause errors + // on older versions of CUDA. + // TODO: It would be better to detect this somehow, rather than hard-coding it + + // Tegra chips do not have forwards compatibility so we need to special case + // them. + bool is_tegra = ((cc_major == 3 && cc_minor == 2) || // Logan + (cc_major == 5 && cc_minor == 3) || // Erista + (cc_major == 6 && cc_minor == 2) || // Parker + (cc_major == 7 && cc_minor == 2)); // Xavier + if (!is_tegra) { + // ensure that future CUDA versions just work (even if suboptimal) + const int cuda_major = std::min(10, CUDA_VERSION / 1000); + // clang-format off + switch (cuda_major) { + case 10: cc = std::min(cc, 75); break; // Turing + case 9: cc = std::min(cc, 70); break; // Volta + case 8: cc = std::min(cc, 61); break; // Pascal + case 7: cc = std::min(cc, 52); break; // Maxwell + default: + throw std::runtime_error("Unexpected CUDA major version " + + std::to_string(cuda_major)); + } + // clang-format on + } + + std::stringstream ss; + ss << cc; + options.push_back("-arch=compute_" + ss.str()); +} + +inline void detect_and_add_cxx11_flag(std::vector& options) { + // Reverse loop so we can erase on the fly. + for (int i = (int)options.size() - 1; i >= 0; --i) { + if (options[i].find("-std=c++98") != std::string::npos) { + // NVRTC doesn't support specifying c++98 explicitly, so we remove it. + options.erase(options.begin() + i); + return; + } else if (options[i].find("-std") != std::string::npos) { + // Some other standard was explicitly specified, don't change anything. + return; + } + } + // Jitify must be compiled with C++11 support, so we default to enabling it + // for the JIT-compiled code too. + options.push_back("-std=c++11"); +} + +inline void split_compiler_and_linker_options( + std::vector options, + std::vector* compiler_options, + std::vector* linker_files, + std::vector* linker_paths) { + for (int i = 0; i < (int)options.size(); ++i) { + std::string opt = options[i]; + std::string flag = opt.substr(0, 2); + std::string value = opt.substr(2); + if (flag == "-l") { + linker_files->push_back(value); + } else if (flag == "-L") { + linker_paths->push_back(value); + } else { + compiler_options->push_back(opt); + } + } +} + +inline bool pop_remove_unused_globals_flag(std::vector* options) { + auto it = std::remove_if( + options->begin(), options->end(), [](const std::string& opt) { + return opt.find("-remove-unused-globals") != std::string::npos; + }); + if (it != options->end()) { + options->resize(it - options->begin()); + return true; + } + return false; +} + +inline std::string ptx_parse_decl_name(const std::string& line) { + size_t name_end = line.find_first_of("[;"); + if (name_end == std::string::npos) { + throw std::runtime_error( + "Failed to parse .global/.const declaration in PTX: expected a " + "semicolon"); + } + size_t name_start_minus1 = line.find_last_of(" \t", name_end); + if (name_start_minus1 == std::string::npos) { + throw std::runtime_error( + "Failed to parse .global/.const declaration in PTX: expected " + "whitespace"); + } + size_t name_start = name_start_minus1 + 1; + std::string name = line.substr(name_start, name_end - name_start); + return name; +} + +inline void ptx_remove_unused_globals(std::string* ptx) { + std::istringstream iss(*ptx); + std::vector lines; + std::unordered_map line_num_to_global_name; + std::unordered_set name_set; + for (std::string line; std::getline(iss, line);) { + size_t line_num = lines.size(); + lines.push_back(line); + auto terms = split_string(line); + if (terms.size() <= 1) continue; // Ignore lines with no arguments + if (terms[0].substr(0, 2) == "//") continue; // Ignore comment lines + if (terms[0].substr(0, 7) == ".global" || + terms[0].substr(0, 6) == ".const") { + line_num_to_global_name.emplace(line_num, ptx_parse_decl_name(line)); + continue; + } + if (terms[0][0] == '.') continue; // Ignore .version, .reg, .param etc. + // Note: The first term will always be an instruction name; starting at 1 + // also allows unchecked inspection of the previous term. + for (int i = 1; i < (int)terms.size(); ++i) { + if (terms[i].substr(0, 2) == "//") break; // Ignore comments + // Note: The characters '.' and '%' are not treated as delimiters. + const char* token_delims = " \t()[]{},;+-*/~&|^?:=!<>\"'\\"; + for (auto token : split_string(terms[i], -1, token_delims)) { + if ( // Ignore non-names + !(std::isalpha(token[0]) || token[0] == '_' || token[0] == '$') || + token.find('.') != std::string::npos || + // Ignore variable/parameter declarations + terms[i - 1][0] == '.' || + // Ignore branch instructions + (token == "bra" && terms[i - 1][0] == '@') || + // Ignore branch labels + (token.substr(0, 2) == "BB" && + terms[i - 1].substr(0, 3) == "bra")) { + continue; + } + name_set.insert(token); + } + } + } + std::ostringstream oss; + for (size_t line_num = 0; line_num < lines.size(); ++line_num) { + auto it = line_num_to_global_name.find(line_num); + if (it != line_num_to_global_name.end()) { + const std::string& name = it->second; + if (!name_set.count(name)) { + continue; // Remove unused .global declaration. + } + } + oss << lines[line_num] << '\n'; + } + *ptx = oss.str(); +} + +inline nvrtcResult compile_kernel(std::string program_name, + std::map sources, + std::vector options, + std::string instantiation = "", + std::string* log = 0, std::string* ptx = 0, + std::string* mangled_instantiation = 0) { + std::string program_source = sources[program_name]; + // Build arrays of header names and sources + std::vector header_names_c; + std::vector header_sources_c; + int num_headers = (int)(sources.size() - 1); + header_names_c.reserve(num_headers); + header_sources_c.reserve(num_headers); + typedef std::map source_map; + for (source_map::const_iterator iter = sources.begin(); iter != sources.end(); + ++iter) { + std::string const& name = iter->first; + std::string const& code = iter->second; + if (name == program_name) { + continue; + } + header_names_c.push_back(name.c_str()); + header_sources_c.push_back(code.c_str()); + } + + // TODO: This WAR is expected to be unnecessary as of CUDA > 10.2. + bool should_remove_unused_globals = + detail::pop_remove_unused_globals_flag(&options); + + std::vector options_c(options.size() + 2); + options_c[0] = "--device-as-default-execution-space"; + options_c[1] = "--pre-include=jitify_preinclude.h"; + for (int i = 0; i < (int)options.size(); ++i) { + options_c[i + 2] = options[i].c_str(); + } + +#if CUDA_VERSION < 8000 + std::string inst_dummy; + if (!instantiation.empty()) { + // WAR for no nvrtcAddNameExpression before CUDA 8.0 + // Force template instantiation by adding dummy reference to kernel + inst_dummy = "__jitify_instantiation"; + program_source += + "\nvoid* " + inst_dummy + " = (void*)" + instantiation + ";\n"; + } +#endif + +#define CHECK_NVRTC(call) \ + do { \ + nvrtcResult check_nvrtc_macro_ret = call; \ + if (check_nvrtc_macro_ret != NVRTC_SUCCESS) { \ + return check_nvrtc_macro_ret; \ + } \ + } while (0) + + nvrtcProgram nvrtc_program; + CHECK_NVRTC(nvrtcCreateProgram( + &nvrtc_program, program_source.c_str(), program_name.c_str(), num_headers, + header_sources_c.data(), header_names_c.data())); + +#if CUDA_VERSION >= 8000 + if (!instantiation.empty()) { + CHECK_NVRTC(nvrtcAddNameExpression(nvrtc_program, instantiation.c_str())); + } +#endif + + nvrtcResult ret = nvrtcCompileProgram(nvrtc_program, (int)options_c.size(), + options_c.data()); + if (log) { + size_t logsize; + CHECK_NVRTC(nvrtcGetProgramLogSize(nvrtc_program, &logsize)); + std::vector vlog(logsize, 0); + CHECK_NVRTC(nvrtcGetProgramLog(nvrtc_program, vlog.data())); + log->assign(vlog.data(), logsize); + } + if (ret != NVRTC_SUCCESS) { + return ret; + } + + if (ptx) { + size_t ptxsize; + CHECK_NVRTC(nvrtcGetPTXSize(nvrtc_program, &ptxsize)); + std::vector vptx(ptxsize); + CHECK_NVRTC(nvrtcGetPTX(nvrtc_program, vptx.data())); + ptx->assign(vptx.data(), ptxsize); + if (should_remove_unused_globals) { + detail::ptx_remove_unused_globals(ptx); + } + } + + if (!instantiation.empty() && mangled_instantiation) { +#if CUDA_VERSION >= 8000 + const char* mangled_instantiation_cstr; + // Note: The returned string pointer becomes invalid after + // nvrtcDestroyProgram has been called, so we save it. + CHECK_NVRTC(nvrtcGetLoweredName(nvrtc_program, instantiation.c_str(), + &mangled_instantiation_cstr)); + *mangled_instantiation = mangled_instantiation_cstr; +#else + // Extract mangled kernel template instantiation from PTX + inst_dummy += " = "; // Note: This must match how the PTX is generated + int mi_beg = ptx->find(inst_dummy) + inst_dummy.size(); + int mi_end = ptx->find(";", mi_beg); + *mangled_instantiation = ptx->substr(mi_beg, mi_end - mi_beg); +#endif + } + + CHECK_NVRTC(nvrtcDestroyProgram(&nvrtc_program)); +#undef CHECK_NVRTC + return NVRTC_SUCCESS; +} + +inline void load_program(std::string const& cuda_source, + std::vector const& headers, + file_callback_type file_callback, + std::vector* include_paths, + std::map* program_sources, + std::vector* program_options, + std::string* program_name) { + // Extract include paths from compile options + std::vector::iterator iter = program_options->begin(); + while (iter != program_options->end()) { + std::string const& opt = *iter; + if (opt.substr(0, 2) == "-I") { + include_paths->push_back(opt.substr(2)); + iter = program_options->erase(iter); + } else { + ++iter; + } + } + + // Load program source + if (!detail::load_source(cuda_source, *program_sources, "", *include_paths, + file_callback, program_name)) { + throw std::runtime_error("Source not found: " + cuda_source); + } + + // Maps header include names to their full file paths. + std::map header_fullpaths; + + // Load header sources + for (std::string const& header : headers) { + if (!detail::load_source(header, *program_sources, "", *include_paths, + file_callback, nullptr, &header_fullpaths)) { + // **TODO: Deal with source not found + throw std::runtime_error("Source not found: " + header); + } + } + +#if JITIFY_PRINT_SOURCE + std::string& program_source = (*program_sources)[*program_name]; + std::cout << "---------------------------------------" << std::endl; + std::cout << "--- Source of " << *program_name << " ---" << std::endl; + std::cout << "---------------------------------------" << std::endl; + detail::print_with_line_numbers(program_source); + std::cout << "---------------------------------------" << std::endl; +#endif + + std::vector compiler_options, linker_files, linker_paths; + detail::split_compiler_and_linker_options(*program_options, &compiler_options, + &linker_files, &linker_paths); + + // If no arch is specified at this point we use whatever the current + // context is. This ensures we pick up the correct internal headers + // for arch-dependent compilation, e.g., some intrinsics are only + // present for specific architectures. + detail::detect_and_add_cuda_arch(compiler_options); + detail::detect_and_add_cxx11_flag(compiler_options); + + // Iteratively try to compile the sources, and use the resulting errors to + // identify missing headers. + std::string log; + nvrtcResult ret; + while ((ret = detail::compile_kernel(*program_name, *program_sources, + compiler_options, "", &log)) == + NVRTC_ERROR_COMPILATION) { + std::string include_name; + std::string include_parent; + int line_num = 0; + if (!detail::extract_include_info_from_compile_error( + log, include_name, include_parent, line_num)) { +#if JITIFY_PRINT_LOG + detail::print_compile_log(*program_name, log); +#endif + // There was a non include-related compilation error + // TODO: How to handle error? + throw std::runtime_error("Runtime compilation failed"); + } + + bool is_included_with_quotes = false; + if (program_sources->count(include_parent)) { + const std::string& parent_source = (*program_sources)[include_parent]; + is_included_with_quotes = + is_include_directive_with_quotes(parent_source, line_num); + } + + // Try to load the new header + // Note: This fullpath lookup is needed because the compiler error + // messages have the include name of the header instead of its full path. + std::string include_parent_fullpath = header_fullpaths[include_parent]; + std::string include_path = detail::path_base(include_parent_fullpath); + if (detail::load_source(include_name, *program_sources, include_path, + *include_paths, file_callback, nullptr, + &header_fullpaths, is_included_with_quotes)) { +#if JITIFY_PRINT_HEADER_PATHS + std::cout << "Found #include " << include_name << " from " + << include_parent << ":" << line_num << " [" + << include_parent_fullpath << "]" + << " at:\n " << header_fullpaths[include_name] << std::endl; +#endif + } else { // Failed to find header file. + // Comment-out the include line and print a warning + if (!program_sources->count(include_parent)) { + // ***TODO: Unless there's another mechanism (e.g., potentially + // the parent path vs. filename problem), getting + // here means include_parent was found automatically + // in a system include path. + // We need a WAR to zap it from *its parent*. + + typedef std::map source_map; + for (source_map::const_iterator it = program_sources->begin(); + it != program_sources->end(); ++it) { + std::cout << " " << it->first << std::endl; + } + throw std::out_of_range(include_parent + + " not in loaded sources!" + " This may be due to a header being loaded by" + " NVRTC without Jitify's knowledge."); + } + std::string& parent_source = (*program_sources)[include_parent]; + parent_source = detail::comment_out_code_line(line_num, parent_source); +#if JITIFY_PRINT_LOG + std::cout << include_parent << "(" << line_num + << "): warning: " << include_name << ": [jitify] File not found" + << std::endl; +#endif + } + } + if (ret != NVRTC_SUCCESS) { +#if JITIFY_PRINT_LOG + if (ret == NVRTC_ERROR_INVALID_OPTION) { + std::cout << "Compiler options: "; + for (int i = 0; i < (int)compiler_options.size(); ++i) { + std::cout << compiler_options[i] << " "; + } + std::cout << std::endl; + } +#endif + throw std::runtime_error(std::string("NVRTC error: ") + + nvrtcGetErrorString(ret)); + } +} + +inline void instantiate_kernel( + std::string const& program_name, + std::map const& program_sources, + std::string const& instantiation, std::vector const& options, + std::string* log, std::string* ptx, std::string* mangled_instantiation, + std::vector* linker_files, + std::vector* linker_paths) { + std::vector compiler_options; + detail::split_compiler_and_linker_options(options, &compiler_options, + linker_files, linker_paths); + + nvrtcResult ret = + detail::compile_kernel(program_name, program_sources, compiler_options, + instantiation, log, ptx, mangled_instantiation); +#if JITIFY_PRINT_LOG + if (log->size() > 1) { + detail::print_compile_log(program_name, *log); + } +#endif + if (ret != NVRTC_SUCCESS) { + throw std::runtime_error(std::string("NVRTC error: ") + + nvrtcGetErrorString(ret)); + } + +#if JITIFY_PRINT_PTX + std::cout << "---------------------------------------" << std::endl; + std::cout << *mangled_instantiation << std::endl; + std::cout << "---------------------------------------" << std::endl; + std::cout << "--- PTX for " << mangled_instantiation << " in " << program_name + << " ---" << std::endl; + std::cout << "---------------------------------------" << std::endl; + std::cout << *ptx << std::endl; + std::cout << "---------------------------------------" << std::endl; +#endif +} + +inline void get_1d_max_occupancy(CUfunction func, + CUoccupancyB2DSize smem_callback, + unsigned int* smem, int max_block_size, + unsigned int flags, int* grid, int* block) { + if (!func) { + throw std::runtime_error( + "Kernel pointer is NULL; you may need to define JITIFY_THREAD_SAFE " + "1"); + } + CUresult res = cuOccupancyMaxPotentialBlockSizeWithFlags( + grid, block, func, smem_callback, *smem, max_block_size, flags); + if (res != CUDA_SUCCESS) { + const char* msg; + cuGetErrorName(res, &msg); + throw std::runtime_error(msg); + } + if (smem_callback) { + *smem = (unsigned int)smem_callback(*block); + } +} + +} // namespace detail + +//! \endcond + +class KernelInstantiation; +class Kernel; +class Program; +class JitCache; + +struct ProgramConfig { + std::vector options; + std::vector include_paths; + std::string name; + typedef std::map source_map; + source_map sources; +}; + +class JitCache_impl { + friend class Program_impl; + friend class KernelInstantiation_impl; + friend class KernelLauncher_impl; + typedef uint64_t key_type; + jitify::ObjectCache _kernel_cache; + jitify::ObjectCache _program_config_cache; + std::vector _options; +#if JITIFY_THREAD_SAFE + std::mutex _kernel_cache_mutex; + std::mutex _program_cache_mutex; +#endif + public: + inline JitCache_impl(size_t cache_size) + : _kernel_cache(cache_size), _program_config_cache(cache_size) { + detail::add_options_from_env(_options); + + // Bootstrap the cuda context to avoid errors + cudaFree(0); + } +}; + +class Program_impl { + // A friendly class + friend class Kernel_impl; + friend class KernelLauncher_impl; + friend class KernelInstantiation_impl; + // TODO: This can become invalid if JitCache is destroyed before the + // Program object is. However, this can't happen if JitCache + // instances are static. + JitCache_impl& _cache; + uint64_t _hash; + ProgramConfig* _config; + void load_sources(std::string source, std::vector headers, + std::vector options, + file_callback_type file_callback); + + public: + inline Program_impl(JitCache_impl& cache, std::string source, + jitify::detail::vector headers = 0, + jitify::detail::vector options = 0, + file_callback_type file_callback = 0); + inline Program_impl(Program_impl const&) = default; + inline Program_impl(Program_impl&&) = default; + inline std::vector const& options() const { + return _config->options; + } + inline std::string const& name() const { return _config->name; } + inline ProgramConfig::source_map const& sources() const { + return _config->sources; + } + inline std::vector const& include_paths() const { + return _config->include_paths; + } +}; + +class Kernel_impl { + friend class KernelLauncher_impl; + friend class KernelInstantiation_impl; + Program_impl _program; + std::string _name; + std::vector _options; + uint64_t _hash; + + public: + inline Kernel_impl(Program_impl const& program, std::string name, + jitify::detail::vector options = 0); + inline Kernel_impl(Kernel_impl const&) = default; + inline Kernel_impl(Kernel_impl&&) = default; +}; + +class KernelInstantiation_impl { + friend class KernelLauncher_impl; + Kernel_impl _kernel; + uint64_t _hash; + std::string _template_inst; + std::vector _options; + detail::CUDAKernel* _cuda_kernel; + inline void print() const; + void build_kernel(); + + public: + inline KernelInstantiation_impl( + Kernel_impl const& kernel, std::vector const& template_args); + inline KernelInstantiation_impl(KernelInstantiation_impl const&) = default; + inline KernelInstantiation_impl(KernelInstantiation_impl&&) = default; + detail::CUDAKernel const& cuda_kernel() const { return *_cuda_kernel; } +}; + +class KernelLauncher_impl { + KernelInstantiation_impl _kernel_inst; + dim3 _grid; + dim3 _block; + unsigned int _smem; + cudaStream_t _stream; + + public: + inline KernelLauncher_impl(KernelInstantiation_impl const& kernel_inst, + dim3 grid, dim3 block, unsigned int smem = 0, + cudaStream_t stream = 0) + : _kernel_inst(kernel_inst), + _grid(grid), + _block(block), + _smem(smem), + _stream(stream) {} + inline KernelLauncher_impl(KernelLauncher_impl const&) = default; + inline KernelLauncher_impl(KernelLauncher_impl&&) = default; + inline CUresult launch( + jitify::detail::vector arg_ptrs, + jitify::detail::vector arg_types = 0) const; + inline void safe_launch( + jitify::detail::vector arg_ptrs, + jitify::detail::vector arg_types = 0) const; + + private: + inline void pre_launch( + jitify::detail::vector arg_types = 0) const; +}; + +/*! An object representing a configured and instantiated kernel ready + * for launching. + */ +class KernelLauncher { + std::unique_ptr _impl; + + public: + inline KernelLauncher(KernelInstantiation const& kernel_inst, dim3 grid, + dim3 block, unsigned int smem = 0, + cudaStream_t stream = 0); + + // Note: It's important that there is no implicit conversion required + // for arg_ptrs, because otherwise the parameter pack version + // below gets called instead (probably resulting in a segfault). + /*! Launch the kernel. + * + * \param arg_ptrs A vector of pointers to each function argument for the + * kernel. + * \param arg_types A vector of function argument types represented + * as code-strings. This parameter is optional and is only used to print + * out the function signature. + */ + inline CUresult launch( + std::vector arg_ptrs = std::vector(), + jitify::detail::vector arg_types = 0) const { + return _impl->launch(arg_ptrs, arg_types); + } + + /*! Launch the kernel and check for cuda errors. + * + * \see launch + */ + inline void safe_launch( + std::vector arg_ptrs = std::vector(), + jitify::detail::vector arg_types = 0) const { + _impl->safe_launch(arg_ptrs, arg_types); + } + + // Regular function call syntax + /*! Launch the kernel. + * + * \see launch + */ + template + inline CUresult operator()(const ArgTypes&... args) const { + return this->launch(args...); + } + /*! Launch the kernel. + * + * \param args Function arguments for the kernel. + */ + template + inline CUresult launch(const ArgTypes&... args) const { + return this->launch(std::vector({(void*)&args...}), + {reflection::reflect()...}); + } + /*! Launch the kernel and check for cuda errors. + * + * \param args Function arguments for the kernel. + */ + template + inline void safe_launch(const ArgTypes&... args) const { + this->safe_launch(std::vector({(void*)&args...}), + {reflection::reflect()...}); + } +}; + +/*! An object representing a kernel instantiation made up of a Kernel and + * template arguments. + */ +class KernelInstantiation { + friend class KernelLauncher; + std::unique_ptr _impl; + + public: + inline KernelInstantiation(Kernel const& kernel, + std::vector const& template_args); + + /*! Implicit conversion to the underlying CUfunction object. + * + * \note This allows use of CUDA APIs like + * cuOccupancyMaxActiveBlocksPerMultiprocessor. + */ + inline operator CUfunction() const { return _impl->cuda_kernel(); } + + /*! Configure the kernel launch. + * + * \see configure + */ + inline KernelLauncher operator()(dim3 grid, dim3 block, unsigned int smem = 0, + cudaStream_t stream = 0) const { + return this->configure(grid, block, smem, stream); + } + /*! Configure the kernel launch. + * + * \param grid The thread grid dimensions for the launch. + * \param block The thread block dimensions for the launch. + * \param smem The amount of shared memory to dynamically allocate, in + * bytes. + * \param stream The CUDA stream to launch the kernel in. + */ + inline KernelLauncher configure(dim3 grid, dim3 block, unsigned int smem = 0, + cudaStream_t stream = 0) const { + return KernelLauncher(*this, grid, block, smem, stream); + } + /*! Configure the kernel launch with a 1-dimensional block and grid chosen + * automatically to maximise occupancy. + * + * \param max_block_size The upper limit on the block size, or 0 for no + * limit. + * \param smem The amount of shared memory to dynamically allocate, in bytes. + * \param smem_callback A function returning smem for a given block size (overrides \p smem). + * \param stream The CUDA stream to launch the kernel in. + * \param flags The flags to pass to cuOccupancyMaxPotentialBlockSizeWithFlags. + */ + inline KernelLauncher configure_1d_max_occupancy( + int max_block_size = 0, unsigned int smem = 0, + CUoccupancyB2DSize smem_callback = 0, cudaStream_t stream = 0, + unsigned int flags = 0) const { + int grid; + int block; + CUfunction func = _impl->cuda_kernel(); + detail::get_1d_max_occupancy(func, smem_callback, &smem, max_block_size, + flags, &grid, &block); + return this->configure(grid, block, smem, stream); + } + + /* + * Returns the function attribute requested from the kernel + */ + inline int get_func_attribute(CUfunction_attribute attribute) const { + return _impl->cuda_kernel().get_func_attribute(attribute); + } + + /* + * Set the function attribute requested for the kernel + */ + inline void set_func_attribute(CUfunction_attribute attribute, + int value) const { + _impl->cuda_kernel().set_func_attribute(attribute, value); + } + + /* + * \deprecated Use \p get_global_ptr instead. + */ + inline CUdeviceptr get_constant_ptr(const char* name, + size_t* size = nullptr) const { + return get_global_ptr(name, size); + } + + /* + * Get a device pointer to a global __constant__ or __device__ variable using + * its un-mangled name. If provided, *size is set to the size of the variable + * in bytes. + */ + inline CUdeviceptr get_global_ptr(const char* name, + size_t* size = nullptr) const { + return _impl->cuda_kernel().get_global_ptr(name, size); + } + + /* + * Copy data from a global __constant__ or __device__ array to the host using + * its un-mangled name. + */ + template + inline CUresult get_global_array(const char* name, T* data, size_t count, + CUstream stream = 0) const { + return _impl->cuda_kernel().get_global_data(name, data, count, stream); + } + + /* + * Copy a value from a global __constant__ or __device__ variable to the host + * using its un-mangled name. + */ + template + inline CUresult get_global_value(const char* name, T* value, + CUstream stream = 0) const { + return get_global_array(name, value, 1, stream); + } + + /* + * Copy data from the host to a global __constant__ or __device__ array using + * its un-mangled name. + */ + template + inline CUresult set_global_array(const char* name, const T* data, + size_t count, CUstream stream = 0) const { + return _impl->cuda_kernel().set_global_data(name, data, count, stream); + } + + /* + * Copy a value from the host to a global __constant__ or __device__ variable + * using its un-mangled name. + */ + template + inline CUresult set_global_value(const char* name, const T& value, + CUstream stream = 0) const { + return set_global_array(name, &value, 1, stream); + } + + const std::string& mangled_name() const { + return _impl->cuda_kernel().function_name(); + } + + const std::string& ptx() const { return _impl->cuda_kernel().ptx(); } + + const std::vector& link_files() const { + return _impl->cuda_kernel().link_files(); + } + + const std::vector& link_paths() const { + return _impl->cuda_kernel().link_paths(); + } +}; + +/*! An object representing a kernel made up of a Program, a name and options. + */ +class Kernel { + friend class KernelInstantiation; + std::unique_ptr _impl; + + public: + Kernel(Program const& program, std::string name, + jitify::detail::vector options = 0); + + /*! Instantiate the kernel. + * + * \param template_args A vector of template arguments represented as + * code-strings. These can be generated using + * \code{.cpp}jitify::reflection::reflect()\endcode or + * \code{.cpp}jitify::reflection::reflect(value)\endcode + * + * \note Template type deduction is not possible, so all types must be + * explicitly specified. + */ + // inline KernelInstantiation instantiate(std::vector const& + // template_args) const { + inline KernelInstantiation instantiate( + std::vector const& template_args = + std::vector()) const { + return KernelInstantiation(*this, template_args); + } + + // Regular template instantiation syntax (note limited flexibility) + /*! Instantiate the kernel. + * + * \note The template arguments specified on this function are + * used to instantiate the kernel. Non-type template arguments must + * be wrapped with + * \code{.cpp}jitify::reflection::NonType\endcode + * + * \note Template type deduction is not possible, so all types must be + * explicitly specified. + */ + template + inline KernelInstantiation instantiate() const { + return this->instantiate( + std::vector({reflection::reflect()...})); + } + // Template-like instantiation syntax + // E.g., instantiate(myvar,Type())(grid,block) + /*! Instantiate the kernel. + * + * \param targs The template arguments for the kernel, represented as + * values. Types must be wrapped with + * \code{.cpp}jitify::reflection::Type()\endcode or + * \code{.cpp}jitify::reflection::type_of(value)\endcode + * + * \note Template type deduction is not possible, so all types must be + * explicitly specified. + */ + template + inline KernelInstantiation instantiate(TemplateArgs... targs) const { + return this->instantiate( + std::vector({reflection::reflect(targs)...})); + } +}; + +/*! An object representing a program made up of source code, headers + * and options. + */ +class Program { + friend class Kernel; + std::unique_ptr _impl; + + public: + Program(JitCache& cache, std::string source, + jitify::detail::vector headers = 0, + jitify::detail::vector options = 0, + file_callback_type file_callback = 0); + + /*! Select a kernel. + * + * \param name The name of the kernel (unmangled and without + * template arguments). + * \param options A vector of options to be passed to the NVRTC + * compiler when compiling this kernel. + */ + inline Kernel kernel(std::string name, + jitify::detail::vector options = 0) const { + return Kernel(*this, name, options); + } + /*! Select a kernel. + * + * \see kernel + */ + inline Kernel operator()( + std::string name, jitify::detail::vector options = 0) const { + return this->kernel(name, options); + } +}; + +/*! An object that manages a cache of JIT-compiled CUDA kernels. + * + */ +class JitCache { + friend class Program; + std::unique_ptr _impl; + + public: + /*! JitCache constructor. + * \param cache_size The number of kernels to hold in the cache + * before overwriting the least-recently-used ones. + */ + enum { DEFAULT_CACHE_SIZE = 128 }; + JitCache(size_t cache_size = DEFAULT_CACHE_SIZE) + : _impl(new JitCache_impl(cache_size)) {} + + /*! Create a program. + * + * \param source A string containing either the source filename or + * the source itself; in the latter case, the first line must be + * the name of the program. + * \param headers A vector of strings representing the source of + * each header file required by the program. Each entry can be + * either the header filename or the header source itself; in + * the latter case, the first line must be the name of the header + * (i.e., the name by which the header is #included). + * \param options A vector of options to be passed to the + * NVRTC compiler. Include paths specified with \p -I + * are added to the search paths used by Jitify. The environment + * variable JITIFY_OPTIONS can also be used to define additional + * options. + * \param file_callback A pointer to a callback function that is + * invoked whenever a source file needs to be loaded. Inside this + * function, the user can either load/specify the source themselves + * or defer to Jitify's file-loading mechanisms. + * \note Program or header source files referenced by filename are + * looked-up using the following mechanisms (in this order): + * \note 1) By calling file_callback. + * \note 2) By looking for the file embedded in the executable via the GCC + * linker. + * \note 3) By looking for the file in the filesystem. + * + * \note Jitify recursively scans all source files for \p #include + * directives and automatically adds them to the set of headers needed + * by the program. + * If a \p #include directive references a header that cannot be found, + * the directive is automatically removed from the source code to prevent + * immediate compilation failure. This may result in compilation errors + * if the header was required by the program. + * + * \note Jitify automatically includes NVRTC-safe versions of some + * standard library headers. + */ + inline Program program(std::string source, + jitify::detail::vector headers = 0, + jitify::detail::vector options = 0, + file_callback_type file_callback = 0) { + return Program(*this, source, headers, options, file_callback); + } +}; + +inline Program::Program(JitCache& cache, std::string source, + jitify::detail::vector headers, + jitify::detail::vector options, + file_callback_type file_callback) + : _impl(new Program_impl(*cache._impl, source, headers, options, + file_callback)) {} + +inline Kernel::Kernel(Program const& program, std::string name, + jitify::detail::vector options) + : _impl(new Kernel_impl(*program._impl, name, options)) {} + +inline KernelInstantiation::KernelInstantiation( + Kernel const& kernel, std::vector const& template_args) + : _impl(new KernelInstantiation_impl(*kernel._impl, template_args)) {} + +inline KernelLauncher::KernelLauncher(KernelInstantiation const& kernel_inst, + dim3 grid, dim3 block, unsigned int smem, + cudaStream_t stream) + : _impl(new KernelLauncher_impl(*kernel_inst._impl, grid, block, smem, + stream)) {} + +inline std::ostream& operator<<(std::ostream& stream, dim3 d) { + if (d.y == 1 && d.z == 1) { + stream << d.x; + } else { + stream << "(" << d.x << "," << d.y << "," << d.z << ")"; + } + return stream; +} + +inline void KernelLauncher_impl::pre_launch( + jitify::detail::vector arg_types) const { + (void)arg_types; +#if JITIFY_PRINT_LAUNCH + Kernel_impl const& kernel = _kernel_inst._kernel; + std::string arg_types_string = + (arg_types.empty() ? "..." : reflection::reflect_list(arg_types)); + std::cout << "Launching " << kernel._name << _kernel_inst._template_inst + << "<<<" << _grid << "," << _block << "," << _smem << "," << _stream + << ">>>" + << "(" << arg_types_string << ")" << std::endl; +#endif + if (!_kernel_inst._cuda_kernel) { + throw std::runtime_error( + "Kernel pointer is NULL; you may need to define JITIFY_THREAD_SAFE 1"); + } +} + +inline CUresult KernelLauncher_impl::launch( + jitify::detail::vector arg_ptrs, + jitify::detail::vector arg_types) const { + pre_launch(arg_types); + return _kernel_inst._cuda_kernel->launch(_grid, _block, _smem, _stream, + arg_ptrs); +} + +inline void KernelLauncher_impl::safe_launch( + jitify::detail::vector arg_ptrs, + jitify::detail::vector arg_types) const { + pre_launch(arg_types); + _kernel_inst._cuda_kernel->safe_launch(_grid, _block, _smem, _stream, + arg_ptrs); +} + +inline KernelInstantiation_impl::KernelInstantiation_impl( + Kernel_impl const& kernel, std::vector const& template_args) + : _kernel(kernel), _options(kernel._options) { + _template_inst = + (template_args.empty() ? "" + : reflection::reflect_template(template_args)); + using detail::hash_combine; + using detail::hash_larson64; + _hash = _kernel._hash; + _hash = hash_combine(_hash, hash_larson64(_template_inst.c_str())); + JitCache_impl& cache = _kernel._program._cache; + uint64_t cache_key = _hash; +#if JITIFY_THREAD_SAFE + std::lock_guard lock(cache._kernel_cache_mutex); +#endif + if (cache._kernel_cache.contains(cache_key)) { +#if JITIFY_PRINT_INSTANTIATION + std::cout << "Found "; + this->print(); +#endif + _cuda_kernel = &cache._kernel_cache.get(cache_key); + } else { +#if JITIFY_PRINT_INSTANTIATION + std::cout << "Building "; + this->print(); +#endif + _cuda_kernel = &cache._kernel_cache.emplace(cache_key); + this->build_kernel(); + } +} + +inline void KernelInstantiation_impl::print() const { + std::string options_string = reflection::reflect_list(_options); + std::cout << _kernel._name << _template_inst << " [" << options_string << "]" + << std::endl; +} + +inline void KernelInstantiation_impl::build_kernel() { + Program_impl const& program = _kernel._program; + + std::string instantiation = _kernel._name + _template_inst; + + std::string log, ptx, mangled_instantiation; + std::vector linker_files, linker_paths; + detail::instantiate_kernel(program.name(), program.sources(), instantiation, + _options, &log, &ptx, &mangled_instantiation, + &linker_files, &linker_paths); + + _cuda_kernel->set(mangled_instantiation.c_str(), ptx.c_str(), linker_files, + linker_paths); +} + +Kernel_impl::Kernel_impl(Program_impl const& program, std::string name, + jitify::detail::vector options) + : _program(program), _name(name), _options(options) { + // Merge options from parent + _options.insert(_options.end(), _program.options().begin(), + _program.options().end()); + detail::detect_and_add_cuda_arch(_options); + detail::detect_and_add_cxx11_flag(_options); + std::string options_string = reflection::reflect_list(_options); + using detail::hash_combine; + using detail::hash_larson64; + _hash = _program._hash; + _hash = hash_combine(_hash, hash_larson64(_name.c_str())); + _hash = hash_combine(_hash, hash_larson64(options_string.c_str())); +} + +Program_impl::Program_impl(JitCache_impl& cache, std::string source, + jitify::detail::vector headers, + jitify::detail::vector options, + file_callback_type file_callback) + : _cache(cache) { + // Compute hash of source, headers and options + std::string options_string = reflection::reflect_list(options); + using detail::hash_combine; + using detail::hash_larson64; + _hash = hash_combine(hash_larson64(source.c_str()), + hash_larson64(options_string.c_str())); + for (size_t i = 0; i < headers.size(); ++i) { + _hash = hash_combine(_hash, hash_larson64(headers[i].c_str())); + } + _hash = hash_combine(_hash, (uint64_t)file_callback); + // Add pre-include built-in JIT-safe headers + for (int i = 0; i < detail::preinclude_jitsafe_headers_count; ++i) { + const char* hdr_name = detail::preinclude_jitsafe_header_names[i]; + const std::string& hdr_source = + detail::get_jitsafe_headers_map().at(hdr_name); + headers.push_back(std::string(hdr_name) + "\n" + hdr_source); + } + // Merge options from parent + options.insert(options.end(), _cache._options.begin(), _cache._options.end()); + // Load sources +#if JITIFY_THREAD_SAFE + std::lock_guard lock(cache._program_cache_mutex); +#endif + if (!cache._program_config_cache.contains(_hash)) { + _config = &cache._program_config_cache.insert(_hash); + this->load_sources(source, headers, options, file_callback); + } else { + _config = &cache._program_config_cache.get(_hash); + } +} + +inline void Program_impl::load_sources(std::string source, + std::vector headers, + std::vector options, + file_callback_type file_callback) { + _config->options = options; + detail::load_program(source, headers, file_callback, &_config->include_paths, + &_config->sources, &_config->options, &_config->name); +} + +enum Location { HOST, DEVICE }; + +/*! Specifies location and parameters for execution of an algorithm. + * \param stream The CUDA stream on which to execute. + * \param headers A vector of headers to include in the code. + * \param options Options to pass to the NVRTC compiler. + * \param file_callback See jitify::Program. + * \param block_size The size of the CUDA thread block with which to + * execute. + * \param cache_size The number of kernels to store in the cache + * before overwriting the least-recently-used ones. + */ +struct ExecutionPolicy { + /*! Location (HOST or DEVICE) on which to execute.*/ + Location location; + /*! List of headers to include when compiling the algorithm.*/ + std::vector headers; + /*! List of compiler options.*/ + std::vector options; + /*! Optional callback for loading source files.*/ + file_callback_type file_callback; + /*! CUDA stream on which to execute.*/ + cudaStream_t stream; + /*! CUDA device on which to execute.*/ + int device; + /*! CUDA block size with which to execute.*/ + int block_size; + /*! The number of instantiations to store in the cache before overwriting + * the least-recently-used ones.*/ + size_t cache_size; + ExecutionPolicy(Location location_ = DEVICE, + jitify::detail::vector headers_ = 0, + jitify::detail::vector options_ = 0, + file_callback_type file_callback_ = 0, + cudaStream_t stream_ = 0, int device_ = 0, + int block_size_ = 256, + size_t cache_size_ = JitCache::DEFAULT_CACHE_SIZE) + : location(location_), + headers(headers_), + options(options_), + file_callback(file_callback_), + stream(stream_), + device(device_), + block_size(block_size_), + cache_size(cache_size_) {} +}; + +template +class Lambda; + +/*! An object that captures a set of variables for use in a parallel_for + * expression. See JITIFY_CAPTURE(). + */ +class Capture { + public: + std::vector _arg_decls; + std::vector _arg_ptrs; + + public: + template + inline Capture(std::vector arg_names, Args const&... args) + : _arg_ptrs{(void*)&args...} { + std::vector arg_types = {reflection::reflect()...}; + _arg_decls.resize(arg_names.size()); + for (int i = 0; i < (int)arg_names.size(); ++i) { + _arg_decls[i] = arg_types[i] + " " + arg_names[i]; + } + } +}; + +/*! An object that captures the instantiated Lambda function for use + in a parallel_for expression and the function string for NVRTC + compilation + */ +template +class Lambda { + public: + Capture _capture; + std::string _func_string; + Func _func; + + public: + inline Lambda(Capture const& capture, std::string func_string, Func func) + : _capture(capture), _func_string(func_string), _func(func) {} +}; + +template +inline Lambda make_Lambda(Capture const& capture, std::string func, + T lambda) { + return Lambda(capture, func, lambda); +} + +#define JITIFY_CAPTURE(...) \ + jitify::Capture(jitify::detail::split_string(#__VA_ARGS__, -1, ","), \ + __VA_ARGS__) + +#define JITIFY_MAKE_LAMBDA(capture, x, ...) \ + jitify::make_Lambda(capture, std::string(#__VA_ARGS__), \ + [x](int i) { __VA_ARGS__; }) + +#define JITIFY_ARGS(...) __VA_ARGS__ + +#define JITIFY_LAMBDA_(x, ...) \ + JITIFY_MAKE_LAMBDA(JITIFY_CAPTURE(x), JITIFY_ARGS(x), __VA_ARGS__) + +// macro sequence to strip surrounding brackets +#define JITIFY_STRIP_PARENS(X) X +#define JITIFY_PASS_PARAMETERS(X) JITIFY_STRIP_PARENS(JITIFY_ARGS X) + +/*! Creates a Lambda object with captured variables and a function + * definition. + * \param capture A bracket-enclosed list of variables to capture. + * \param ... The function definition. + * + * \code{.cpp} + * float* capture_me; + * int capture_me_too; + * auto my_lambda = JITIFY_LAMBDA( (capture_me, capture_me_too), + * capture_me[i] = i*capture_me_too ); + * \endcode + */ +#define JITIFY_LAMBDA(capture, ...) \ + JITIFY_LAMBDA_(JITIFY_ARGS(JITIFY_PASS_PARAMETERS(capture)), \ + JITIFY_ARGS(__VA_ARGS__)) + +// TODO: Try to implement for_each that accepts iterators instead of indices +// Add compile guard for NOCUDA compilation +/*! Call a function for a range of indices + * + * \param policy Determines the location and device parameters for + * execution of the parallel_for. + * \param begin The starting index. + * \param end The ending index. + * \param lambda A Lambda object created using the JITIFY_LAMBDA() macro. + * + * \code{.cpp} + * char const* in; + * float* out; + * parallel_for(0, 100, JITIFY_LAMBDA( (in, out), {char x = in[i]; out[i] = + * x*x; } ); \endcode + */ +template +CUresult parallel_for(ExecutionPolicy policy, IndexType begin, IndexType end, + Lambda const& lambda) { + using namespace jitify; + + if (policy.location == HOST) { +#ifdef _OPENMP +#pragma omp parallel for +#endif + for (IndexType i = begin; i < end; i++) { + lambda._func(i); + } + return CUDA_SUCCESS; // FIXME - replace with non-CUDA enum type? + } + + thread_local static JitCache kernel_cache(policy.cache_size); + + std::vector arg_decls; + arg_decls.push_back("I begin, I end"); + arg_decls.insert(arg_decls.end(), lambda._capture._arg_decls.begin(), + lambda._capture._arg_decls.end()); + + std::stringstream source_ss; + source_ss << "parallel_for_program\n"; + for (auto const& header : policy.headers) { + std::string header_name = header.substr(0, header.find("\n")); + source_ss << "#include <" << header_name << ">\n"; + } + source_ss << "template\n" + "__global__\n" + "void parallel_for_kernel(" + << reflection::reflect_list(arg_decls) + << ") {\n" + " I i0 = threadIdx.x + blockDim.x*blockIdx.x;\n" + " for( I i=i0+begin; i arg_ptrs; + arg_ptrs.push_back(&begin); + arg_ptrs.push_back(&end); + arg_ptrs.insert(arg_ptrs.end(), lambda._capture._arg_ptrs.begin(), + lambda._capture._arg_ptrs.end()); + + size_t n = end - begin; + dim3 block(policy.block_size); + dim3 grid((unsigned int)std::min((n - 1) / block.x + 1, size_t(65535))); + cudaSetDevice(policy.device); + return program.kernel("parallel_for_kernel") + .instantiate() + .configure(grid, block, 0, policy.stream) + .launch(arg_ptrs); +} + +namespace experimental { + +using jitify::file_callback_type; + +namespace serialization { + +namespace detail { + +// This should be incremented whenever the serialization format changes in any +// incompatible way. +static constexpr const size_t kSerializationVersion = 2; + +inline void serialize(std::ostream& stream, size_t u) { + uint64_t u64 = u; + char bytes[8]; + for (int i = 0; i < (int)sizeof(bytes); ++i) { + // Convert to little-endian bytes. + bytes[i] = (unsigned char)(u64 >> (i * CHAR_BIT)); + } + stream.write(bytes, sizeof(bytes)); +} + +inline bool deserialize(std::istream& stream, size_t* size) { + char bytes[8]; + stream.read(bytes, sizeof(bytes)); + uint64_t u64 = 0; + for (int i = 0; i < (int)sizeof(bytes); ++i) { + // Convert from little-endian bytes. + u64 |= uint64_t((unsigned char)(bytes[i])) << (i * CHAR_BIT); + } + *size = u64; + return stream.good(); +} + +inline void serialize(std::ostream& stream, std::string const& s) { + serialize(stream, s.size()); + stream.write(s.data(), s.size()); +} + +inline bool deserialize(std::istream& stream, std::string* s) { + size_t size; + if (!deserialize(stream, &size)) return false; + s->resize(size); + if (s->size()) { + stream.read(&(*s)[0], s->size()); + } + return stream.good(); +} + +inline void serialize(std::ostream& stream, std::vector const& v) { + serialize(stream, v.size()); + for (auto const& s : v) { + serialize(stream, s); + } +} + +inline bool deserialize(std::istream& stream, std::vector* v) { + size_t size; + if (!deserialize(stream, &size)) return false; + v->resize(size); + for (auto& s : *v) { + if (!deserialize(stream, &s)) return false; + } + return true; +} + +inline void serialize(std::ostream& stream, + std::map const& m) { + serialize(stream, m.size()); + for (auto const& kv : m) { + serialize(stream, kv.first); + serialize(stream, kv.second); + } +} + +inline bool deserialize(std::istream& stream, + std::map* m) { + size_t size; + if (!deserialize(stream, &size)) return false; + for (size_t i = 0; i < size; ++i) { + std::string key; + if (!deserialize(stream, &key)) return false; + if (!deserialize(stream, &(*m)[key])) return false; + } + return true; +} + +template +inline void serialize(std::ostream& stream, T const& value, Rest... rest) { + serialize(stream, value); + serialize(stream, rest...); +} + +template +inline bool deserialize(std::istream& stream, T* value, Rest... rest) { + if (!deserialize(stream, value)) return false; + return deserialize(stream, rest...); +} + +inline void serialize_magic_number(std::ostream& stream) { + stream.write("JTFY", 4); + serialize(stream, kSerializationVersion); +} + +inline bool deserialize_magic_number(std::istream& stream) { + char magic_number[4] = {0, 0, 0, 0}; + stream.read(&magic_number[0], 4); + if (!(magic_number[0] == 'J' && magic_number[1] == 'T' && + magic_number[2] == 'F' && magic_number[3] == 'Y')) { + return false; + } + size_t serialization_version; + if (!deserialize(stream, &serialization_version)) return false; + return serialization_version == kSerializationVersion; +} + +} // namespace detail + +template +inline std::string serialize(Values const&... values) { + std::ostringstream ss(std::stringstream::out | std::stringstream::binary); + detail::serialize_magic_number(ss); + detail::serialize(ss, values...); + return ss.str(); +} + +template +inline bool deserialize(std::string const& serialized, Values*... values) { + std::istringstream ss(serialized, + std::stringstream::in | std::stringstream::binary); + if (!detail::deserialize_magic_number(ss)) return false; + return detail::deserialize(ss, values...); +} + +} // namespace serialization + +class Program; +class Kernel; +class KernelInstantiation; +class KernelLauncher; + +/*! An object representing a program made up of source code, headers + * and options. + */ +class Program { + private: + friend class KernelInstantiation; + std::string _name; + std::vector _options; + std::map _sources; + + // Private constructor used by deserialize() + Program() {} + + public: + /*! Create a program. + * + * \param source A string containing either the source filename or + * the source itself; in the latter case, the first line must be + * the name of the program. + * \param headers A vector of strings representing the source of + * each header file required by the program. Each entry can be + * either the header filename or the header source itself; in + * the latter case, the first line must be the name of the header + * (i.e., the name by which the header is #included). + * \param options A vector of options to be passed to the + * NVRTC compiler. Include paths specified with \p -I + * are added to the search paths used by Jitify. The environment + * variable JITIFY_OPTIONS can also be used to define additional + * options. + * \param file_callback A pointer to a callback function that is + * invoked whenever a source file needs to be loaded. Inside this + * function, the user can either load/specify the source themselves + * or defer to Jitify's file-loading mechanisms. + * \note Program or header source files referenced by filename are + * looked-up using the following mechanisms (in this order): + * \note 1) By calling file_callback. + * \note 2) By looking for the file embedded in the executable via the GCC + * linker. + * \note 3) By looking for the file in the filesystem. + * + * \note Jitify recursively scans all source files for \p #include + * directives and automatically adds them to the set of headers needed + * by the program. + * If a \p #include directive references a header that cannot be found, + * the directive is automatically removed from the source code to prevent + * immediate compilation failure. This may result in compilation errors + * if the header was required by the program. + * + * \note Jitify automatically includes NVRTC-safe versions of some + * standard library headers. + */ + Program(std::string const& cuda_source, + std::vector const& given_headers = {}, + std::vector const& given_options = {}, + file_callback_type file_callback = nullptr) { + // Add pre-include built-in JIT-safe headers + std::vector headers = given_headers; + for (int i = 0; i < detail::preinclude_jitsafe_headers_count; ++i) { + const char* hdr_name = detail::preinclude_jitsafe_header_names[i]; + const std::string& hdr_source = + detail::get_jitsafe_headers_map().at(hdr_name); + headers.push_back(std::string(hdr_name) + "\n" + hdr_source); + } + + _options = given_options; + detail::add_options_from_env(_options); + std::vector include_paths; + detail::load_program(cuda_source, headers, file_callback, &include_paths, + &_sources, &_options, &_name); + } + + /*! Restore a serialized program. + * + * \param serialized_program The serialized program to restore. + * + * \see serialize + */ + static Program deserialize(std::string const& serialized_program) { + Program program; + if (!serialization::deserialize(serialized_program, &program._name, + &program._options, &program._sources)) { + throw std::runtime_error("Failed to deserialize program"); + } + return program; + } + + /*! Save the program. + * + * \see deserialize + */ + std::string serialize() const { + // Note: Must update kSerializationVersion if this is changed. + return serialization::serialize(_name, _options, _sources); + }; + + /*! Select a kernel. + * + * \param name The name of the kernel (unmangled and without + * template arguments). + * \param options A vector of options to be passed to the NVRTC + * compiler when compiling this kernel. + */ + Kernel kernel(std::string const& name, + std::vector const& options = {}) const; +}; + +class Kernel { + friend class KernelInstantiation; + Program const* _program; + std::string _name; + std::vector _options; + + public: + Kernel(Program const* program, std::string const& name, + std::vector const& options = {}) + : _program(program), _name(name), _options(options) {} + + /*! Instantiate the kernel. + * + * \param template_args A vector of template arguments represented as + * code-strings. These can be generated using + * \code{.cpp}jitify::reflection::reflect()\endcode or + * \code{.cpp}jitify::reflection::reflect(value)\endcode + * + * \note Template type deduction is not possible, so all types must be + * explicitly specified. + */ + KernelInstantiation instantiate( + std::vector const& template_args = + std::vector()) const; + + // Regular template instantiation syntax (note limited flexibility) + /*! Instantiate the kernel. + * + * \note The template arguments specified on this function are + * used to instantiate the kernel. Non-type template arguments must + * be wrapped with + * \code{.cpp}jitify::reflection::NonType\endcode + * + * \note Template type deduction is not possible, so all types must be + * explicitly specified. + */ + template + KernelInstantiation instantiate() const; + + // Template-like instantiation syntax + // E.g., instantiate(myvar,Type())(grid,block) + /*! Instantiate the kernel. + * + * \param targs The template arguments for the kernel, represented as + * values. Types must be wrapped with + * \code{.cpp}jitify::reflection::Type()\endcode or + * \code{.cpp}jitify::reflection::type_of(value)\endcode + * + * \note Template type deduction is not possible, so all types must be + * explicitly specified. + */ + template + KernelInstantiation instantiate(TemplateArgs... targs) const; +}; + +class KernelInstantiation { + friend class KernelLauncher; + std::unique_ptr _cuda_kernel; + + // Private constructor used by deserialize() + KernelInstantiation(std::string const& func_name, std::string const& ptx, + std::vector const& link_files, + std::vector const& link_paths) + : _cuda_kernel(new detail::CUDAKernel(func_name.c_str(), ptx.c_str(), + link_files, link_paths)) {} + + public: + KernelInstantiation(Kernel const& kernel, + std::vector const& template_args) { + Program const* program = kernel._program; + + std::string template_inst = + (template_args.empty() ? "" + : reflection::reflect_template(template_args)); + std::string instantiation = kernel._name + template_inst; + + std::vector options; + options.insert(options.begin(), program->_options.begin(), + program->_options.end()); + options.insert(options.begin(), kernel._options.begin(), + kernel._options.end()); + detail::detect_and_add_cuda_arch(options); + detail::detect_and_add_cxx11_flag(options); + + std::string log, ptx, mangled_instantiation; + std::vector linker_files, linker_paths; + detail::instantiate_kernel(program->_name, program->_sources, instantiation, + options, &log, &ptx, &mangled_instantiation, + &linker_files, &linker_paths); + + _cuda_kernel.reset(new detail::CUDAKernel(mangled_instantiation.c_str(), + ptx.c_str(), linker_files, + linker_paths)); + } + + /*! Implicit conversion to the underlying CUfunction object. + * + * \note This allows use of CUDA APIs like + * cuOccupancyMaxActiveBlocksPerMultiprocessor. + */ + operator CUfunction() const { return *_cuda_kernel; } + + /*! Restore a serialized kernel instantiation. + * + * \param serialized_kernel_inst The serialized kernel instantiation to + * restore. + * + * \see serialize + */ + static KernelInstantiation deserialize( + std::string const& serialized_kernel_inst) { + std::string func_name, ptx; + std::vector link_files, link_paths; + if (!serialization::deserialize(serialized_kernel_inst, &func_name, &ptx, + &link_files, &link_paths)) { + throw std::runtime_error("Failed to deserialize kernel instantiation"); + } + return KernelInstantiation(func_name, ptx, link_files, link_paths); + } + + /*! Save the program. + * + * \see deserialize + */ + std::string serialize() const { + // Note: Must update kSerializationVersion if this is changed. + return serialization::serialize( + _cuda_kernel->function_name(), _cuda_kernel->ptx(), + _cuda_kernel->link_files(), _cuda_kernel->link_paths()); + } + + /*! Configure the kernel launch. + * + * \param grid The thread grid dimensions for the launch. + * \param block The thread block dimensions for the launch. + * \param smem The amount of shared memory to dynamically allocate, in + * bytes. + * \param stream The CUDA stream to launch the kernel in. + */ + KernelLauncher configure(dim3 grid, dim3 block, unsigned int smem = 0, + cudaStream_t stream = 0) const; + + /*! Configure the kernel launch with a 1-dimensional block and grid chosen + * automatically to maximise occupancy. + * + * \param max_block_size The upper limit on the block size, or 0 for no + * limit. + * \param smem The amount of shared memory to dynamically allocate, in bytes. + * \param smem_callback A function returning smem for a given block size + * (overrides \p smem). + * \param stream The CUDA stream to launch the kernel in. + * \param flags The flags to pass to + * cuOccupancyMaxPotentialBlockSizeWithFlags. + */ + KernelLauncher configure_1d_max_occupancy( + int max_block_size = 0, unsigned int smem = 0, + CUoccupancyB2DSize smem_callback = 0, cudaStream_t stream = 0, + unsigned int flags = 0) const; + + /* + * Returns the function attribute requested from the kernel + */ + inline int get_func_attribute(CUfunction_attribute attribute) const { + return _cuda_kernel->get_func_attribute(attribute); + } + + /* + * Set the function attribute requested for the kernel + */ + inline void set_func_attribute(CUfunction_attribute attribute, + int value) const { + _cuda_kernel->set_func_attribute(attribute, value); + } + + /* + * \deprecated Use \p get_global_ptr instead. + */ + CUdeviceptr get_constant_ptr(const char* name, size_t* size = nullptr) const { + return get_global_ptr(name, size); + } + + /* + * Get a device pointer to a global __constant__ or __device__ variable using + * its un-mangled name. If provided, *size is set to the size of the variable + * in bytes. + */ + CUdeviceptr get_global_ptr(const char* name, size_t* size = nullptr) const { + return _cuda_kernel->get_global_ptr(name, size); + } + + /* + * Copy data from a global __constant__ or __device__ array to the host using + * its un-mangled name. + */ + template + CUresult get_global_array(const char* name, T* data, size_t count, + CUstream stream = 0) const { + return _cuda_kernel->get_global_data(name, data, count, stream); + } + + /* + * Copy a value from a global __constant__ or __device__ variable to the host + * using its un-mangled name. + */ + template + CUresult get_global_value(const char* name, T* value, + CUstream stream = 0) const { + return get_global_array(name, value, 1, stream); + } + + /* + * Copy data from the host to a global __constant__ or __device__ array using + * its un-mangled name. + */ + template + CUresult set_global_array(const char* name, const T* data, size_t count, + CUstream stream = 0) const { + return _cuda_kernel->set_global_data(name, data, count, stream); + } + + /* + * Copy a value from the host to a global __constant__ or __device__ variable + * using its un-mangled name. + */ + template + CUresult set_global_value(const char* name, const T& value, + CUstream stream = 0) const { + return set_global_array(name, &value, 1, stream); + } + + const std::string& mangled_name() const { + return _cuda_kernel->function_name(); + } + + const std::string& ptx() const { return _cuda_kernel->ptx(); } + + const std::vector& link_files() const { + return _cuda_kernel->link_files(); + } + + const std::vector& link_paths() const { + return _cuda_kernel->link_paths(); + } +}; + +class KernelLauncher { + KernelInstantiation const* _kernel_inst; + dim3 _grid; + dim3 _block; + unsigned int _smem; + cudaStream_t _stream; + + private: + void pre_launch(std::vector arg_types = {}) const { + (void)arg_types; +#if JITIFY_PRINT_LAUNCH + std::string arg_types_string = + (arg_types.empty() ? "..." : reflection::reflect_list(arg_types)); + std::cout << "Launching " << _kernel_inst->_cuda_kernel->function_name() + << "<<<" << _grid << "," << _block << "," << _smem << "," + << _stream << ">>>" + << "(" << arg_types_string << ")" << std::endl; +#endif + } + + public: + KernelLauncher(KernelInstantiation const* kernel_inst, dim3 grid, dim3 block, + unsigned int smem = 0, cudaStream_t stream = 0) + : _kernel_inst(kernel_inst), + _grid(grid), + _block(block), + _smem(smem), + _stream(stream) {} + + // Note: It's important that there is no implicit conversion required + // for arg_ptrs, because otherwise the parameter pack version + // below gets called instead (probably resulting in a segfault). + /*! Launch the kernel. + * + * \param arg_ptrs A vector of pointers to each function argument for the + * kernel. + * \param arg_types A vector of function argument types represented + * as code-strings. This parameter is optional and is only used to print + * out the function signature. + */ + CUresult launch(std::vector arg_ptrs = {}, + std::vector arg_types = {}) const { + pre_launch(arg_types); + return _kernel_inst->_cuda_kernel->launch(_grid, _block, _smem, _stream, + arg_ptrs); + } + + void safe_launch(std::vector arg_ptrs = {}, + std::vector arg_types = {}) const { + pre_launch(arg_types); + _kernel_inst->_cuda_kernel->safe_launch(_grid, _block, _smem, _stream, + arg_ptrs); + } + + /*! Launch the kernel. + * + * \param args Function arguments for the kernel. + */ + template + CUresult launch(const ArgTypes&... args) const { + return this->launch(std::vector({(void*)&args...}), + {reflection::reflect()...}); + } + + /*! Launch the kernel and check for cuda errors. + * + * \param args Function arguments for the kernel. + */ + template + void safe_launch(const ArgTypes&... args) const { + return this->safe_launch(std::vector({(void*)&args...}), + {reflection::reflect()...}); + } +}; + +inline Kernel Program::kernel(std::string const& name, + std::vector const& options) const { + return Kernel(this, name, options); +} + +inline KernelInstantiation Kernel::instantiate( + std::vector const& template_args) const { + return KernelInstantiation(*this, template_args); +} + +template +inline KernelInstantiation Kernel::instantiate() const { + return this->instantiate( + std::vector({reflection::reflect()...})); +} + +template +inline KernelInstantiation Kernel::instantiate(TemplateArgs... targs) const { + return this->instantiate( + std::vector({reflection::reflect(targs)...})); +} + +inline KernelLauncher KernelInstantiation::configure( + dim3 grid, dim3 block, unsigned int smem, cudaStream_t stream) const { + return KernelLauncher(this, grid, block, smem, stream); +} + +inline KernelLauncher KernelInstantiation::configure_1d_max_occupancy( + int max_block_size, unsigned int smem, CUoccupancyB2DSize smem_callback, + cudaStream_t stream, unsigned int flags) const { + int grid; + int block; + CUfunction func = *_cuda_kernel; + detail::get_1d_max_occupancy(func, smem_callback, &smem, max_block_size, + flags, &grid, &block); + return this->configure(grid, block, smem, stream); +} + +} // namespace experimental + +} // namespace jitify + +#if defined(_WIN32) || defined(_WIN64) +#pragma pop_macro("max") +#pragma pop_macro("min") +#pragma pop_macro("strtok_r") +#endif diff --git a/timemachine/cpp/src/kernels/k_fixed_point.cuh b/timemachine/cpp/src/kernels/k_fixed_point.cuh index 10340c198..4aebd0bea 100644 --- a/timemachine/cpp/src/kernels/k_fixed_point.cuh +++ b/timemachine/cpp/src/kernels/k_fixed_point.cuh @@ -3,6 +3,12 @@ // cuda specific version #include "../fixed_point.hpp" +// we need to use a different level of precision for parameter derivatives +#define FIXED_EXPONENT_DU_DCHARGE 0x1000000000 +#define FIXED_EXPONENT_DU_DSIG 0x2000000000 +#define FIXED_EXPONENT_DU_DEPS 0x4000000000 // this is just getting silly + + template RealType __device__ __forceinline__ FIXED_TO_FLOAT_DU_DP(unsigned long long v) { return static_cast(static_cast(v))/EXPONENT; diff --git a/timemachine/cpp/src/kernels/k_lambda_transformer_jit.cuh b/timemachine/cpp/src/kernels/k_lambda_transformer_jit.cuh new file mode 100644 index 000000000..fdbe9c50e --- /dev/null +++ b/timemachine/cpp/src/kernels/k_lambda_transformer_jit.cuh @@ -0,0 +1,137 @@ +jit_program + +#include "timemachine/cpp/src/kernels/surreal.cuh" +#include "timemachine/cpp/src/kernels/k_fixed_point.cuh" + +#define PI 3.141592653589793115997963468544185161 + +template +NumericType __device__ __forceinline__ transform_lambda_charge(NumericType lambda) { + return CUSTOM_EXPRESSION_CHARGE; +} + +template +NumericType __device__ __forceinline__ transform_lambda_sigma(NumericType lambda) { + return CUSTOM_EXPRESSION_SIGMA; +} + +template +NumericType __device__ __forceinline__ transform_lambda_epsilon(NumericType lambda) { + return CUSTOM_EXPRESSION_EPSILON; +} + +template +NumericType __device__ __forceinline__ transform_lambda_w(NumericType lambda) { + return CUSTOM_EXPRESSION_W; +} + +void __global__ k_compute_w_coords( + const int N, + const double lambda, + const double cutoff, + const int * __restrict__ lambda_plane_idxs, // 0 or 1, shift + const int * __restrict__ lambda_offset_idxs, + double * __restrict__ coords_w, + double * __restrict__ dw_dl) { + + int atom_i_idx = blockIdx.x*blockDim.x + threadIdx.x; + + if(atom_i_idx >= N) { + return; + } + + int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0; + int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0; + + double f_lambda = transform_lambda_w(lambda); + + double step = 1e-7; + Surreal lambda_surreal(lambda, step); + double f_lambda_grad = (transform_lambda_w(lambda_surreal).imag)/step; + + double coords_w_i = (lambda_plane_i + lambda_offset_i*f_lambda)*cutoff; + double dw_dl_i = lambda_offset_i*f_lambda_grad*cutoff; + + coords_w[atom_i_idx] = coords_w_i; + dw_dl[atom_i_idx] = dw_dl_i; + +} // 0 or 1, how much we offset from the plane by ) + +void __global__ k_permute_interpolated( + const double lambda, + const int N, + const unsigned int * __restrict__ perm, + const double * __restrict__ d_p, + double * __restrict__ d_sorted_p, + double * __restrict__ d_sorted_dp_dl) { + + int idx = blockIdx.x*blockDim.x + threadIdx.x; + int stride = gridDim.y; + int stride_idx = blockIdx.y; + + if(idx >= N) { + return; + } + + int size = N*stride; + + int source_idx = idx*stride+stride_idx; + int target_idx = perm[idx]*stride+stride_idx; + + double step = 1e-7; + Surreal lambda_surreal(lambda, step); + + double f_lambda; + double f_lambda_grad; + + if(stride_idx == 0) { + f_lambda = transform_lambda_charge(lambda); + f_lambda_grad = (transform_lambda_charge(lambda_surreal).imag)/step; + } + if(stride_idx == 1) { + f_lambda = transform_lambda_sigma(lambda); + f_lambda_grad = (transform_lambda_sigma(lambda_surreal).imag)/step; + } + if(stride_idx == 2) { + f_lambda = transform_lambda_epsilon(lambda); + f_lambda_grad = (transform_lambda_epsilon(lambda_surreal).imag)/step; + } + + d_sorted_p[source_idx] = (1-f_lambda)*d_p[target_idx] + f_lambda*d_p[size+target_idx]; + d_sorted_dp_dl[source_idx] = f_lambda_grad*(d_p[size+target_idx] - d_p[target_idx]); + +} + +void __global__ k_add_ull_to_real_interpolated( + const double lambda, + const int N, + const unsigned long long * __restrict__ ull_array, + double * __restrict__ real_array) { + + int idx = blockIdx.x*blockDim.x + threadIdx.x; + int stride = gridDim.y; + int stride_idx = blockIdx.y; + + if(idx >= N) { + return; + } + + int size = N*stride; + int target_idx = idx*stride+stride_idx; + + // handle charges, sigmas, epsilons with different exponents + if(stride_idx == 0) { + double f_lambda = transform_lambda_charge(lambda); + real_array[target_idx] += (1-f_lambda)*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); + real_array[size+target_idx] += f_lambda*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); + } else if(stride_idx == 1) { + double f_lambda = transform_lambda_sigma(lambda); + real_array[target_idx] += (1-f_lambda)*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); + real_array[size+target_idx] += f_lambda*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); + } else if(stride_idx == 2) { + double f_lambda = transform_lambda_epsilon(lambda); + real_array[target_idx] += (1-f_lambda)*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); + real_array[size+target_idx] += f_lambda*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); + } + +} diff --git a/timemachine/cpp/src/kernels/k_nonbonded.cuh b/timemachine/cpp/src/kernels/k_nonbonded.cuh index 348adaf2b..a7f775ca5 100644 --- a/timemachine/cpp/src/kernels/k_nonbonded.cuh +++ b/timemachine/cpp/src/kernels/k_nonbonded.cuh @@ -9,11 +9,6 @@ #define PI 3.141592653589793115997963468544185161 #define TWO_OVER_SQRT_PI 1.128379167095512595889238330988549829708 -// we need to use a different level of precision for parameter derivatives -#define FIXED_EXPONENT_DU_DCHARGE 0x1000000000 -#define FIXED_EXPONENT_DU_DSIG 0x2000000000 -#define FIXED_EXPONENT_DU_DEPS 0x4000000000 // this is just getting silly - // generate kv values from coordinates to be radix sorted void __global__ k_coords_to_kv( @@ -122,51 +117,6 @@ void __global__ k_check_rebuild_coords_and_box( } - -__global__ void k_interpolate_parameters( - const double lambda, - const int P_base, - const double * __restrict__ params_src, - const double * __restrict__ params_dst, - double *params_out) { - - int idx = blockIdx.x*blockDim.x + threadIdx.x; - - if(idx >= P_base) { - return; - } - - params_out[idx] = (1-lambda)*params_src[idx] + lambda*params_dst[idx]; -} - -template -void __global__ k_permute_interpolated( - const double lambda, - const int N, - const unsigned int * __restrict__ perm, - const RealType * __restrict__ d_p, - RealType * __restrict__ d_sorted_p, - RealType * __restrict__ d_sorted_dp_dl) { - - int idx = blockIdx.x*blockDim.x + threadIdx.x; - int stride = gridDim.y; - int stride_idx = blockIdx.y; - - if(idx >= N) { - return; - } - - int size = N*stride; - - int source_idx = idx*stride+stride_idx; - int target_idx = perm[idx]*stride+stride_idx; - - d_sorted_p[source_idx] = (1-lambda)*d_p[target_idx] + lambda*d_p[size+target_idx]; - d_sorted_dp_dl[source_idx] = d_p[size+target_idx] - d_p[target_idx]; - -} - - template void __global__ k_permute( const int N, @@ -273,38 +223,6 @@ void __global__ k_add_ull_to_real( } -template -void __global__ k_add_ull_to_real_interpolated( - const double lambda, - const int N, - const unsigned long long * __restrict__ ull_array, - RealType * __restrict__ real_array) { - - int idx = blockIdx.x*blockDim.x + threadIdx.x; - int stride = gridDim.y; - int stride_idx = blockIdx.y; - - if(idx >= N) { - return; - } - - int size = N*stride; - int target_idx = idx*stride+stride_idx; - - // handle charges, sigmas, epsilons with different exponents - if(stride_idx == 0) { - real_array[target_idx] += (1-lambda)*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); - real_array[size+target_idx] += lambda*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); - } else if(stride_idx == 1) { - real_array[target_idx] += (1-lambda)*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); - real_array[size+target_idx] += lambda*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); - } else if(stride_idx == 2) { - real_array[target_idx] += (1-lambda)*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); - real_array[size+target_idx] += lambda*FIXED_TO_FLOAT_DU_DP(ull_array[target_idx]); - } - -} - template void __global__ k_reduce_buffer( @@ -362,6 +280,32 @@ double __device__ __forceinline__ fix_nvidia_fmad(double a, double b, double c, return __dmul_rn(a, b) + __dmul_rn(c, d); } +// void __global__ k_compute_w_coords( +// const int N, +// const double lambda, +// const double cutoff, +// const int * __restrict__ lambda_plane_idxs, // 0 or 1, shift +// const int * __restrict__ lambda_offset_idxs, +// double * __restrict__ coords_w, +// double * __restrict__ dw_dl) { + +// int atom_i_idx = blockIdx.x*blockDim.x + threadIdx.x; + +// if(atom_i_idx >= N) { +// return; +// } + +// int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0; +// int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0; + +// double coords_w_i = (lambda_plane_i + lambda_offset_i*lambda)*cutoff; +// double dw_dl_i = lambda_offset_i*cutoff; + +// coords_w[atom_i_idx] = coords_w_i; +// dw_dl[atom_i_idx] = dw_dl_i; + +// } // 0 or 1, how much we offset from the plane by ) + // ALCHEMICAL == false guarantees that the tile's atoms are such that // 1. src_param and dst_params are equal for every i in R and j in C // 2. w_i and w_j are identical for every (i,j) in (RxC) @@ -382,9 +326,11 @@ void __device__ v_nonbonded_unified( const double * __restrict__ params, // [N] const double * __restrict__ box, const double * __restrict__ dp_dl, + const double * __restrict__ coords_w, // 4D coords + const double * __restrict__ dw_dl, // 4D derivatives const double lambda, - const int * __restrict__ lambda_plane_idxs, // 0 or 1, shift - const int * __restrict__ lambda_offset_idxs, // 0 or 1, how much we offset from the plane by cutoff + // const int * __restrict__ lambda_plane_idxs, // 0 or 1, shift + // const int * __restrict__ lambda_offset_idxs, // 0 or 1, how much we offset from the plane by cutoff const double beta, const double cutoff, const int * __restrict__ ixn_tiles, @@ -407,16 +353,18 @@ void __device__ v_nonbonded_unified( int row_block_idx = ixn_tiles[tile_idx]; int atom_i_idx = row_block_idx*32 + threadIdx.x; - int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0; - int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0; + // int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0; + // int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0; RealType ci_x = atom_i_idx < N ? coords[atom_i_idx*3+0] : 0; RealType ci_y = atom_i_idx < N ? coords[atom_i_idx*3+1] : 0; RealType ci_z = atom_i_idx < N ? coords[atom_i_idx*3+2] : 0; + RealType ci_w = atom_i_idx < N ? coords_w[atom_i_idx] : 0; RealType dq_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx*3+0] : 0; RealType dsig_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx*3+1] : 0; RealType deps_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx*3+2] : 0; + RealType dw_dl_i = atom_i_idx < N ? dw_dl[atom_i_idx] : 0; unsigned long long gi_x = 0; unsigned long long gi_y = 0; @@ -437,16 +385,18 @@ void __device__ v_nonbonded_unified( // i idx is contiguous but j is not, so we should swap them to avoid having to shuffle atom_j_idx int atom_j_idx = ixn_atoms[tile_idx*32 + threadIdx.x]; - int lambda_offset_j = atom_j_idx < N ? lambda_offset_idxs[atom_j_idx] : 0; - int lambda_plane_j = atom_j_idx < N ? lambda_plane_idxs[atom_j_idx] : 0; + // int lambda_offset_j = atom_j_idx < N ? lambda_offset_idxs[atom_j_idx] : 0; + // int lambda_plane_j = atom_j_idx < N ? lambda_plane_idxs[atom_j_idx] : 0; RealType cj_x = atom_j_idx < N ? coords[atom_j_idx*3+0] : 0; RealType cj_y = atom_j_idx < N ? coords[atom_j_idx*3+1] : 0; RealType cj_z = atom_j_idx < N ? coords[atom_j_idx*3+2] : 0; + RealType cj_w = atom_j_idx < N ? coords_w[atom_j_idx] : 0; RealType dq_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx*3+0] : 0; RealType dsig_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx*3+1] : 0; RealType deps_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx*3+2] : 0; + RealType dw_dl_j = atom_j_idx < N ? dw_dl[atom_j_idx] : 0; unsigned long long gj_x = 0; unsigned long long gj_y = 0; @@ -489,7 +439,8 @@ void __device__ v_nonbonded_unified( if(ALCHEMICAL) { // (ytz): we are guaranteed that delta_w is zero if ALCHEMICAL == false - delta_w = (lambda_plane_i - lambda_plane_j)*real_cutoff + (lambda_offset_i - lambda_offset_j)*real_lambda*real_cutoff; + // delta_w = (lambda_plane_i - lambda_plane_j)*real_cutoff + (lambda_offset_i - lambda_offset_j)*real_lambda*real_cutoff; + delta_w = ci_w - cj_w; d2ij += delta_w * delta_w; } @@ -592,8 +543,7 @@ void __device__ v_nonbonded_unified( if(COMPUTE_DU_DL && ALCHEMICAL) { // needed for cancellation of nans (if one term blows up) - real_du_dl += delta_w*cutoff*delta_prefactor*(lambda_offset_i - lambda_offset_j); - // this extra ebd kills as it requires the erfc to be fully evaluated. SHIT + real_du_dl += delta_w*delta_prefactor*(dw_dl_i - dw_dl_j); real_du_dl += inv_dij*ebd*fix_nvidia_fmad(qj, dq_dl_i, qi, dq_dl_j); du_dl += FLOAT_TO_FIXED_NONBONDED(real_du_dl); } @@ -614,8 +564,8 @@ void __device__ v_nonbonded_unified( cj_z = __shfl_sync(0xffffffff, cj_z, srcLane); if(ALCHEMICAL) { - lambda_offset_j = __shfl_sync(0xffffffff, lambda_offset_j, srcLane); // this also can be optimized away - lambda_plane_j = __shfl_sync(0xffffffff, lambda_plane_j, srcLane); + cj_w = __shfl_sync(0xffffffff, cj_w, srcLane); // this also can be optimized away + dw_dl_j = __shfl_sync(0xffffffff, dw_dl_j, srcLane); } if(COMPUTE_DU_DX) { @@ -694,9 +644,9 @@ void __global__ k_nonbonded_unified( const double * __restrict__ params, // [N] const double * __restrict__ box, const double * __restrict__ dp_dl, + const double * __restrict__ coords_w, // 4D coords + const double * __restrict__ dw_dl, // 4D derivatives const double lambda, - const int * __restrict__ lambda_plane_idxs, // 0 or 1, shift - const int * __restrict__ lambda_offset_idxs, // 0 or 1, how much we offset from the plane by cutoff const double beta, const double cutoff, const int * __restrict__ ixn_tiles, @@ -709,29 +659,25 @@ void __global__ k_nonbonded_unified( int tile_idx = blockIdx.x; int row_block_idx = ixn_tiles[tile_idx]; int atom_i_idx = row_block_idx*32 + threadIdx.x; - int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0; - int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0; RealType dq_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx*3+0] : 0; RealType dsig_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx*3+1] : 0; RealType deps_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx*3+2] : 0; + RealType cw_i = atom_i_idx < N ? coords_w[atom_i_idx] : 0; int atom_j_idx = ixn_atoms[tile_idx*32 + threadIdx.x]; - int lambda_offset_j = atom_j_idx < N ? lambda_offset_idxs[atom_j_idx] : 0; - int lambda_plane_j = atom_j_idx < N ? lambda_plane_idxs[atom_j_idx] : 0; RealType dq_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx*3+0] : 0; RealType dsig_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx*3+1] : 0; RealType deps_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx*3+2] : 0; + RealType cw_j = atom_j_idx < N ? coords_w[atom_j_idx] : 0; int is_vanilla = ( - lambda_offset_i == 0 && - lambda_plane_i == 0 && + cw_i == 0 && dq_dl_i == 0 && dsig_dl_i == 0 && deps_dl_i == 0 && - lambda_offset_j == 0 && - lambda_plane_j == 0 && + cw_j == 0 && dq_dl_j == 0 && dsig_dl_j == 0 && deps_dl_j == 0 @@ -746,9 +692,9 @@ void __global__ k_nonbonded_unified( params, box, dp_dl, + coords_w, + dw_dl, lambda, - lambda_plane_idxs, - lambda_offset_idxs, beta, cutoff, ixn_tiles, @@ -765,9 +711,9 @@ void __global__ k_nonbonded_unified( params, box, dp_dl, + coords_w, + dw_dl, lambda, - lambda_plane_idxs, - lambda_offset_idxs, beta, cutoff, ixn_tiles, @@ -790,9 +736,9 @@ void __global__ k_nonbonded_exclusions( const double * __restrict__ params, const double * __restrict__ box, const double * __restrict__ dp_dl, + const double * __restrict__ coords_w, // 4D coords + const double * __restrict__ dw_dl, // 4D derivatives const double lambda, - const int * __restrict__ lambda_plane_idxs, // 0 or 1, shift - const int * __restrict__ lambda_offset_idxs, // 0 or 1, if we alolw this atom to be decoupled const int * __restrict__ exclusion_idxs, // [E, 2] pair-list of atoms to be excluded const double * __restrict__ scales, // [E] const double beta, @@ -815,16 +761,16 @@ void __global__ k_nonbonded_exclusions( } int atom_i_idx = exclusion_idxs[e_idx*2 + 0]; - int lambda_offset_i = lambda_offset_idxs[atom_i_idx]; - int lambda_plane_i = lambda_plane_idxs[atom_i_idx]; RealType ci_x = coords[atom_i_idx*3+0]; RealType ci_y = coords[atom_i_idx*3+1]; RealType ci_z = coords[atom_i_idx*3+2]; + RealType ci_w = coords_w[atom_i_idx]; RealType dq_dl_i = dp_dl[atom_i_idx*3+0]; RealType dsig_dl_i = dp_dl[atom_i_idx*3+1]; RealType deps_dl_i = dp_dl[atom_i_idx*3+2]; + RealType dw_dl_i = dw_dl[atom_i_idx]; unsigned long long gi_x = 0; unsigned long long gi_y = 0; @@ -843,16 +789,16 @@ void __global__ k_nonbonded_exclusions( unsigned long long g_epsi = 0; int atom_j_idx = exclusion_idxs[e_idx*2 + 1]; - int lambda_offset_j = lambda_offset_idxs[atom_j_idx]; - int lambda_plane_j = lambda_plane_idxs[atom_j_idx]; RealType cj_x = coords[atom_j_idx*3+0]; RealType cj_y = coords[atom_j_idx*3+1]; RealType cj_z = coords[atom_j_idx*3+2]; + RealType cj_w = coords_w[atom_j_idx]; RealType dq_dl_j = dp_dl[atom_j_idx*3+0]; RealType dsig_dl_j = dp_dl[atom_j_idx*3+1]; RealType deps_dl_j = dp_dl[atom_j_idx*3+2]; + RealType dw_dl_j = dw_dl[atom_j_idx]; unsigned long long gj_x = 0; unsigned long long gj_y = 0; @@ -895,8 +841,7 @@ void __global__ k_nonbonded_exclusions( delta_y -= box_y*nearbyint(delta_y*inv_box_y); delta_z -= box_z*nearbyint(delta_z*inv_box_z); - RealType delta_w = (lambda_plane_i - lambda_plane_j)*real_cutoff + (lambda_offset_i - lambda_offset_j)*real_lambda*real_cutoff; - + RealType delta_w = ci_w - cj_w; RealType d2ij = delta_x*delta_x + delta_y*delta_y + delta_z*delta_z + delta_w*delta_w; unsigned long long energy = 0; @@ -991,7 +936,7 @@ void __global__ k_nonbonded_exclusions( g_qi -= FLOAT_TO_FIXED_DU_DP(charge_scale*qj*inv_dij*ebd); g_qj -= FLOAT_TO_FIXED_DU_DP(charge_scale*qi*inv_dij*ebd); - real_du_dl -= delta_w*cutoff*delta_prefactor*(lambda_offset_i - lambda_offset_j); + real_du_dl -= delta_w*delta_prefactor*(dw_dl_i - dw_dl_j); real_du_dl -= charge_scale*inv_dij*ebd*fix_nvidia_fmad(qj, dq_dl_i, qi, dq_dl_j); if(du_dx) { diff --git a/timemachine/cpp/src/nonbonded.cu b/timemachine/cpp/src/nonbonded.cu index 04ece8be2..3e683e34c 100644 --- a/timemachine/cpp/src/nonbonded.cu +++ b/timemachine/cpp/src/nonbonded.cu @@ -6,6 +6,7 @@ #include #include #include +#include "jitify.hpp" #include "nonbonded.hpp" #include "hilbert.h" @@ -13,16 +14,28 @@ #include "k_nonbonded.cuh" +#include +#include +#include + namespace timemachine { +static jitify::JitCache kernel_cache; + + template Nonbonded::Nonbonded( const std::vector &exclusion_idxs, // [E,2] const std::vector &scales, // [E, 2] const std::vector &lambda_plane_idxs, // [N] const std::vector &lambda_offset_idxs, // [N] - double beta, - double cutoff + const double beta, + const double cutoff, + const std::string &kernel_src + // const std::string &transform_lambda_charge, + // const std::string &transform_lambda_sigma, + // const std::string &transform_lambda_epsilon, + // const std::string &transform_lambda_w ) : N_(lambda_offset_idxs.size()), cutoff_(cutoff), E_(exclusion_idxs.size()/2), @@ -55,7 +68,10 @@ Nonbonded::Nonbonded( &k_nonbonded_unified, &k_nonbonded_unified, &k_nonbonded_unified - }){ + }), + compute_w_coords_instance_(kernel_cache.program(kernel_src.c_str()).kernel("k_compute_w_coords").instantiate()), + compute_permute_interpolated_(kernel_cache.program(kernel_src.c_str()).kernel("k_permute_interpolated").instantiate()), + compute_add_ull_to_real_interpolated_(kernel_cache.program(kernel_src.c_str()).kernel("k_add_ull_to_real_interpolated").instantiate()) { if(lambda_offset_idxs.size() != N_) { throw std::runtime_error("lambda offset idxs need to have size N"); @@ -77,10 +93,13 @@ Nonbonded::Nonbonded( gpuErrchk(cudaMalloc(&d_perm_, N_*sizeof(*d_perm_))); - gpuErrchk(cudaMalloc(&d_sorted_lambda_plane_idxs_, N_*sizeof(*d_sorted_lambda_plane_idxs_))); - gpuErrchk(cudaMalloc(&d_sorted_lambda_offset_idxs_, N_*sizeof(*d_sorted_lambda_offset_idxs_))); gpuErrchk(cudaMalloc(&d_sorted_x_, N_*3*sizeof(*d_sorted_x_))); + gpuErrchk(cudaMalloc(&d_w_, N_*sizeof(*d_w_))); + gpuErrchk(cudaMalloc(&d_dw_dl_, N_*sizeof(*d_dw_dl_))); + gpuErrchk(cudaMalloc(&d_sorted_w_, N_*sizeof(*d_sorted_w_))); + gpuErrchk(cudaMalloc(&d_sorted_dw_dl_, N_*sizeof(*d_sorted_dw_dl_))); + gpuErrchk(cudaMalloc(&d_unsorted_p_, N_*3*sizeof(*d_unsorted_p_))); // interpolated gpuErrchk(cudaMalloc(&d_sorted_p_, N_*3*sizeof(*d_sorted_p_))); // interpolated gpuErrchk(cudaMalloc(&d_unsorted_dp_dl_, N_*3*sizeof(*d_unsorted_dp_dl_))); // interpolated @@ -141,10 +160,8 @@ Nonbonded::Nonbonded( ); gpuErrchk(cudaPeekAtLastError()); - gpuErrchk(cudaMalloc(&d_sort_storage_, d_sort_storage_bytes_)); - }; template @@ -159,22 +176,23 @@ Nonbonded::~Nonbonded() { gpuErrchk(cudaFree(d_bin_to_idx_)); gpuErrchk(cudaFree(d_sorted_x_)); + + gpuErrchk(cudaFree(d_w_)); + gpuErrchk(cudaFree(d_dw_dl_)); + gpuErrchk(cudaFree(d_sorted_w_)); + gpuErrchk(cudaFree(d_sorted_dw_dl_)); gpuErrchk(cudaFree(d_unsorted_p_)); gpuErrchk(cudaFree(d_sorted_p_)); gpuErrchk(cudaFree(d_unsorted_dp_dl_)); gpuErrchk(cudaFree(d_sorted_dp_dl_)); gpuErrchk(cudaFree(d_sorted_du_dx_)); gpuErrchk(cudaFree(d_sorted_du_dp_)); - gpuErrchk(cudaFree(d_sorted_lambda_plane_idxs_)); - gpuErrchk(cudaFree(d_sorted_lambda_offset_idxs_)); gpuErrchk(cudaFree(d_sort_keys_in_)); gpuErrchk(cudaFree(d_sort_keys_out_)); gpuErrchk(cudaFree(d_sort_vals_in_)); gpuErrchk(cudaFree(d_sort_storage_)); - // gpuErrchk(cudaFree(d_sum_storage_)); - gpuErrchk(cudaFreeHost(p_ixn_count_)); gpuErrchk(cudaFree(d_nblist_x_)); @@ -295,8 +313,6 @@ void Nonbonded::execute_device( d_rebuild_nblist_ ); gpuErrchk(cudaPeekAtLastError()); - // k_check_rebuild_box<<<1, tpb, 0, stream>>>(N, d_box, d_nblist_box_, d_rebuild_nblist_); - // gpuErrchk(cudaPeekAtLastError()); // we can optimize this away by doing the check on the GPU directly. gpuErrchk(cudaMemcpyAsync(p_rebuild_nblist_, d_rebuild_nblist_, 1*sizeof(*p_rebuild_nblist_), cudaMemcpyDeviceToHost, stream)); @@ -316,10 +332,6 @@ void Nonbonded::execute_device( // compute new coordinates, new lambda_idxs, new_plane_idxs k_permute<<>>(N, d_perm_, d_x, d_sorted_x_); gpuErrchk(cudaPeekAtLastError()); - k_permute<<>>(N, d_perm_, d_lambda_plane_idxs_, d_sorted_lambda_plane_idxs_); - gpuErrchk(cudaPeekAtLastError()); - k_permute<<>>(N, d_perm_, d_lambda_offset_idxs_, d_sorted_lambda_offset_idxs_); - gpuErrchk(cudaPeekAtLastError()); nblist_.build_nblist_device( N, d_sorted_x_, @@ -340,7 +352,8 @@ void Nonbonded::execute_device( // do parameter interpolation here if(Interpolated) { - k_permute_interpolated<<>>( + CUresult result = compute_permute_interpolated_.configure(dimGrid, 32, 0, stream) + .launch( lambda, N, d_perm_, @@ -348,7 +361,9 @@ void Nonbonded::execute_device( d_sorted_p_, d_sorted_dp_dl_ ); - gpuErrchk(cudaPeekAtLastError()); + if(result != 0) { + throw std::runtime_error("Driver call to k_permute_interpolated failed"); + } } else { k_permute<<>>(N, d_perm_, d_p, d_sorted_p_); gpuErrchk(cudaPeekAtLastError()); @@ -364,13 +379,27 @@ void Nonbonded::execute_device( gpuErrchk(cudaMemsetAsync(d_sorted_du_dp_, 0, N*3*sizeof(*d_sorted_du_dp_), stream)) } - // if(d_du_dl) { - // this->reset_du_dl_buffer(stream); - // } + // update new w coordinates + // (tbd): cache lambda value for equilibrium calculations + CUresult result = compute_w_coords_instance_.configure(B, tpb, 0, stream) + .launch( + N, + lambda, + cutoff_, + d_lambda_plane_idxs_, + d_lambda_offset_idxs_, + d_w_, + d_dw_dl_ + ); + if(result != 0) { + throw std::runtime_error("Driver call to k_compute_w_coords"); + } - // if(d_u) { - // this->reset_u_buffer(stream); - // } + gpuErrchk(cudaPeekAtLastError()); + k_permute<<>>(N, d_perm_, d_w_, d_sorted_w_); + gpuErrchk(cudaPeekAtLastError()); + k_permute<<>>(N, d_perm_, d_dw_dl_, d_sorted_dw_dl_); + gpuErrchk(cudaPeekAtLastError()); // look up which kernel we need for this computation int kernel_idx = 0; @@ -378,7 +407,6 @@ void Nonbonded::execute_device( kernel_idx |= d_du_dl ? 1 << 1 : 0; kernel_idx |= d_du_dx ? 1 << 2 : 0; kernel_idx |= d_u ? 1 << 3 : 0; - // kernel_idx |= 1 << 4; // force set alchemical = True for now before we start optimizations kernel_ptrs_[kernel_idx]<<>>( N, @@ -386,9 +414,9 @@ void Nonbonded::execute_device( d_sorted_p_, d_box, d_sorted_dp_dl_, + d_sorted_w_, + d_sorted_dw_dl_, lambda, - d_sorted_lambda_plane_idxs_, - d_sorted_lambda_offset_idxs_, beta_, cutoff_, nblist_.get_ixn_tiles(), @@ -438,9 +466,9 @@ void Nonbonded::execute_device( Interpolated ? d_unsorted_p_ : d_p, d_box, Interpolated ? d_unsorted_dp_dl_ : d_sorted_dp_dl_, + d_w_, + d_dw_dl_, lambda, - d_lambda_plane_idxs_, - d_lambda_offset_idxs_, d_exclusion_idxs_, d_scales_, beta_, @@ -455,12 +483,16 @@ void Nonbonded::execute_device( if(d_du_dp) { if(Interpolated) { - k_add_ull_to_real_interpolated<<>>( + CUresult result = compute_add_ull_to_real_interpolated_.configure(dimGrid, tpb, 0, stream) + .launch( lambda, N, d_du_dp_buffer_, d_du_dp ); + if(result != 0) { + throw std::runtime_error("Driver call to k_add_ull_to_real_interpolated failed"); + } } else { k_add_ull_to_real<<>>( N, diff --git a/timemachine/cpp/src/nonbonded.hpp b/timemachine/cpp/src/nonbonded.hpp index 22bc3a85a..ff3b6ab6a 100644 --- a/timemachine/cpp/src/nonbonded.hpp +++ b/timemachine/cpp/src/nonbonded.hpp @@ -1,5 +1,6 @@ #pragma once +#include "jitify.hpp" #include "neighborlist.hpp" #include "potential.hpp" #include @@ -12,9 +13,9 @@ typedef void (*k_nonbonded_fn)(const int N, const double * __restrict__ params, // [N] const double * __restrict__ box, const double * __restrict__ dl_dp, + const double * __restrict__ coords_w, // 4D coords + const double * __restrict__ dw_dl, // 4D derivatives const double lambda, - const int * __restrict__ lambda_plane_idxs, // 0 or 1, shift - const int * __restrict__ lambda_offset_idxs, // 0 or 1, how much we offset from the plane by cutoff const double beta, const double cutoff, const int * __restrict__ ixn_tiles, @@ -52,9 +53,12 @@ class Nonbonded : public Potential { unsigned int *d_perm_; // hilbert curve permutation - int *d_sorted_lambda_plane_idxs_; - int *d_sorted_lambda_offset_idxs_; + double *d_w_; // + double *d_dw_dl_; // + double *d_sorted_x_; // + double *d_sorted_w_; // + double *d_sorted_dw_dl_; // double *d_sorted_p_; // double *d_unsorted_p_; // double *d_sorted_dp_dl_; @@ -78,6 +82,10 @@ class Nonbonded : public Potential { cudaStream_t stream ); + jitify::KernelInstantiation compute_w_coords_instance_; + jitify::KernelInstantiation compute_permute_interpolated_; + jitify::KernelInstantiation compute_add_ull_to_real_interpolated_; + public: // these are marked public but really only intended for testing. @@ -89,8 +97,9 @@ class Nonbonded : public Potential { const std::vector &scales, // [E, 2] const std::vector &lambda_plane_idxs, // N const std::vector &lambda_offset_idxs, // N - double beta, - double cutoff + const double beta, + const double cutoff, + const std::string &kernel_src ); ~Nonbonded(); diff --git a/timemachine/cpp/src/wrap_kernels.cpp b/timemachine/cpp/src/wrap_kernels.cpp index 3edc5186f..5cef69193 100644 --- a/timemachine/cpp/src/wrap_kernels.cpp +++ b/timemachine/cpp/src/wrap_kernels.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include "context.hpp" #include "potential.hpp" @@ -857,6 +858,16 @@ void declare_periodic_torsion(py::module &m, const char *typestr) { // } + +// stackoverflow +std::string dirname(const std::string& fname) { + size_t pos = fname.find_last_of("\\/"); + return (std::string::npos == pos) + ? "" + : fname.substr(0, pos); +} + + template void declare_nonbonded(py::module &m, const char *typestr) { @@ -875,8 +886,12 @@ void declare_nonbonded(py::module &m, const char *typestr) { const py::array_t &scales_i, // [E, 2] const py::array_t &lambda_plane_idxs_i, // const py::array_t &lambda_offset_idxs_i, // - double beta, - double cutoff) { + const double beta, + const double cutoff, + const std::string &transform_lambda_charge="lambda", + const std::string &transform_lambda_sigma="lambda", + const std::string &transform_lambda_epsilon="lambda", + const std::string &transform_lambda_w="lambda") { std::vector exclusion_idxs(exclusion_i.size()); std::memcpy(exclusion_idxs.data(), exclusion_i.data(), exclusion_i.size()*sizeof(int)); @@ -890,16 +905,35 @@ void declare_nonbonded(py::module &m, const char *typestr) { std::vector lambda_offset_idxs(lambda_offset_idxs_i.size()); std::memcpy(lambda_offset_idxs.data(), lambda_offset_idxs_i.data(), lambda_offset_idxs_i.size()*sizeof(int)); + std::string dir_path = dirname(__FILE__); + std::string src_path = dir_path + "/kernels/k_lambda_transformer_jit.cuh"; + std::ifstream t(src_path); + std::string source_str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); + source_str = std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_CHARGE"), transform_lambda_charge); + source_str = std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_SIGMA"), transform_lambda_sigma); + source_str = std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_EPSILON"), transform_lambda_epsilon); + source_str = std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_W"), transform_lambda_w); + return new timemachine::Nonbonded( exclusion_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, - cutoff + cutoff, + source_str ); - } - )); + }), + py::arg("exclusion_i"), + py::arg("scales_i"), + py::arg("lambda_plane_idxs_i"), + py::arg("lambda_offset_idxs_i"), + py::arg("beta"), + py::arg("cutoff"), + py::arg("transform_lambda_charge")="lambda", + py::arg("transform_lambda_sigma")="lambda", + py::arg("transform_lambda_epsilon")="lambda", + py::arg("transform_lambda_w")="lambda"); } diff --git a/timemachine/potentials/nonbonded.py b/timemachine/potentials/nonbonded.py index 454b24004..bf951fbf7 100644 --- a/timemachine/potentials/nonbonded.py +++ b/timemachine/potentials/nonbonded.py @@ -127,10 +127,11 @@ def nonbonded_v3( cutoff, lambda_plane_idxs, lambda_offset_idxs): - + N = conf.shape[0] - conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) + if conf.shape[-1] == 3: + conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) # make 4th dimension of box large enough so its roughly aperiodic if box is not None: