1 #ifndef STAN_MATH_OPENCL_KERNELS_INVERSE_LOWER_TRI_MULTIPLY_HPP 2 #define STAN_MATH_OPENCL_KERNELS_INVERSE_LOWER_TRI_MULTIPLY_HPP 10 namespace opencl_kernels {
12 static const char* inv_lower_tri_multiply_kernel_code =
STRINGIFY(
46 __global
double* temp,
47 const int A_rows,
const int rows) {
48 int result_matrix_id = get_global_id(2);
49 int offset = result_matrix_id * rows * 2;
50 const int thread_block_row = get_local_id(0);
51 const int thread_block_col = get_local_id(1);
52 const int global_thread_row
53 = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
54 const int global_thread_col
55 = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
57 __local
double C2_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
58 __local
double A3_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
60 double acc[WORK_PER_THREAD] = {0};
62 const int num_tiles = (rows + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
63 for (
int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
66 for (
int w = 0; w < WORK_PER_THREAD; w++) {
67 const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
68 const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
71 const int C2_global_col
72 = offset + rows + tiled_j + w * THREAD_BLOCK_SIZE_COL;
73 const int C2_global_row = offset + global_thread_row +
rows;
74 const int A3_global_col
75 = offset + global_thread_col + w * THREAD_BLOCK_SIZE_COL;
76 const int A3_global_row = tiled_i + rows + offset;
79 const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
80 const int local_row = thread_block_row;
82 if (C2_global_col <= C2_global_row && C2_global_col < A_rows
83 && C2_global_row < A_rows) {
84 C2_local[local_col][local_row]
85 = A[C2_global_col * A_rows + C2_global_row];
87 C2_local[local_col][local_row] = 0;
89 if (A3_global_col < A_rows && A3_global_row < A_rows) {
90 A3_local[local_col][local_row]
91 = A[A3_global_col * A_rows + A3_global_row];
93 A3_local[local_col][local_row] = 0.0;
97 barrier(CLK_LOCAL_MEM_FENCE);
98 for (
int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
99 for (
int w = 0; w < WORK_PER_THREAD; w++) {
100 const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
101 const int local_row = thread_block_row;
102 acc[w] += C2_local[block_ind][local_row]
103 * A3_local[local_col][block_ind];
106 barrier(CLK_LOCAL_MEM_FENCE);
109 const int batch_offset = result_matrix_id * rows *
rows;
112 const int temp_global_row = global_thread_row;
114 for (
int w = 0; w < WORK_PER_THREAD; w++) {
116 const int temp_global_col
117 = global_thread_col + w * THREAD_BLOCK_SIZE_COL;
118 temp[batch_offset + temp_global_col * rows + temp_global_row] = acc[w];
129 "inv_lower_tri_multiply",
131 {{
"THREAD_BLOCK_SIZE", 32}, {
"WORK_PER_THREAD", 8}});
int rows(const Eigen::Matrix< T, R, C > &m)
Return the number of rows in the specified matrix, vector, or row vector.
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
Creates functor for kernels.
static const char * thread_block_helpers