Stan Math Library  2.20.0
reverse mode automatic differentiation
multiply.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_OPENCL_MULTIPLY_HPP
2 #define STAN_MATH_OPENCL_MULTIPLY_HPP
3 #ifdef STAN_OPENCL
12 
13 namespace stan {
14 namespace math {
15 namespace opencl {
34 template <TriangularViewCL triangular_view_A = TriangularViewCL::Entire,
35  TriangularViewCL triangular_view_B = TriangularViewCL::Entire>
36 inline auto multiply(const matrix_cl& A, const matrix_cl& B) {
37  check_size_match("multiply ((OpenCL))", "A.cols()", A.cols(), "B.rows()",
38  B.rows());
39  matrix_cl temp(A.rows(), B.cols());
40  if (A.size() == 0 || B.size() == 0) {
41  temp.zeros();
42  return temp;
43  }
44  if (A.rows() == 1) {
45  const int local_size
46  = opencl_kernels::row_vector_matrix_multiply.make_functor.get_opts().at(
47  "LOCAL_SIZE_");
48  try {
50  cl::NDRange(temp.cols() * local_size), cl::NDRange(local_size), A, B,
51  temp, B.rows(), B.cols(), triangular_view_A, triangular_view_B);
52  } catch (cl::Error& e) {
53  check_opencl_error("row_vector - matrix multiply", e);
54  }
55  return temp;
56  }
57  if (B.cols() == 1) {
58  try {
60  cl::NDRange(temp.rows()), A, B, temp, A.rows(), A.cols(),
61  triangular_view_A, triangular_view_B);
62  } catch (cl::Error& e) {
63  check_opencl_error("matrix - vector multiply", e);
64  }
65  return temp;
66  }
67  int local = opencl_kernels::matrix_multiply.make_functor.get_opts().at(
68  "THREAD_BLOCK_SIZE");
69  const int Mpad = ((A.rows() + local - 1) / local) * local;
70  const int Npad = ((B.cols() + local - 1) / local) * local;
71  const int wpt = opencl_kernels::matrix_multiply.make_functor.get_opts().at(
72  "WORK_PER_THREAD");
73  int split = A.cols() / std::sqrt(A.rows() * B.cols());
74  if (split > 20) {
75  split = 20;
76  }
77  // when there result matrix is large, there is no benefit of splitting
78  // as the number of created threads is large enough to occupy all
79  // compute units in the OpenCL device
81  split = 1;
82  }
83  try {
84  if (split <= 1) {
85  opencl_kernels::matrix_multiply(cl::NDRange(Mpad, Npad / wpt),
86  cl::NDRange(local, local / wpt), A, B,
87  temp, A.rows(), B.cols(), B.rows(),
88  triangular_view_A, triangular_view_B);
89  } else {
90  matrix_cl tempSplit(A.rows(), B.cols() * split);
91  opencl_kernels::matrix_multiply(cl::NDRange(Mpad, Npad / wpt, split),
92  cl::NDRange(local, local / wpt, 1), A, B,
93  tempSplit, A.rows(), B.cols(), B.rows(),
94  triangular_view_A, triangular_view_B);
95  opencl_kernels::add_batch(cl::NDRange(A.rows(), B.cols()), temp,
96  tempSplit, A.rows(), B.cols(), split);
97  }
98  } catch (cl::Error& e) {
99  check_opencl_error("multiply", e);
100  }
101  return temp;
102 }
103 } // namespace opencl
104 
113 inline matrix_cl multiply(const matrix_cl& A, const double scalar) {
114  matrix_cl temp(A.rows(), A.cols());
115  if (A.size() == 0)
116  return temp;
117  try {
118  opencl_kernels::scalar_mul(cl::NDRange(A.rows(), A.cols()), temp, A, scalar,
119  A.rows(), A.cols());
120  } catch (const cl::Error& e) {
121  check_opencl_error("multiply scalar", e);
122  }
123  return temp;
124 }
125 
134 inline auto multiply(const double scalar, const matrix_cl& A) {
135  return multiply(A, scalar);
136 }
137 
150 inline auto multiply(const matrix_cl& A, const matrix_cl& B) {
151  return opencl::multiply(A, B);
152 }
153 
166 inline matrix_cl operator*(const matrix_cl& A, const matrix_cl& B) {
167  return opencl::multiply(A, B);
168 }
169 inline matrix_cl operator*(const matrix_cl& B, const double scalar) {
170  return multiply(B, scalar);
171 }
172 inline matrix_cl operator*(const double scalar, const matrix_cl& B) {
173  return multiply(scalar, B);
174 }
175 } // namespace math
176 } // namespace stan
177 
178 #endif
179 #endif
fvar< T > sqrt(const fvar< T > &x)
Definition: sqrt.hpp:13
fvar< T > operator*(const fvar< T > &x, const fvar< T > &y)
Return the product of the two arguments.
The API to access the methods and values in opencl_context_base.
const kernel_cl< out_buffer, in_buffer, double, int, int > scalar_mul("scalar_mul", {indexing_helpers, scalar_mul_kernel_code})
See the docs for add() .
The matrix_cl class - allocates memory space on the OpenCL device, functions for transfering matrices...
opencl_context_base::tuning_struct & tuning_opts()
Returns the thread block size for the Cholesky Decompositions L_11.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
auto multiply(const matrix_cl &A, const matrix_cl &B)
Computes the product of the specified matrices with the option of specifying the triangularity of eit...
Definition: multiply.hpp:36
void zeros()
Stores zeros in the matrix on the OpenCL device.
Definition: zeros.hpp:27
Represents a matrix on the OpenCL device.
Definition: matrix_cl.hpp:29
checking OpenCL error numbers
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, TriangularViewCL, TriangularViewCL > row_vector_matrix_multiply("row_vector_matrix_multiply", row_vector_matrix_multiply_kernel_code, {{"LOCAL_SIZE_", 64}, {"REDUCTION_STEP_SIZE", 4}})
See the docs for row_vector_matrix_multiply() .
const kernel_cl< out_buffer, in_buffer, int, int, int > add_batch("add_batch", {indexing_helpers, add_batch_kernel_code})
See the docs for add_batch() .
double e()
Return the base of the natural logarithm.
Definition: constants.hpp:87
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, int, TriangularViewCL, TriangularViewCL > matrix_multiply("matrix_multiply", {thread_block_helpers, matrix_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for matrix_multiply() .
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occured.
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, TriangularViewCL, TriangularViewCL > matrix_vector_multiply("matrix_vector_multiply", matrix_vector_multiply_kernel_code)
See the docs for matrix_vector_multiply() .

     [ Stan Home Page ] © 2011–2018, Stan Development Team.