1 #ifndef STAN_MATH_OPENCL_KERNELS_NEGATIVE_RECT_LOWER_TRI_MULTIPLY_HPP 2 #define STAN_MATH_OPENCL_KERNELS_NEGATIVE_RECT_LOWER_TRI_MULTIPLY_HPP 10 namespace opencl_kernels {
12 static const char* neg_rect_lower_tri_multiply_kernel_code =
STRINGIFY(
40 __global
double* A,
const __global
double* temp,
const int A_rows,
42 int result_matrix_id = get_global_id(2);
43 int offset = result_matrix_id * rows * 2;
44 const int thread_block_row = get_local_id(0);
45 const int thread_block_col = get_local_id(1);
46 const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
47 const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
49 __local
double temp_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
50 __local
double C1_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
52 double acc[WORK_PER_THREAD] = {0};
54 const int num_tiles = (rows + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
55 for (
int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
58 for (
int w = 0; w < WORK_PER_THREAD; w++) {
59 const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
60 const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
61 const int temp_global_col = tiled_j + w * THREAD_BLOCK_SIZE_COL;
64 const int C1_global_col = offset + j + w * THREAD_BLOCK_SIZE_COL;
65 const int C1_global_row = tiled_i + offset;
68 const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
69 const int local_row = thread_block_row;
70 if ((temp_global_col) < rows && i <
rows) {
71 temp_local[local_col][local_row]
72 = temp[result_matrix_id * rows * rows + temp_global_col * rows
75 temp_local[local_col][local_row] = 0.0;
78 if (C1_global_col <= C1_global_row && C1_global_col < A_rows
79 && C1_global_row < A_rows) {
80 C1_local[local_col][local_row]
81 = A[C1_global_col * A_rows + C1_global_row];
83 C1_local[local_col][local_row] = 0;
87 barrier(CLK_LOCAL_MEM_FENCE);
88 for (
int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
89 for (
int w = 0; w < WORK_PER_THREAD; w++) {
92 const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
93 const int local_row = thread_block_row;
94 acc[w] += temp_local[block_ind][local_row]
95 * C1_local[local_col][block_ind];
98 barrier(CLK_LOCAL_MEM_FENCE);
102 const int A_global_row = i + rows + offset;
103 const int A_global_col_offset = offset + j;
105 for (
int w = 0; w < WORK_PER_THREAD; w++) {
106 const int A_global_col
107 = A_global_col_offset + w * THREAD_BLOCK_SIZE_COL;
108 if (A_global_col < A_rows && (i + rows + offset) < A_rows) {
109 A[A_global_col * A_rows + i + rows + offset] = -acc[w];
123 "neg_rect_lower_tri_multiply",
125 {{
"THREAD_BLOCK_SIZE", 32}, {
"WORK_PER_THREAD", 8}});
int rows(const Eigen::Matrix< T, R, C > &m)
Return the number of rows in the specified matrix, vector, or row vector.
const kernel_cl< in_out_buffer, in_buffer, int, int > neg_rect_lower_tri_multiply("neg_rect_lower_tri_multiply", {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for neg_rect_lower_tri_multiply() .
Creates functor for kernels.
static const char * thread_block_helpers