1 #ifndef STAN_MATH_OPENCL_KERNELS_MATRIX_MULTIPLY_HPP 2 #define STAN_MATH_OPENCL_KERNELS_MATRIX_MULTIPLY_HPP 10 namespace opencl_kernels {
12 static const char* matrix_multiply_kernel_code =
STRINGIFY(
27 const __global
double* A,
const __global
double* B, __global
double* C,
28 const int M,
const int N,
const int K,
unsigned int lower_upper_A,
29 unsigned int lower_upper_B) {
31 const int thread_block_row = get_local_id(0);
32 const int thread_block_col = get_local_id(1);
34 const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
35 const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
37 const int split_id = get_global_id(2);
38 const int split_size = get_global_size(2);
40 __local
double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
41 __local
double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
43 double acc[WORK_PER_THREAD];
44 for (
int w = 0; w < WORK_PER_THREAD; w++) {
48 const int num_tiles = (K + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
57 int split_tiles = num_tiles / split_size;
58 const int split_remainder = num_tiles % split_size;
59 int split_offset_tiles = split_id * split_tiles;
60 if (split_id < split_remainder) {
61 split_offset_tiles = split_offset_tiles + split_id;
64 split_offset_tiles = split_offset_tiles + split_remainder;
82 = lower_upper_A == LOWER ? (i / THREAD_BLOCK_SIZE) : (num_tiles - 1);
84 = lower_upper_B == UPPER ? (j / THREAD_BLOCK_SIZE) : (num_tiles - 1);
85 const int start_tile_A
86 = lower_upper_A == UPPER ? (i / THREAD_BLOCK_SIZE) : 0;
87 const int start_tile_B
88 = lower_upper_B == LOWER ? (j / THREAD_BLOCK_SIZE) : 0;
94 int start_tile =
max(start_tile_A, start_tile_B);
95 start_tile =
max(start_tile, split_offset_tiles);
96 int end_tile =
min(end_tile_A, end_tile_B);
97 end_tile =
min(end_tile, split_offset_tiles + split_tiles - 1);
98 for (
int tile_idx = start_tile; tile_idx <= end_tile; tile_idx++) {
99 const int tiled_i = THREAD_BLOCK_SIZE * tile_idx + thread_block_row;
100 const int tiled_j = THREAD_BLOCK_SIZE * tile_idx + thread_block_col;
103 for (
int w = 0; w < WORK_PER_THREAD; w++) {
107 const A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
108 const B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
112 if (A_curr_j >= K || i >= M
113 || (lower_upper_A == LOWER && A_curr_j > i)
114 || (lower_upper_A == UPPER && A_curr_j < i)) {
115 A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
119 A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
121 = A[A_curr_j * M + i];
123 if (B_curr_j >= N || tiled_i >= K
124 || (lower_upper_B == LOWER && B_curr_j > tiled_i)
125 || (lower_upper_B == UPPER && B_curr_j < tiled_i)) {
126 B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
130 B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
132 = B[B_curr_j * K + tiled_i];
135 barrier(CLK_LOCAL_MEM_FENCE);
136 for (
int block_idx = 0; block_idx < THREAD_BLOCK_SIZE; block_idx++) {
137 for (
int w = 0; w < WORK_PER_THREAD; w++) {
138 acc[w] += A_local[block_idx][thread_block_row]
139 * B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
143 barrier(CLK_LOCAL_MEM_FENCE);
146 for (
int w = 0; w < WORK_PER_THREAD; w++) {
152 if ((j + w * THREAD_BLOCK_SIZE_COL) < N && i < M) {
153 C[split_id * M * N + (j + w * THREAD_BLOCK_SIZE_COL) * M + i]
169 {{
"THREAD_BLOCK_SIZE", 32}, {
"WORK_PER_THREAD", 8}});
172 static const char* matrix_vector_multiply_kernel_code =
STRINGIFY(
186 const __global
double* A,
const __global
double* B, __global
double* R,
187 const int M,
const int N,
unsigned int lower_upper_A,
188 unsigned int lower_upper_B) {
189 const int gid = get_global_id(0);
191 const int start = lower_upper_A == UPPER ? gid : 0;
193 = lower_upper_B == UPPER ? 1 : (lower_upper_A == LOWER ? gid + 1 : N);
196 for (
int i = start, j = M * start; i < stop; i++, j += M) {
197 acc += A[j + gid] * B[i];
212 matrix_vector_multiply_kernel_code);
215 static const char* row_vector_matrix_multiply_kernel_code =
STRINGIFY(
230 const __global
double* A,
const __global
double* B, __global
double* R,
231 const int N,
const int K,
unsigned int lower_upper_A,
232 unsigned int lower_upper_B) {
233 const int lid = get_local_id(0);
234 const int gid = get_global_id(0);
235 const int wgid = get_group_id(0);
237 const int start = lower_upper_B == LOWER ? wgid : 0;
238 const int stop = lower_upper_A == LOWER
240 : (lower_upper_B == UPPER) ? wgid + 1 : N;
243 for (
int i = lid + start; i < stop; i += LOCAL_SIZE_) {
244 acc += A[i] * B[i + wgid * N];
247 __local
double res_loc[LOCAL_SIZE_];
249 barrier(CLK_LOCAL_MEM_FENCE);
250 for (
int step = LOCAL_SIZE_ / REDUCTION_STEP_SIZE;
step > 0;
251 step /= REDUCTION_STEP_SIZE) {
253 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
254 res_loc[lid] += res_loc[lid +
step * i];
257 barrier(CLK_LOCAL_MEM_FENCE);
260 R[wgid] = res_loc[0];
274 row_vector_matrix_multiply_kernel_code,
275 {{
"LOCAL_SIZE_", 64},
276 {
"REDUCTION_STEP_SIZE", 4}});
int min(const std::vector< int > &x)
Returns the minimum coefficient in the specified column vector.
double step(const T &y)
The step, or Heaviside, function.
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, TriangularViewCL, TriangularViewCL > row_vector_matrix_multiply("row_vector_matrix_multiply", row_vector_matrix_multiply_kernel_code, {{"LOCAL_SIZE_", 64}, {"REDUCTION_STEP_SIZE", 4}})
See the docs for row_vector_matrix_multiply() .
int max(const std::vector< int > &x)
Returns the maximum coefficient in the specified column vector.
Creates functor for kernels.
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, int, TriangularViewCL, TriangularViewCL > matrix_multiply("matrix_multiply", {thread_block_helpers, matrix_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for matrix_multiply() .
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, TriangularViewCL, TriangularViewCL > matrix_vector_multiply("matrix_vector_multiply", matrix_vector_multiply_kernel_code)
See the docs for matrix_vector_multiply() .
static const char * thread_block_helpers