Skip to content
Snippets Groups Projects
eigen_utils.h 6.46 KiB
Newer Older
#include <array>
#include <unordered_set>
#include <unsupported/Eigen/CXX11/Tensor>

using Eigen::Index;
template <typename Scalar, int N>
using RTensor = Eigen::Tensor<Scalar, N, Eigen::RowMajor>;
template <int N> using BTensor = RTensor<bool, N>;
template <int N> using FTensor = RTensor<float, N>;
template <int N> using ITensor = RTensor<Index, N>;
template <int N> using IArray = std::array<Index, N>;

template <typename Nested, int ODim>
using Reshape = Eigen::TensorReshapingOp<const IArray<ODim>, const Nested>;

template <typename Nested, int ODim>
using BcastReshape = Eigen::TensorBroadcastingOp<const IArray<ODim>,
                                                 const Reshape<Nested, ODim>>;

template <int N, typename Scalar, int Dim>
Reshape<RTensor<Scalar, Dim>, Dim + N>
unsqueeze(const RTensor<Scalar, Dim> &input, const IArray<N> &dims) {
  IArray<Dim + N> reshapeTo;
  std::unordered_set<Index> dimSet(dims.begin(), dims.end());
  size_t fromDim = 0;
  for (size_t dim = 0; dim < Dim + 1; dim++) {
    if (dimSet.find(dim) == dimSet.end())
      reshapeTo[dim] = input.dimension(fromDim++);
    else
      reshapeTo[dim] = 1;
  }
  return input.reshape(reshapeTo);
}

template <size_t IDim, size_t ODim, typename TensorOp>
BcastReshape<TensorOp, ODim> broadcast(const TensorOp &iTensor,
                                       const IArray<IDim> &fromShape,
                                       const IArray<ODim> &toShape) {
  assert(ODim >= IDim);
  IArray<ODim> reshapeShape, broadcastN = toShape;
  reshapeShape.fill(1);
  size_t offset = ODim - IDim;
  for (int i = offset; i < ODim; i++) {
    Index size = fromShape[i - offset];
    reshapeShape[i] = size;
    if (size == 1)
      broadcastN[i] = toShape[i];
    else if (size == toShape[i])
      broadcastN[i] = 1;
    else
      throw std::runtime_error("Shape mismatch at dimension " +
                               std::to_string(i));
  }
  return iTensor.reshape(reshapeShape).broadcast(broadcastN);
}

template <size_t ODim, typename Scalar, int IDim>
BcastReshape<RTensor<Scalar, IDim>, ODim>
broadcast(const RTensor<Scalar, IDim> &iTensor, const IArray<ODim> &shape) {
  return broadcast(iTensor, iTensor.dimensions(), shape);
}

template <size_t ODim, typename Scalar, int IDim>
BcastReshape<Reshape<RTensor<Scalar, IDim>, IDim + 1>, ODim>
broadcast(const RTensor<Scalar, IDim> &iTensor, const IArray<ODim> &shape,
          Index unsqueezeDim) {
  auto unsqueezed = unsqueeze<1>(iTensor, {unsqueezeDim});
  return broadcast(unsqueezed, unsqueezed.dimensions(), shape);
}

FTensor<1> arange(Index from, Index to) {
  // to - 1 because LinSpaced is [from, to] closed interval.
  Eigen::ArrayXf array = Eigen::ArrayXf::LinSpaced(to - from, from, to - 1);
  return Eigen::TensorMap<FTensor<1>>(array.data(), to - from);
}

std::tuple<FTensor<1>, FTensor<1>> meshgrid(Index h, Index w) {
  IArray<1> reshapeTo({h * w});
  FTensor<1> linx = broadcast<2>(arange(0, w), {h, w}, 0).reshape(reshapeTo);
  FTensor<1> liny = broadcast<2>(arange(0, h), {h, w}, 1).reshape(reshapeTo);
  return std::make_pair(linx, liny);
}

