#ifndef STAN_MATH_GPU_ERR_CHECK_NAN_HPP #define STAN_MATH_GPU_ERR_CHECK_NAN_HPP #ifdef STAN_OPENCL #include <stan/math/gpu/matrix_gpu.hpp> #include <stan/math/gpu/kernels/check_nan.hpp> #include <stan/math/prim/scal/err/domain_error.hpp> namespace stan { namespace math { /** * Check if the <code>matrix_gpu</code> has NaN values * * @param function Function name (for error messages) * @param name Variable name (for error messages) * @param y <code>matrix_gpu</code> to test * * @throw <code>std::domain_error</code> if * any element of the matrix is <code>NaN</code>. */ inline void check_nan(const char* function, const char* name, const matrix_gpu& y) { if (y.size() == 0) return; cl::CommandQueue cmd_queue = opencl_context.queue(); cl::Context& ctx = opencl_context.context(); try { int nan_flag = 0; cl::Buffer buffer_nan_flag(ctx, CL_MEM_READ_WRITE, sizeof(int)); cmd_queue.enqueueWriteBuffer(buffer_nan_flag, CL_TRUE, 0, sizeof(int), &nan_flag); opencl_kernels::check_nan(cl::NDRange(y.rows(), y.cols()), y.buffer(), buffer_nan_flag, y.rows(), y.cols()); cmd_queue.enqueueReadBuffer(buffer_nan_flag, CL_TRUE, 0, sizeof(int), &nan_flag); // if NaN values were found in the matrix if (nan_flag) { domain_error(function, name, "has NaN values", ""); } } catch (const cl::Error& e) { check_opencl_error("nan_check", e); } } } // namespace math } // namespace stan #endif #endif