Newer
Older
#include <array>
#include <unordered_set>
#include <unsupported/Eigen/CXX11/Tensor>
using Eigen::Index;
template <typename Scalar, int N>
using RTensor = Eigen::Tensor<Scalar, N, Eigen::RowMajor>;
template <int N> using BTensor = RTensor<bool, N>;
template <int N> using FTensor = RTensor<float, N>;
template <int N> using ITensor = RTensor<Index, N>;
template <int N> using IArray = std::array<Index, N>;
template <typename Nested, int ODim>
using Reshape = Eigen::TensorReshapingOp<const IArray<ODim>, const Nested>;
template <typename Nested, int ODim>
using BcastReshape = Eigen::TensorBroadcastingOp<const IArray<ODim>,
const Reshape<Nested, ODim>>;
template <int N, typename Scalar, int Dim>
Reshape<RTensor<Scalar, Dim>, Dim + N>
unsqueeze(const RTensor<Scalar, Dim> &input, const IArray<N> &dims) {
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
IArray<Dim + N> reshapeTo;
std::unordered_set<Index> dimSet(dims.begin(), dims.end());
size_t fromDim = 0;
for (size_t dim = 0; dim < Dim + 1; dim++) {
if (dimSet.find(dim) == dimSet.end())
reshapeTo[dim] = input.dimension(fromDim++);
else
reshapeTo[dim] = 1;
}
return input.reshape(reshapeTo);
}
template <size_t IDim, size_t ODim, typename TensorOp>
BcastReshape<TensorOp, ODim> broadcast(const TensorOp &iTensor,
const IArray<IDim> &fromShape,
const IArray<ODim> &toShape) {
assert(ODim >= IDim);
IArray<ODim> reshapeShape, broadcastN = toShape;
reshapeShape.fill(1);
size_t offset = ODim - IDim;
for (int i = offset; i < ODim; i++) {
Index size = fromShape[i - offset];
reshapeShape[i] = size;
if (size == 1)
broadcastN[i] = toShape[i];
else if (size == toShape[i])
broadcastN[i] = 1;
else
throw std::runtime_error("Shape mismatch at dimension " +
std::to_string(i));
}
return iTensor.reshape(reshapeShape).broadcast(broadcastN);
}
template <size_t ODim, typename Scalar, int IDim>
BcastReshape<RTensor<Scalar, IDim>, ODim>
broadcast(const RTensor<Scalar, IDim> &iTensor, const IArray<ODim> &shape) {
return broadcast(iTensor, iTensor.dimensions(), shape);
}
template <size_t ODim, typename Scalar, int IDim>
BcastReshape<Reshape<RTensor<Scalar, IDim>, IDim + 1>, ODim>
broadcast(const RTensor<Scalar, IDim> &iTensor, const IArray<ODim> &shape,
Index unsqueezeDim) {
auto unsqueezed = unsqueeze<1>(iTensor, {unsqueezeDim});
return broadcast(unsqueezed, unsqueezed.dimensions(), shape);
}
FTensor<1> arange(Index from, Index to) {
// to - 1 because LinSpaced is [from, to] closed interval.
Eigen::ArrayXf array = Eigen::ArrayXf::LinSpaced(to - from, from, to - 1);
return Eigen::TensorMap<FTensor<1>>(array.data(), to - from);
}
std::tuple<FTensor<1>, FTensor<1>> meshgrid(Index h, Index w) {
IArray<1> reshapeTo({h * w});
FTensor<1> linx = broadcast<2>(arange(0, w), {h, w}, 0).reshape(reshapeTo);
FTensor<1> liny = broadcast<2>(arange(0, h), {h, w}, 1).reshape(reshapeTo);
return std::make_pair(linx, liny);
}
template <int Dim> FTensor<4> softmax(const FTensor<4> &input) {
IArray<1> dimToReduce({Dim});
IArray<4> inputShape = input.dimensions();
FTensor<3> maxElems = input.maximum(dimToReduce);
FTensor<4> expInput = (input - broadcast(maxElems, inputShape, Dim)).exp();
FTensor<3> sumExp = expInput.sum(dimToReduce);
return expInput / broadcast(sumExp, inputShape, Dim);
}
template <int AlongDim, typename Scalar, int IDim>
RTensor<Scalar, IDim> maskSelect(const RTensor<Scalar, IDim> &input,
const BTensor<1> &mask) {
size_t nSelected = 0;
for (Index i = 0; i < mask.dimension(0); i++)
if (mask[i])
nSelected += 1;
IArray<IDim> retShape = input.dimensions();
retShape[AlongDim] = nSelected;
RTensor<Scalar, IDim> ret(retShape);
for (Index i = 0, j = 0; i < mask.dimension(0); i++)
if (mask[i]) {
ret.chip(j, AlongDim) = input.chip(i, AlongDim);
++j;
}
return ret;
}
template <int AlongDim, typename Scalar, int IDim>
RTensor<Scalar, IDim> &maskAssign(RTensor<Scalar, IDim> &tensor,
const BTensor<1> &mask,
const RTensor<Scalar, IDim> &values) {
for (Index i = 0, j = 0; i < mask.dimension(0); i++)
if (mask[i]) {
tensor.chip(i, AlongDim) = values.chip(j, AlongDim);
++j;
}
return tensor;
}
template <int AlongDim, typename Scalar, int IDim>
RTensor<Scalar, IDim> &maskAssign(RTensor<Scalar, IDim> &tensor,
const BTensor<1> &mask, const Scalar &value) {
for (Index i = 0, j = 0; i < mask.dimension(0); i++)
if (mask[i]) {
tensor.chip(i, AlongDim).setConstant(value);
++j;
}
return tensor;
}
template <typename Scalar, int IDim>
RTensor<Scalar, IDim> dimSelect(const RTensor<Scalar, IDim> &input, size_t dim,
Index from, Index to) {
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
Index dimSize = input.dimension(dim);
if (from < 0)
from += dimSize;
if (to < 0)
to += dimSize + 1;
IArray<IDim> slice_starts, slice_extents = input.dimensions();
slice_starts.fill(0);
slice_starts[dim] = from;
slice_extents[dim] = to - from;
return input.slice(slice_starts, slice_extents);
}
template <Index AlongDim, int IDim, int N>
FTensor<IDim> concat(const std::array<FTensor<IDim>, N> &tensors) {
size_t outputSize = 0;
for (const auto &tensor : tensors)
outputSize += tensor.dimension(AlongDim);
auto dim = tensors[0].dimensions();
dim[AlongDim] = outputSize;
FTensor<IDim> ret(dim);
IArray<IDim> slice_starts, slice_extents = dim;
slice_starts.fill(0);
for (const auto &tensor : tensors) {
size_t size = tensor.dimension(AlongDim);
slice_extents[AlongDim] = size;
ret.slice(slice_starts, slice_extents) = tensor;
slice_starts[AlongDim] += size;
}
return ret;
}
template <typename OutScalar, typename InScalar, typename FuncT>
RTensor<OutScalar, 2> outerAction(const RTensor<InScalar, 1> &xs,
const FuncT &func) {
Index nElem = xs.size();
auto lhs = broadcast<2>(xs, {nElem, nElem}, 0),
rhs = broadcast<2>(xs, {nElem, nElem}, 1);
return func(lhs, rhs);
}
template <typename Scalar>
RTensor<Scalar, 1> scatter1D(const RTensor<Index, 1> &indices,
const RTensor<Scalar, 1> &values) {
assert(indices.size() == values.size());
RTensor<Scalar, 1> ret(indices.size());
for (size_t i = 0; i < indices.size(); i++)
ret[indices[i]] = values[i];
return ret;
}