template <int Dim> FTensor<4> softmax(const FTensor<4> &input) {
  IArray<1> dimToReduce({Dim});
  IArray<4> inputShape = input.dimensions();
  FTensor<3> maxElems = input.maximum(dimToReduce);
  FTensor<4> expInput = (input - broadcast(maxElems, inputShape, Dim)).exp();
  FTensor<3> sumExp = expInput.sum(dimToReduce);
  return expInput / broadcast(sumExp, inputShape, Dim);
}

template <int AlongDim, typename Scalar, int IDim>
RTensor<Scalar, IDim> maskSelect(const RTensor<Scalar, IDim> &input,
                                 const BTensor<1> &mask) {
  size_t nSelected = 0;
  for (Index i = 0; i < mask.dimension(0); i++)
    if (mask[i])
      nSelected += 1;
  IArray<IDim> retShape = input.dimensions();
  retShape[AlongDim] = nSelected;
  RTensor<Scalar, IDim> ret(retShape);
  for (Index i = 0, j = 0; i < mask.dimension(0); i++)
    if (mask[i]) {
      ret.chip(j, AlongDim) = input.chip(i, AlongDim);
      ++j;
    }
  return ret;
}

template <int AlongDim, typename Scalar, int IDim>
RTensor<Scalar, IDim> &maskAssign(RTensor<Scalar, IDim> &tensor,
                                  const BTensor<1> &mask,
                                  const RTensor<Scalar, IDim> &values) {
  for (Index i = 0, j = 0; i < mask.dimension(0); i++)
    if (mask[i]) {
      tensor.chip(i, AlongDim) = values.chip(j, AlongDim);
      ++j;
    }
  return tensor;
}

template <int AlongDim, typename Scalar, int IDim>
RTensor<Scalar, IDim> &maskAssign(RTensor<Scalar, IDim> &tensor,
                                  const BTensor<1> &mask, const Scalar &value) {
  for (Index i = 0, j = 0; i < mask.dimension(0); i++)
    if (mask[i]) {
      tensor.chip(i, AlongDim).setConstant(value);
      ++j;
    }
  return tensor;
}

template <typename Scalar, int IDim>
RTensor<Scalar, IDim> dimSelect(const RTensor<Scalar, IDim> &input, size_t dim,
                                Index from, Index to) {
  Index dimSize = input.dimension(dim);
  if (from < 0)
    from += dimSize;
  if (to < 0)
    to += dimSize + 1;
  IArray<IDim> slice_starts, slice_extents = input.dimensions();
  slice_starts.fill(0);
  slice_starts[dim] = from;
  slice_extents[dim] = to - from;
  return input.slice(slice_starts, slice_extents);
}

template <Index AlongDim, int IDim, int N>
FTensor<IDim> concat(const std::array<FTensor<IDim>, N> &tensors) {
  size_t outputSize = 0;
  for (const auto &tensor : tensors)
    outputSize += tensor.dimension(AlongDim);
  auto dim = tensors[0].dimensions();
  dim[AlongDim] = outputSize;

  FTensor<IDim> ret(dim);
  IArray<IDim> slice_starts, slice_extents = dim;
  slice_starts.fill(0);
  for (const auto &tensor : tensors) {
    size_t size = tensor.dimension(AlongDim);
    slice_extents[AlongDim] = size;
    ret.slice(slice_starts, slice_extents) = tensor;
    slice_starts[AlongDim] += size;
  }
  return ret;
}

template <typename OutScalar, typename InScalar, typename FuncT>
RTensor<OutScalar, 2> outerAction(const RTensor<InScalar, 1> &xs,
                                  const FuncT &func) {
  Index nElem = xs.size();
  auto lhs = broadcast<2>(xs, {nElem, nElem}, 0),
       rhs = broadcast<2>(xs, {nElem, nElem}, 1);
  return func(lhs, rhs);
}

template <typename Scalar>
RTensor<Scalar, 1> scatter1D(const RTensor<Index, 1> &indices,
                             const RTensor<Scalar, 1> &values) {
  assert(indices.size() == values.size());
  RTensor<Scalar, 1> ret(indices.size());
  for (size_t i = 0; i < indices.size(); i++)
    ret[indices[i]] = values[i];
  return ret;
}