1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#ifndef STAN_MATH_OPENCL_TRIANGULAR_HPP
#define STAN_MATH_OPENCL_TRIANGULAR_HPP
#ifdef STAN_OPENCL
#include <stan/math/opencl/stringify.hpp>
#include <Eigen/Core>
#include <type_traits>
namespace stan {
namespace math {
enum class matrix_cl_view { Diagonal = 0, Lower = 1, Upper = 2, Entire = 3 };
/**
* Determines which parts are nonzero in any of the input views.
* @param left_view first view
* @param right_view second view
* @return combined view
*/
inline const matrix_cl_view either(const matrix_cl_view left_view,
const matrix_cl_view right_view) {
typedef typename std::underlying_type<matrix_cl_view>::type underlying;
return static_cast<matrix_cl_view>(static_cast<underlying>(left_view)
| static_cast<underlying>(right_view));
}
/**
* Determines which parts are nonzero in both input views.
* @param left_view first view
* @param right_view second view
* @return common nonzero part
*/
inline const matrix_cl_view both(const matrix_cl_view left_view,
const matrix_cl_view right_view) {
typedef typename std::underlying_type<matrix_cl_view>::type underlying;
return static_cast<matrix_cl_view>(static_cast<underlying>(left_view)
& static_cast<underlying>(right_view));
}
/**
* Check whether a view contains certain nonzero part
* @param view view to check
* @param part part to check for (usually `Lower` or `Upper`)
* @return true, if `view` has `part` nonzero
*/
inline bool contains_nonzero(const matrix_cl_view view,
const matrix_cl_view part) {
return static_cast<bool>(both(view, part));
}
/**
* Transposes a view - swaps lower and upper parts.
* @param view view to transpose
* @return transposition of input
*/
inline const matrix_cl_view transpose(const matrix_cl_view view) {
if (view == matrix_cl_view::Lower) {
return matrix_cl_view::Upper;
}
if (view == matrix_cl_view::Upper) {
return matrix_cl_view::Lower;
}
return view;
}
/**
* Inverts a view. Parts that are zero in the input become nonzero in
* output and vice versa.
* @param view view to invert
* @return inverted view
*/
inline const matrix_cl_view invert(const matrix_cl_view view) {
typedef typename std::underlying_type<matrix_cl_view>::type underlying;
return static_cast<matrix_cl_view>(
static_cast<underlying>(matrix_cl_view::Entire)
& ~static_cast<underlying>(view));
}
/**
* Creates a view from `Eigen::UpLoType`. `Eigen::Lower`,
* `Eigen::StrictlyLower` and `Eigen::UnitLower` become
* `PartialViewCL::Lower`. Similar for `Upper`. Any other view becomes
* `PartialViewCL::Entire`.
* @param eigen_type `UpLoType` to create a view from
* @return view
*/
inline matrix_cl_view from_eigen_uplo_type(Eigen::UpLoType eigen_type) {
if (eigen_type & Eigen::Lower) {
return matrix_cl_view::Lower;
}
if (eigen_type & Eigen::Upper) {
return matrix_cl_view::Upper;
}
return matrix_cl_view::Entire;
}
enum class TriangularMapCL { UpperToLower = 0, LowerToUpper = 1 };
// \cond
static const char* view_kernel_helpers = STRINGIFY(
// \endcond
/**
* Determines which parts are nonzero in any of the input views.
* @param left_view first view
* @param right_view second view
* @return combined view
*/
int either(int left_view, int right_view) { return left_view | right_view; }
/**
* Determines which parts are nonzero in both input views.
* @param left_view first view
* @param right_view second view
* @return common nonzero part
*/
int both(int left_view, int right_view) { return left_view & right_view; }
/**
* Check whether a view contains certain nonzero part
* @param view view to check
* @param part part to check for (usually `Lower` or `Upper`)
* @return true, if `view` has `part` nonzero
*/
bool contains_nonzero(int view, int part) { return both(view, part); }
// \cond
);
// \endcond
} // namespace math
} // namespace stan
#endif
#endif