Stan Math Library  2.20.0
reverse mode automatic differentiation
matrix_multiply.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_OPENCL_KERNELS_MATRIX_MULTIPLY_HPP
2 #define STAN_MATH_OPENCL_KERNELS_MATRIX_MULTIPLY_HPP
3 #ifdef STAN_OPENCL
4 
7 
8 namespace stan {
9 namespace math {
10 namespace opencl_kernels {
11 // \cond
12 static const char* matrix_multiply_kernel_code = STRINGIFY(
13  // \endcond
26  __kernel void matrix_multiply(
27  const __global double* A, const __global double* B, __global double* C,
28  const int M, const int N, const int K, unsigned int lower_upper_A,
29  unsigned int lower_upper_B) {
30  // thread index inside the thread_block
31  const int thread_block_row = get_local_id(0);
32  const int thread_block_col = get_local_id(1);
33  // global thread index
34  const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
35  const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
36  // identify if the matrix multiply is split
37  const int split_id = get_global_id(2);
38  const int split_size = get_global_size(2);
39  // local memory
40  __local double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
41  __local double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
42 
43  double acc[WORK_PER_THREAD];
44  for (int w = 0; w < WORK_PER_THREAD; w++) {
45  acc[w] = 0.0;
46  }
47  // the number of tiles for each scalar product in the matrix mulitply
48  const int num_tiles = (K + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
49  // in case of splitting the matrix multiply we need
50  // use split_offset_tiles the threads assigned part
51  // of the scalar products, while the split_tiles
52  // determines the number of tiles a thread multiplies
53  // if split_size = 1, each thread calculates the
54  // the entire scalar product for all assigned
55  // elements of the resulting matrix, meaning that
56  // split_offset_tiles is 0 and split_tiles = num_tiles
57  int split_tiles = num_tiles / split_size;
58  const int split_remainder = num_tiles % split_size;
59  int split_offset_tiles = split_id * split_tiles;
60  if (split_id < split_remainder) {
61  split_offset_tiles = split_offset_tiles + split_id;
62  split_tiles++;
63  } else {
64  split_offset_tiles = split_offset_tiles + split_remainder;
65  }
66  // This kernel is based on the well known
67  // general matrix multiplication kernels that
68  // use tiling for shared memory
69  // In cases where a matrix is lower triangular
70  // its not necessary to multiply the elements
71  // over the diagonal, therefore those tiles
72  // in the matrix multiply can be skipped.
73  // With upper triangular matrices we dont need
74  // to multiply the elements under the diagonal,
75  // so those tiles can be skipped.
76  // The following code determines the start and
77  // end tile based on triangularity of the input matrices
78  // If no matrices are triangular the starting tile
79  // is 0 and the end tile is num_tiles-1 which
80  // is then a general matrix multiply
81  const int end_tile_A
82  = lower_upper_A == LOWER ? (i / THREAD_BLOCK_SIZE) : (num_tiles - 1);
83  const int end_tile_B
84  = lower_upper_B == UPPER ? (j / THREAD_BLOCK_SIZE) : (num_tiles - 1);
85  const int start_tile_A
86  = lower_upper_A == UPPER ? (i / THREAD_BLOCK_SIZE) : 0;
87  const int start_tile_B
88  = lower_upper_B == LOWER ? (j / THREAD_BLOCK_SIZE) : 0;
89  // the starting and end tiles for a thread are determined by
90  // split_offset_tiles and split_tiles. If the input matrix is
91  // triangular some tiles can be skipped in which case we
92  // either start the scalar product at larger cols/rows
93  // or end them at smaller cols/rows.
94  int start_tile = max(start_tile_A, start_tile_B);
95  start_tile = max(start_tile, split_offset_tiles);
96  int end_tile = min(end_tile_A, end_tile_B); // NOLINT
97  end_tile = min(end_tile, split_offset_tiles + split_tiles - 1); // NOLINT
98  for (int tile_idx = start_tile; tile_idx <= end_tile; tile_idx++) {
99  const int tiled_i = THREAD_BLOCK_SIZE * tile_idx + thread_block_row;
100  const int tiled_j = THREAD_BLOCK_SIZE * tile_idx + thread_block_col;
101  // each thread copies WORK_PER_THREAD values to the local
102  // memory
103  for (int w = 0; w < WORK_PER_THREAD; w++) {
104  // For the tiles on the diagonal we can ignore the values over
105  // the diagonal if the matrix is lower triangular or under
106  // the diagonal if the matrix is upper triangular
107  const A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
108  const B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
109  // check if the indexes are outside the matrix
110  // or under/above the diagonal with upper/lower
111  // triangular matrices
112  if (A_curr_j >= K || i >= M
113  || (lower_upper_A == LOWER && A_curr_j > i)
114  || (lower_upper_A == UPPER && A_curr_j < i)) {
115  A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
116  [thread_block_row]
117  = 0.0;
118  } else {
119  A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
120  [thread_block_row]
121  = A[A_curr_j * M + i];
122  }
123  if (B_curr_j >= N || tiled_i >= K
124  || (lower_upper_B == LOWER && B_curr_j > tiled_i)
125  || (lower_upper_B == UPPER && B_curr_j < tiled_i)) {
126  B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
127  [thread_block_row]
128  = 0.0;
129  } else {
130  B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
131  [thread_block_row]
132  = B[B_curr_j * K + tiled_i];
133  }
134  }
135  barrier(CLK_LOCAL_MEM_FENCE);
136  for (int block_idx = 0; block_idx < THREAD_BLOCK_SIZE; block_idx++) {
137  for (int w = 0; w < WORK_PER_THREAD; w++) {
138  acc[w] += A_local[block_idx][thread_block_row]
139  * B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
140  [block_idx];
141  }
142  }
143  barrier(CLK_LOCAL_MEM_FENCE);
144  }
145  // each thread saves WORK_PER_THREAD values
146  for (int w = 0; w < WORK_PER_THREAD; w++) {
147  // This prevents threads from accessing elements
148  // outside the allocated memory for C. The check
149  // is in the loop because some threads
150  // can be assigned elements in and out of
151  // the allocated memory.
152  if ((j + w * THREAD_BLOCK_SIZE_COL) < N && i < M) {
153  C[split_id * M * N + (j + w * THREAD_BLOCK_SIZE_COL) * M + i]
154  = acc[w];
155  }
156  }
157  }
158  // \cond
159 );
160 // \endcond
161 
165 const kernel_cl<in_buffer, in_buffer, out_buffer, int, int, int,
166  TriangularViewCL, TriangularViewCL>
167  matrix_multiply("matrix_multiply",
168  {thread_block_helpers, matrix_multiply_kernel_code},
169  {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}});
170 
171 // \cond
172 static const char* matrix_vector_multiply_kernel_code = STRINGIFY(
173  // \endcond
185  __kernel void matrix_vector_multiply(
186  const __global double* A, const __global double* B, __global double* R,
187  const int M, const int N, unsigned int lower_upper_A,
188  unsigned int lower_upper_B) {
189  const int gid = get_global_id(0);
190 
191  const int start = lower_upper_A == UPPER ? gid : 0;
192  const int stop
193  = lower_upper_B == UPPER ? 1 : (lower_upper_A == LOWER ? gid + 1 : N);
194 
195  double acc = 0;
196  for (int i = start, j = M * start; i < stop; i++, j += M) {
197  acc += A[j + gid] * B[i];
198  }
199  R[gid] = acc;
200  }
201  // \cond
202 );
203 // \endcond
204 
209 const kernel_cl<in_buffer, in_buffer, out_buffer, int, int, TriangularViewCL,
210  TriangularViewCL>
211  matrix_vector_multiply("matrix_vector_multiply",
212  matrix_vector_multiply_kernel_code);
213 
214 // \cond
215 static const char* row_vector_matrix_multiply_kernel_code = STRINGIFY(
216  // \endcond
230  const __global double* A, const __global double* B, __global double* R,
231  const int N, const int K, unsigned int lower_upper_A,
232  unsigned int lower_upper_B) {
233  const int lid = get_local_id(0);
234  const int gid = get_global_id(0);
235  const int wgid = get_group_id(0);
236 
237  const int start = lower_upper_B == LOWER ? wgid : 0;
238  const int stop = lower_upper_A == LOWER
239  ? 1
240  : (lower_upper_B == UPPER) ? wgid + 1 : N;
241 
242  double acc = 0;
243  for (int i = lid + start; i < stop; i += LOCAL_SIZE_) {
244  acc += A[i] * B[i + wgid * N];
245  }
246 
247  __local double res_loc[LOCAL_SIZE_];
248  res_loc[lid] = acc;
249  barrier(CLK_LOCAL_MEM_FENCE);
250  for (int step = LOCAL_SIZE_ / REDUCTION_STEP_SIZE; step > 0;
251  step /= REDUCTION_STEP_SIZE) {
252  if (lid < step) {
253  for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
254  res_loc[lid] += res_loc[lid + step * i];
255  }
256  }
257  barrier(CLK_LOCAL_MEM_FENCE);
258  }
259  if (lid == 0) {
260  R[wgid] = res_loc[0];
261  }
262  }
263  // \cond
264 );
265 // \endcond
266 
271 const kernel_cl<in_buffer, in_buffer, out_buffer, int, int, TriangularViewCL,
272  TriangularViewCL>
273  row_vector_matrix_multiply("row_vector_matrix_multiply",
274  row_vector_matrix_multiply_kernel_code,
275  {{"LOCAL_SIZE_", 64},
276  {"REDUCTION_STEP_SIZE", 4}});
277 
278 } // namespace opencl_kernels
279 } // namespace math
280 } // namespace stan
281 #endif
282 #endif
int min(const std::vector< int > &x)
Returns the minimum coefficient in the specified column vector.
Definition: min.hpp:20
#define STRINGIFY(src)
Definition: kernel_cl.hpp:22
double step(const T &y)
The step, or Heaviside, function.
Definition: step.hpp:29
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, TriangularViewCL, TriangularViewCL > row_vector_matrix_multiply("row_vector_matrix_multiply", row_vector_matrix_multiply_kernel_code, {{"LOCAL_SIZE_", 64}, {"REDUCTION_STEP_SIZE", 4}})
See the docs for row_vector_matrix_multiply() .
int max(const std::vector< int > &x)
Returns the maximum coefficient in the specified column vector.
Definition: max.hpp:21
Creates functor for kernels.
Definition: kernel_cl.hpp:201
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, int, TriangularViewCL, TriangularViewCL > matrix_multiply("matrix_multiply", {thread_block_helpers, matrix_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for matrix_multiply() .
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, TriangularViewCL, TriangularViewCL > matrix_vector_multiply("matrix_vector_multiply", matrix_vector_multiply_kernel_code)
See the docs for matrix_vector_multiply() .
static const char * thread_block_helpers
Definition: helpers.hpp:48

     [ Stan Home Page ] © 2011–2018, Stan Development Team.