Stan Math Library  2.20.0
reverse mode automatic differentiation
multiply_transpose.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP
2 #define STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP
3 #ifdef STAN_OPENCL
4 
7 
8 namespace stan {
9 namespace math {
10 namespace opencl_kernels {
11 // \cond
12 static const char* multiply_transpose_kernel_code = STRINGIFY(
13  // \endcond
23  __kernel void multiply_transpose(const __global double* A,
24  __global double* B, const int M,
25  const int N) {
26  // thread index inside the thread block
27  const int thread_block_row = get_local_id(0);
28  const int thread_block_col = get_local_id(1);
29 
30  // global thread index
31  const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
32  const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
33 
34  // indexes that determine the last indexes that need to compute
35  // in order to remove the unnecesary multiplications in the special
36  // multiplication of A*A^T
37  const int j_min = THREAD_BLOCK_SIZE * get_group_id(1);
38  const int i_max = THREAD_BLOCK_SIZE * get_group_id(0) + get_local_size(0);
39 
40  // local memory
41  __local double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
42  __local double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
43 
44  double acc[WORK_PER_THREAD];
45  for (int w = 0; w < WORK_PER_THREAD; w++) {
46  acc[w] = 0.0;
47  }
48  if (j_min <= i_max) {
49  const int num_tiles = (N + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
50  // iterate over all tiles
51  for (int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
52  // in each tile
53  const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
54  const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
55  // if the data needs to be loaded to local memory
56  // each thread copies WORK_PER_THREAD values to the
57  // local memory
58  for (int w = 0; w < WORK_PER_THREAD; w++) {
59  const A_temp_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
60  const AT_temp_j = j + w * THREAD_BLOCK_SIZE_COL;
61  if (A_temp_j >= N || i >= M) {
62  A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
63  [thread_block_row]
64  = 0.0;
65  } else {
66  A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
67  [thread_block_row]
68  = A[A_temp_j * M + i];
69  }
70  if (AT_temp_j >= M || tiled_i >= N) {
71  B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
72  [thread_block_row]
73  = 0.0;
74  } else {
75  B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
76  [thread_block_row]
77  = A[AT_temp_j + tiled_i * M];
78  }
79  }
80  // wait till all tile values are loaded to the local memory
81  barrier(CLK_LOCAL_MEM_FENCE);
82  // multiply the tile products
83  for (int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
84  // each thread multiplies WORK_PER_THREAD values
85  for (int w = 0; w < WORK_PER_THREAD; w++) {
86  if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
87  acc[w] += A_local[block_ind][thread_block_row]
88  * B_local[thread_block_col
89  + w * THREAD_BLOCK_SIZE_COL][block_ind];
90  }
91  }
92  }
93  barrier(CLK_LOCAL_MEM_FENCE);
94  }
95  // each thread saves WORK_PER_THREAD values to C
96  for (int w = 0; w < WORK_PER_THREAD; w++) {
97  // This prevents threads from accessing elements
98  // outside the allocated memory for C. The check
99  // is in the loop because some threads
100  // can be assigned elements in and out of
101  // the allocated memory.
102  if ((j + w * THREAD_BLOCK_SIZE_COL) < M && i < M) {
103  if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
104  B[i + (j + w * THREAD_BLOCK_SIZE_COL) * M] = acc[w];
105  B[(j + w * THREAD_BLOCK_SIZE_COL) + i * M] = acc[w];
106  }
107  }
108  }
109  }
110  }
111  // \cond
112 );
113 // \endcond
114 
119  "multiply_transpose",
120  {thread_block_helpers, multiply_transpose_kernel_code},
121  {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 4}});
122 
123 } // namespace opencl_kernels
124 } // namespace math
125 } // namespace stan
126 #endif
127 #endif
#define STRINGIFY(src)
Definition: kernel_cl.hpp:22
const kernel_cl< in_buffer, out_buffer, int, int > multiply_transpose("multiply_transpose", {thread_block_helpers, multiply_transpose_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 4}})
See the docs for add() .
Creates functor for kernels.
Definition: kernel_cl.hpp:201
static const char * thread_block_helpers
Definition: helpers.hpp:48

     [ Stan Home Page ] © 2011–2018, Stan Development Team.