Stan Math Library  2.20.0
reverse mode automatic differentiation
inv_lower_tri_multiply.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_OPENCL_KERNELS_INVERSE_LOWER_TRI_MULTIPLY_HPP
2 #define STAN_MATH_OPENCL_KERNELS_INVERSE_LOWER_TRI_MULTIPLY_HPP
3 #ifdef STAN_OPENCL
4 
7 
8 namespace stan {
9 namespace math {
10 namespace opencl_kernels {
11 // \cond
12 static const char* inv_lower_tri_multiply_kernel_code = STRINGIFY(
13  // \endcond
45  __kernel void inv_lower_tri_multiply(__global double* A,
46  __global double* temp,
47  const int A_rows, const int rows) {
48  int result_matrix_id = get_global_id(2);
49  int offset = result_matrix_id * rows * 2;
50  const int thread_block_row = get_local_id(0);
51  const int thread_block_col = get_local_id(1);
52  const int global_thread_row
53  = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
54  const int global_thread_col
55  = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
56 
57  __local double C2_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
58  __local double A3_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
59 
60  double acc[WORK_PER_THREAD] = {0};
61 
62  const int num_tiles = (rows + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
63  for (int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
64  // Each thread copies WORK_PER_THREAD values to the local
65  // memory
66  for (int w = 0; w < WORK_PER_THREAD; w++) {
67  const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
68  const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
69  // {C2}{A2}_global_{col}{row} specifies which global element for each
70  // matrix the thread is in charge of moving to local memory.
71  const int C2_global_col
72  = offset + rows + tiled_j + w * THREAD_BLOCK_SIZE_COL;
73  const int C2_global_row = offset + global_thread_row + rows;
74  const int A3_global_col
75  = offset + global_thread_col + w * THREAD_BLOCK_SIZE_COL;
76  const int A3_global_row = tiled_i + rows + offset;
77  // Which {col}{row} location in the local memory the thread is in
78  // charge of.
79  const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
80  const int local_row = thread_block_row;
81  // Element above the diagonal will not be transferred.
82  if (C2_global_col <= C2_global_row && C2_global_col < A_rows
83  && C2_global_row < A_rows) {
84  C2_local[local_col][local_row]
85  = A[C2_global_col * A_rows + C2_global_row];
86  } else {
87  C2_local[local_col][local_row] = 0;
88  }
89  if (A3_global_col < A_rows && A3_global_row < A_rows) {
90  A3_local[local_col][local_row]
91  = A[A3_global_col * A_rows + A3_global_row];
92  } else {
93  A3_local[local_col][local_row] = 0.0;
94  }
95  }
96  // Wait until all tile values are loaded to the local memory
97  barrier(CLK_LOCAL_MEM_FENCE);
98  for (int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
99  for (int w = 0; w < WORK_PER_THREAD; w++) {
100  const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
101  const int local_row = thread_block_row;
102  acc[w] += C2_local[block_ind][local_row]
103  * A3_local[local_col][block_ind];
104  }
105  }
106  barrier(CLK_LOCAL_MEM_FENCE);
107  }
108  // Global offset for each resulting submatrix
109  const int batch_offset = result_matrix_id * rows * rows;
110  // temp_global_{row}{col} tells the thread which local memory it needs
111  // to move to the final output
112  const int temp_global_row = global_thread_row;
113  // save the values
114  for (int w = 0; w < WORK_PER_THREAD; w++) {
115  // each thread saves WORK_PER_THREAD values
116  const int temp_global_col
117  = global_thread_col + w * THREAD_BLOCK_SIZE_COL;
118  temp[batch_offset + temp_global_col * rows + temp_global_row] = acc[w];
119  }
120  }
121  // \cond
122 );
123 // \endcond
124 
129  "inv_lower_tri_multiply",
130  {thread_block_helpers, inv_lower_tri_multiply_kernel_code},
131  {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}});
132 
133 } // namespace opencl_kernels
134 } // namespace math
135 } // namespace stan
136 #endif
137 #endif
int rows(const Eigen::Matrix< T, R, C > &m)
Return the number of rows in the specified matrix, vector, or row vector.
Definition: rows.hpp:20
#define STRINGIFY(src)
Definition: kernel_cl.hpp:22
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
Creates functor for kernels.
Definition: kernel_cl.hpp:201
static const char * thread_block_helpers
Definition: helpers.hpp:48

     [ Stan Home Page ] © 2011–2018, Stan Development Team.