1 #ifndef STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP 2 #define STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP 10 namespace opencl_kernels {
12 static const char* multiply_transpose_kernel_code =
STRINGIFY(
24 __global
double* B,
const int M,
27 const int thread_block_row = get_local_id(0);
28 const int thread_block_col = get_local_id(1);
31 const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
32 const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
37 const int j_min = THREAD_BLOCK_SIZE * get_group_id(1);
38 const int i_max = THREAD_BLOCK_SIZE * get_group_id(0) + get_local_size(0);
41 __local
double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
42 __local
double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
44 double acc[WORK_PER_THREAD];
45 for (
int w = 0; w < WORK_PER_THREAD; w++) {
49 const int num_tiles = (N + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
51 for (
int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
53 const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
54 const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
58 for (
int w = 0; w < WORK_PER_THREAD; w++) {
59 const A_temp_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
60 const AT_temp_j = j + w * THREAD_BLOCK_SIZE_COL;
61 if (A_temp_j >= N || i >= M) {
62 A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
66 A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
68 = A[A_temp_j * M + i];
70 if (AT_temp_j >= M || tiled_i >= N) {
71 B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
75 B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
77 = A[AT_temp_j + tiled_i * M];
81 barrier(CLK_LOCAL_MEM_FENCE);
83 for (
int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
85 for (
int w = 0; w < WORK_PER_THREAD; w++) {
86 if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
87 acc[w] += A_local[block_ind][thread_block_row]
88 * B_local[thread_block_col
89 + w * THREAD_BLOCK_SIZE_COL][block_ind];
93 barrier(CLK_LOCAL_MEM_FENCE);
96 for (
int w = 0; w < WORK_PER_THREAD; w++) {
102 if ((j + w * THREAD_BLOCK_SIZE_COL) < M && i < M) {
103 if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
104 B[i + (j + w * THREAD_BLOCK_SIZE_COL) * M] = acc[w];
105 B[(j + w * THREAD_BLOCK_SIZE_COL) + i * M] = acc[w];
119 "multiply_transpose",
121 {{
"THREAD_BLOCK_SIZE", 32}, {
"WORK_PER_THREAD", 4}});
const kernel_cl< in_buffer, out_buffer, int, int > multiply_transpose("multiply_transpose", {thread_block_helpers, multiply_transpose_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 4}})
See the docs for add() .
Creates functor for kernels.
static const char * thread_block_helpers