Stan Math Library  2.20.0
reverse mode automatic differentiation
tri_inverse.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_OPENCL_TRI_INVERSE_HPP
2 #define STAN_MATH_OPENCL_TRI_INVERSE_HPP
3 
4 #ifdef STAN_OPENCL
16 
17 #include <string>
18 #include <vector>
19 
20 namespace stan {
21 namespace math {
38 template <TriangularViewCL triangular_view>
39 inline matrix_cl tri_inverse(const matrix_cl& A) {
40  static_assert(triangular_view != TriangularViewCL::Entire,
41  "tri_inverse(OpenCL) only supports triangular input matrices");
42  check_square("tri_inverse (OpenCL)", "A", A);
43 
44  int thread_block_2D_dim = 32;
45  int max_1D_thread_block_size = opencl_context.max_thread_block_size();
46  // we split the input matrix to 32 blocks
47  int thread_block_size_1D
48  = (((A.rows() / 32) + thread_block_2D_dim - 1) / thread_block_2D_dim)
49  * thread_block_2D_dim;
50  if (max_1D_thread_block_size < thread_block_size_1D) {
51  thread_block_size_1D = max_1D_thread_block_size;
52  }
53  int max_2D_thread_block_dim = sqrt(max_1D_thread_block_size);
54  if (max_2D_thread_block_dim < thread_block_2D_dim) {
55  thread_block_2D_dim = max_2D_thread_block_dim;
56  }
57  // for small size split in max 2 parts
58  if (thread_block_size_1D < 64) {
59  thread_block_size_1D = 32;
60  }
61  if (A.rows() < thread_block_size_1D) {
62  thread_block_size_1D = A.rows();
63  }
64 
65  // pad the input matrix
66  int A_rows_padded
67  = ((A.rows() + thread_block_size_1D - 1) / thread_block_size_1D)
68  * thread_block_size_1D;
69 
70  matrix_cl temp(A_rows_padded, A_rows_padded);
71  matrix_cl inv_padded(A_rows_padded, A_rows_padded);
72  matrix_cl inv_mat(A);
73  matrix_cl zero_mat(A_rows_padded - A.rows(), A_rows_padded);
77  if (triangular_view == TriangularViewCL::Upper) {
78  inv_mat = transpose(inv_mat);
79  }
80  int work_per_thread
81  = opencl_kernels::inv_lower_tri_multiply.make_functor.get_opts().at(
82  "WORK_PER_THREAD");
83  // the number of blocks in the first step
84  // each block is inverted with using the regular forward substitution
85  int parts = inv_padded.rows() / thread_block_size_1D;
86  inv_padded.sub_block(inv_mat, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows());
87  try {
88  // create a batch of identity matrices to be used in the first step
90  cl::NDRange(parts, thread_block_size_1D, thread_block_size_1D), temp,
91  thread_block_size_1D, temp.size());
92  // spawn parts thread blocks, each responsible for one block
93  opencl_kernels::diag_inv(cl::NDRange(parts * thread_block_size_1D),
94  cl::NDRange(thread_block_size_1D), inv_padded,
95  temp, inv_padded.rows());
96  } catch (cl::Error& e) {
97  check_opencl_error("inverse step1", e);
98  }
99  // set the padded part of the matrix and the upper triangular to zeros
100  inv_padded.sub_block(zero_mat, 0, 0, inv_mat.rows(), 0, zero_mat.rows(),
101  zero_mat.cols());
103  if (parts == 1) {
104  inv_mat.sub_block(inv_padded, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows());
105  if (triangular_view == TriangularViewCL::Upper) {
106  inv_mat = transpose(inv_mat);
107  }
108  return inv_mat;
109  }
110  parts = ceil(parts / 2.0);
111 
112  auto result_matrix_dim = thread_block_size_1D;
113  auto thread_block_work2d_dim = thread_block_2D_dim / work_per_thread;
114  auto ndrange_2d
115  = cl::NDRange(thread_block_2D_dim, thread_block_work2d_dim, 1);
116  while (parts > 0) {
117  int result_matrix_dim_x = result_matrix_dim;
118  // when calculating the last submatrix
119  // we can reduce the size to the actual size (not the next power of 2)
120  if (parts == 1 && (inv_padded.rows() - result_matrix_dim * 2) < 0) {
121  result_matrix_dim_x = inv_padded.rows() - result_matrix_dim;
122  }
123  auto result_work_dim = result_matrix_dim / work_per_thread;
124  auto result_ndrange
125  = cl::NDRange(result_matrix_dim_x, result_work_dim, parts);
126  opencl_kernels::inv_lower_tri_multiply(result_ndrange, ndrange_2d,
127  inv_padded, temp, inv_padded.rows(),
128  result_matrix_dim);
130  result_ndrange, ndrange_2d, inv_padded, temp, inv_padded.rows(),
131  result_matrix_dim);
132  // if this is the last submatrix, end
133  if (parts == 1) {
134  parts = 0;
135  } else {
136  parts = ceil(parts / 2.0);
137  }
138  result_matrix_dim *= 2;
139  // set the padded part and upper diagonal to zeros
140  inv_padded.sub_block(zero_mat, 0, 0, inv_mat.rows(), 0, zero_mat.rows(),
141  zero_mat.cols());
143  }
144  // un-pad and return
145  inv_mat.sub_block(inv_padded, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows());
146  if (triangular_view == TriangularViewCL::Upper) {
147  inv_mat = transpose(inv_mat);
148  }
149  return inv_mat;
150 }
151 } // namespace math
152 } // namespace stan
153 
154 #endif
155 #endif
fvar< T > sqrt(const fvar< T > &x)
Definition: sqrt.hpp:13
void check_square(const char *function, const char *name, const matrix_cl &y)
Check if the matrix_cl is square.
The API to access the methods and values in opencl_context_base.
The matrix_cl class - allocates memory space on the OpenCL device, functions for transfering matrices...
const kernel_cl< in_out_buffer, in_buffer, int, int > neg_rect_lower_tri_multiply("neg_rect_lower_tri_multiply", {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for neg_rect_lower_tri_multiply() .
const kernel_cl< out_buffer, int, int > batch_identity("batch_identity", {indexing_helpers, batch_identity_kernel_code})
See the docs for batch_identity() .
void zeros()
Stores zeros in the matrix on the OpenCL device.
Definition: zeros.hpp:27
matrix_cl tri_inverse(const matrix_cl &A)
Computes the inverse of a triangular matrix.
Definition: tri_inverse.hpp:39
void sub_block(const matrix_cl &A, size_t A_i, size_t A_j, size_t this_i, size_t this_j, size_t nrows, size_t ncols)
Write the contents of A into this starting at the top left of this
Definition: sub_block.hpp:28
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
Represents a matrix on the OpenCL device.
Definition: matrix_cl.hpp:29
checking OpenCL error numbers
matrix_cl transpose(const matrix_cl &src)
Takes the transpose of the matrix on the OpenCL device.
Definition: transpose.hpp:20
double e()
Return the base of the natural logarithm.
Definition: constants.hpp:87
int max_thread_block_size()
Returns the maximum thread block size defined by CL_DEVICE_MAX_WORK_GROUP_SIZE for the device in the ...
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occured.
const kernel_cl< in_out_buffer, in_out_buffer, int > diag_inv("diag_inv", {indexing_helpers, diag_inv_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}})
See the docs for add() .
fvar< T > ceil(const fvar< T > &x)
Definition: ceil.hpp:12

     [ Stan Home Page ] © 2011–2018, Stan Development Team.