1 #ifndef STAN_MATH_OPENCL_TRI_INVERSE_HPP 2 #define STAN_MATH_OPENCL_TRI_INVERSE_HPP 38 template <TriangularViewCL triangular_view>
41 "tri_inverse(OpenCL) only supports triangular input matrices");
44 int thread_block_2D_dim = 32;
47 int thread_block_size_1D
48 = (((A.
rows() / 32) + thread_block_2D_dim - 1) / thread_block_2D_dim)
49 * thread_block_2D_dim;
50 if (max_1D_thread_block_size < thread_block_size_1D) {
51 thread_block_size_1D = max_1D_thread_block_size;
53 int max_2D_thread_block_dim =
sqrt(max_1D_thread_block_size);
54 if (max_2D_thread_block_dim < thread_block_2D_dim) {
55 thread_block_2D_dim = max_2D_thread_block_dim;
58 if (thread_block_size_1D < 64) {
59 thread_block_size_1D = 32;
61 if (A.
rows() < thread_block_size_1D) {
62 thread_block_size_1D = A.
rows();
67 = ((A.
rows() + thread_block_size_1D - 1) / thread_block_size_1D)
68 * thread_block_size_1D;
70 matrix_cl temp(A_rows_padded, A_rows_padded);
71 matrix_cl inv_padded(A_rows_padded, A_rows_padded);
85 int parts = inv_padded.
rows() / thread_block_size_1D;
90 cl::NDRange(parts, thread_block_size_1D, thread_block_size_1D), temp,
91 thread_block_size_1D, temp.
size());
94 cl::NDRange(thread_block_size_1D), inv_padded,
95 temp, inv_padded.
rows());
96 }
catch (cl::Error&
e) {
100 inv_padded.
sub_block(zero_mat, 0, 0, inv_mat.
rows(), 0, zero_mat.rows(),
110 parts =
ceil(parts / 2.0);
112 auto result_matrix_dim = thread_block_size_1D;
113 auto thread_block_work2d_dim = thread_block_2D_dim / work_per_thread;
115 = cl::NDRange(thread_block_2D_dim, thread_block_work2d_dim, 1);
117 int result_matrix_dim_x = result_matrix_dim;
120 if (parts == 1 && (inv_padded.
rows() - result_matrix_dim * 2) < 0) {
121 result_matrix_dim_x = inv_padded.
rows() - result_matrix_dim;
123 auto result_work_dim = result_matrix_dim / work_per_thread;
125 = cl::NDRange(result_matrix_dim_x, result_work_dim, parts);
127 inv_padded, temp, inv_padded.
rows(),
130 result_ndrange, ndrange_2d, inv_padded, temp, inv_padded.
rows(),
136 parts =
ceil(parts / 2.0);
138 result_matrix_dim *= 2;
140 inv_padded.
sub_block(zero_mat, 0, 0, inv_mat.
rows(), 0, zero_mat.rows(),
fvar< T > sqrt(const fvar< T > &x)
void check_square(const char *function, const char *name, const matrix_cl &y)
Check if the matrix_cl is square.
The API to access the methods and values in opencl_context_base.
The matrix_cl class - allocates memory space on the OpenCL device, functions for transfering matrices...
const kernel_cl< in_out_buffer, in_buffer, int, int > neg_rect_lower_tri_multiply("neg_rect_lower_tri_multiply", {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for neg_rect_lower_tri_multiply() .
const kernel_cl< out_buffer, int, int > batch_identity("batch_identity", {indexing_helpers, batch_identity_kernel_code})
See the docs for batch_identity() .
void zeros()
Stores zeros in the matrix on the OpenCL device.
matrix_cl tri_inverse(const matrix_cl &A)
Computes the inverse of a triangular matrix.
void sub_block(const matrix_cl &A, size_t A_i, size_t A_j, size_t this_i, size_t this_j, size_t nrows, size_t ncols)
Write the contents of A into this starting at the top left of this
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
Represents a matrix on the OpenCL device.
checking OpenCL error numbers
matrix_cl transpose(const matrix_cl &src)
Takes the transpose of the matrix on the OpenCL device.
double e()
Return the base of the natural logarithm.
int max_thread_block_size()
Returns the maximum thread block size defined by CL_DEVICE_MAX_WORK_GROUP_SIZE for the device in the ...
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occured.
const kernel_cl< in_out_buffer, in_out_buffer, int > diag_inv("diag_inv", {indexing_helpers, diag_inv_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}})
See the docs for add() .
fvar< T > ceil(const fvar< T > &x)