Stan Math Library  2.20.0
reverse mode automatic differentiation
neg_rect_lower_tri_multiply.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_OPENCL_KERNELS_NEGATIVE_RECT_LOWER_TRI_MULTIPLY_HPP
2 #define STAN_MATH_OPENCL_KERNELS_NEGATIVE_RECT_LOWER_TRI_MULTIPLY_HPP
3 #ifdef STAN_OPENCL
4 
7 
8 namespace stan {
9 namespace math {
10 namespace opencl_kernels {
11 // \cond
12 static const char* neg_rect_lower_tri_multiply_kernel_code = STRINGIFY(
13  // \endcond
40  __global double* A, const __global double* temp, const int A_rows,
41  const int rows) {
42  int result_matrix_id = get_global_id(2);
43  int offset = result_matrix_id * rows * 2;
44  const int thread_block_row = get_local_id(0);
45  const int thread_block_col = get_local_id(1);
46  const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
47  const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
48 
49  __local double temp_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
50  __local double C1_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
51 
52  double acc[WORK_PER_THREAD] = {0};
53 
54  const int num_tiles = (rows + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
55  for (int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
56  // each thread copies WORK_PER_THREAD values to the local
57  // memory
58  for (int w = 0; w < WORK_PER_THREAD; w++) {
59  const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
60  const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
61  const int temp_global_col = tiled_j + w * THREAD_BLOCK_SIZE_COL;
62  // {C2}{A2}_global_{col}{row} specifies which global element for each
63  // matrix the thread is in charge of moving to local memory.
64  const int C1_global_col = offset + j + w * THREAD_BLOCK_SIZE_COL;
65  const int C1_global_row = tiled_i + offset;
66  // Which {col}{row} location in the local memory the thread is in
67  // charge of.
68  const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
69  const int local_row = thread_block_row;
70  if ((temp_global_col) < rows && i < rows) {
71  temp_local[local_col][local_row]
72  = temp[result_matrix_id * rows * rows + temp_global_col * rows
73  + i];
74  } else {
75  temp_local[local_col][local_row] = 0.0;
76  }
77  // Element above the diagonal will not be transferred.
78  if (C1_global_col <= C1_global_row && C1_global_col < A_rows
79  && C1_global_row < A_rows) {
80  C1_local[local_col][local_row]
81  = A[C1_global_col * A_rows + C1_global_row];
82  } else {
83  C1_local[local_col][local_row] = 0;
84  }
85  }
86  // wait until all tile values are loaded to the local memory
87  barrier(CLK_LOCAL_MEM_FENCE);
88  for (int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
89  for (int w = 0; w < WORK_PER_THREAD; w++) {
90  // Which {col}{row} location in the local memory the thread is in
91  // charge of.
92  const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
93  const int local_row = thread_block_row;
94  acc[w] += temp_local[block_ind][local_row]
95  * C1_local[local_col][block_ind];
96  }
97  }
98  barrier(CLK_LOCAL_MEM_FENCE);
99  }
100  // A_global_{row}{col} tells the thread which local memory it needs
101  // to move to the final output
102  const int A_global_row = i + rows + offset;
103  const int A_global_col_offset = offset + j;
104  // each thread saves WORK_PER_THREAD values
105  for (int w = 0; w < WORK_PER_THREAD; w++) {
106  const int A_global_col
107  = A_global_col_offset + w * THREAD_BLOCK_SIZE_COL;
108  if (A_global_col < A_rows && (i + rows + offset) < A_rows) {
109  A[A_global_col * A_rows + i + rows + offset] = -acc[w];
110  }
111  }
112  }
113  // \cond
114 );
115 // \endcond
116 
123  "neg_rect_lower_tri_multiply",
124  {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code},
125  {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}});
126 } // namespace opencl_kernels
127 } // namespace math
128 } // namespace stan
129 #endif
130 #endif
int rows(const Eigen::Matrix< T, R, C > &m)
Return the number of rows in the specified matrix, vector, or row vector.
Definition: rows.hpp:20
#define STRINGIFY(src)
Definition: kernel_cl.hpp:22
const kernel_cl< in_out_buffer, in_buffer, int, int > neg_rect_lower_tri_multiply("neg_rect_lower_tri_multiply", {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for neg_rect_lower_tri_multiply() .
Creates functor for kernels.
Definition: kernel_cl.hpp:201
static const char * thread_block_helpers
Definition: helpers.hpp:48

     [ Stan Home Page ] © 2011–2018, Stan Development Team.