#ifndef STAN_MATH_GPU_KERNELS_LOWER_TRI_INVERSE_STEP3_HPP
#define STAN_MATH_GPU_KERNELS_LOWER_TRI_INVERSE_STEP3_HPP
#ifdef STAN_OPENCL

#include <stan/math/gpu/kernel_cl.hpp>

namespace stan {
namespace math {
namespace opencl_kernels {
// \cond
const char* negative_rectangular_lower_triangular_multiply_kernel_code
    = STRINGIFY(
        // \endcond
        /**
         * Calculates C = -B * A where B is rectangular and A is a lower
         * triangular.
         * The full inverse requires calculation of the lower left rectangular
         * matrix within the lower left triangular C3 = -C2*A3*C1. where C2 is
         * the inverse of the bottom right lower triangular, C1 is the inverse
         * of the upper left lower and A3 is the original lower triangulars
         * lower left rectangular. This kernel performs multiplications on
         * submatrices in the input matrix A in parallel and includes
         * optimizations to account for the lower triangular input matrix A.
         *
         *
         * @param[in, out] A Input matrix that is being inverted.
         * @param[in] temp Temporary matrix with the intermediate results.
         * @param A_rows Number of rows for A.
         * @param rows The number of rows in a single matrix of the batch
         * @note Code is a <code>const char*</code> held in
         *  negative_rectangular_lower_triangular_multiply_kernel_code
         *  Used in math/gpu/lower_tri_inverse.hpp.
         *  This kernel uses the helper macros available in helpers.cl.
         */
        __kernel void negative_rectangular_lower_triangular_multiply(
            __global double* A, const __global double* temp, const int A_rows,
            const int rows) {
          int result_matrix_id = get_global_id(2);
          int offset = result_matrix_id * rows * 2;
          // thread index inside the thread_block
          const int thread_block_row = get_local_id(0);
          const int thread_block_col = get_local_id(1);
          // global thread index
          const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
          const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
          // local memory
          __local double temp_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
          __local double C1_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];

          double acc[WORK_PER_THREAD] = {0};

          const int num_tiles
              = (rows + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
          // iterate over all tiles
          for (int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
            // each thread copies WORK_PER_THREAD values to the local
            // memory
            for (int w = 0; w < WORK_PER_THREAD; w++) {
              const int tiled_i
                  = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
              const int tiled_j
                  = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
              const int temp_global_col = tiled_j + w * THREAD_BLOCK_SIZE_COL;
              const int C1_global_col = offset + j + w * THREAD_BLOCK_SIZE_COL;
              const int C1_global_row = tiled_i + offset;
              const int local_col
                  = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
              const int local_row = thread_block_row;
              if ((temp_global_col) < rows && i < rows) {
                temp_local[local_col][local_row]
                    = temp[result_matrix_id * rows * rows
                           + temp_global_col * rows + i];
              } else {
                temp_local[local_col][local_row] = 0.0;
              }
              if (C1_global_col <= C1_global_row) {
                C1_local[local_col][local_row]
                    = A[C1_global_col * A_rows + C1_global_row];
              } else {
                C1_local[local_col][local_row] = 0;
              }
            }
            // wait until all tile values are loaded to the local memory
            barrier(CLK_LOCAL_MEM_FENCE);
            for (int block_ind = 0; block_ind < THREAD_BLOCK_SIZE;
                 block_ind++) {
              for (int w = 0; w < WORK_PER_THREAD; w++) {
                const int local_col
                    = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
                const int local_row = thread_block_row;
                acc[w] += temp_local[block_ind][local_row]
                          * C1_local[local_col][block_ind];
              }
            }
            barrier(CLK_LOCAL_MEM_FENCE);
          }
          // save the values
          const int A_global_row = i + rows + offset;
          const int A_global_col_offset = offset + j;
          for (int w = 0; w < WORK_PER_THREAD; w++) {
            const int A_global_col
                = A_global_col_offset + w * THREAD_BLOCK_SIZE_COL;
            // each thread saves WORK_PER_THREAD values
            A[A_global_col * A_rows + i + rows + offset] = -acc[w];
          }
        }
        // \cond
    );
// \endcond

/**
 * See the docs for \link kernels/matrix_multiply.hpp add() \endlink
 */
const local_range_kernel<cl::Buffer, cl::Buffer, int, int>
    negative_rectangular_lower_triangular_multiply(
        "negative_rectangular_lower_triangular_multiply",
        negative_rectangular_lower_triangular_multiply_kernel_code,
        {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}});
}  // namespace opencl_kernels
}  // namespace math
}  // namespace stan
#endif
#endif