diff --git a/llvm/projects/hpvm-tensor-rt/CMakeLists.txt b/llvm/projects/hpvm-tensor-rt/CMakeLists.txt
index ac07b17c2372cb0fa6a11409b79b23685da13ffd..d980429560da047f4fdb7dab8eb228e3d1b3419f 100644
--- a/llvm/projects/hpvm-tensor-rt/CMakeLists.txt
+++ b/llvm/projects/hpvm-tensor-rt/CMakeLists.txt
@@ -4,17 +4,26 @@ project (cudnn-training)
 find_package(CUDA 6.5 REQUIRED)
 
 
+set(
+  CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};
+  -gencode;arch=compute_60,code=sm_60;
+  -gencode;arch=compute_60,code=compute_60;
+  -std=c++14 --expt-relaxed-constexpr # These are newly added
+)
+
 if (CMAKE_BUILD_TYPE STREQUAL "Debug")
   message("Debug mode")
-    set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_60,code=compute_60;-std=c++11;-g;-lineinfo;-Xcompiler;-ggdb;-lcurand)
+  set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-g;-lineinfo;-Xcompiler;-ggdb;-lcurand)
 else()
-   set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_60,code=compute_60;-std=c++11;-DNDEBUG;-Xcompiler;-DNDEBUG;-lcurand)
+  set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-DNDEBUG;-Xcompiler;-DNDEBUG;-lcurand)
 endif()
 
 set(CUDA_PROPAGATE_HOST_FLAGS OFF)
 
 # Addresses a bug where code is not compiled as C++11 in non-CUDA code and older g++ versions
-set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11  -I/  " )
+# Edit: using c++14 now
+set(CMAKE_CXX_STANDARD 14)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14  -I/  " )
 
 add_definitions(-DNO_INJECTION)
 add_definitions(-DPROMISE_TUNER_ENABLED)
@@ -43,9 +52,9 @@ cuda_add_cublas_to_target(tensor_runtime)
 cuda_add_library(tensor_cpu_runtime tensor_runtime/src/tensor_cpu_runtime.cc)
 
 if(USE_GFLAGS)
-  target_link_libraries(tensor_runtime gflags cudnn -lcurand)
+  target_link_libraries(tensor_runtime gflags cudnn cufft -lcurand)
 else()
-  target_link_libraries(tensor_runtime cudnn -lcurand)
+  target_link_libraries(tensor_runtime cudnn cufft -lcurand)
 endif()
 
 target_link_libraries(tensor_cpu_runtime)
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/debug.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/debug.h
index 33864fed94f3a86d065f2f166adbfc36127cc42d..813351da7b0b1fa53252c097bde3d2630efe012c 100644
--- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/debug.h
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/debug.h
@@ -9,6 +9,13 @@
 
 #include "tensor.h"
 
+#include <sstream>
+#include <iostream>
+#include <cstdarg>
+
+#include <cudnn.h>
+#include <cublas_v2.h>
+#include <cufft.h>
 
 #define FatalError(s) do {                                             \
     std::stringstream _where, _message;                                \
@@ -37,7 +44,13 @@
     }                                                                  \
 } while(0)
 
+void _checkCUBLAS(cublasStatus_t error, const char *file, int line);
+
+void _checkCUFFT(cufftResult error, const char *file, int line);
+
+#define checkCUBLAS(err) _checkCUBLAS(err, __FILE__, __LINE__)
 
+#define checkCUFFT(err) _checkCUFFT(err, __FILE__, __LINE__)
 
 void INFO(const char* format, ...){
   if(!LOG_INFO) // Don't print if logging info is disabled
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/broadcast.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/broadcast.h
new file mode 100644
index 0000000000000000000000000000000000000000..52b70e08ff6cf601c44acbdc07132ada2f629c58
--- /dev/null
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/broadcast.h
@@ -0,0 +1,74 @@
+#include <algorithm>
+#include <array>
+#include <cstddef>
+#include <type_traits>
+
+#include "common.h"
+#include "tensor.h"
+
+// TODO: don't accept N == 1
+template <size_t N, typename std::enable_if<N >= 1, int>::type = 0>
+class BroadcastRemap {
+public:
+  BroadcastRemap(const std::array<Tensor *, N> &tensors)
+      : out_sizes(), sizes() {
+    this->in_dims = tensors[0]->dims.num_dims;
+    for (size_t i = 0; i < N; i++) {
+      Tensor *t = tensors[i];
+      this->sizes[i] = ::sizes(t);
+      if (this->in_dims != t->dims.num_dims)
+        throw std::runtime_error("Broadcast tensors have different dimensions");
+      this->tail_stride[i] = 1;
+    }
+    fill_broadcast_dims();
+  }
+
+  std::vector<size_t> getDim() const { return this->out_sizes; }
+
+  const size_t *getStrides() const { return tail_stride; }
+
+private:
+  void fill_broadcast_dims() {
+    // Simplified broadcasting rule:
+    // 1. Tensors must have the same dimension that is greater than 1.
+    // 2. Dimension size being 1 (instead of equal) is only allowed for each
+    // tensor for a continuous N dimensions starting from the last one.
+
+    // Assume all this->in_dims are 1, and compute
+    // out_dims is reverse-constructed
+    if (this->in_dims < 1)
+      throw std::runtime_error(
+          "Broadcast tensors should have at least 1 dimension");
+    bool broadcast_ended[N]{false};
+    this->out_sizes.resize(this->in_dims, 1);
+    for (long i = this->in_dims - 1; i >= 0; i--) {
+      // First get tensors agree on dim size
+      for (size_t j = 0; j < N; j++) {
+        size_t this_size = this->sizes[j][i];
+        if (this_size == 1)
+          continue;
+        if (this->out_sizes[i] != 1 && this->out_sizes[i] != this_size)
+          throw std::runtime_error("Dimension size mismatch");
+        this->out_sizes[i] = this_size;
+      }
+    }
+    for (size_t j = 0; j < N; j++)
+      for (long i = this->in_dims - 1; i >= 0; i--) {
+        // Check for continuity, calculate stride size
+        size_t this_size = this->sizes[j][i];
+        if (this_size != 1) {
+          // Broadcast cannot go on anymore
+          broadcast_ended[j] = true;
+          continue;
+        }
+        if (this->out_sizes[i] != this_size && broadcast_ended[j])
+          throw std::runtime_error("Broadcast dims must be continuous");
+        else
+          tail_stride[j] *= this->out_sizes[i];
+      }
+  }
+
+  size_t in_dims;
+  std::vector<size_t> out_sizes, sizes[N];
+  size_t tail_stride[N];
+};
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/common.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/common.h
new file mode 100644
index 0000000000000000000000000000000000000000..520c1954b237039aee681dccc44acdb9b94dc443
--- /dev/null
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/common.h
@@ -0,0 +1,32 @@
+#ifndef IMAGE_PROCESSING_COMMON_H
+#define IMAGE_PROCESSING_COMMON_H
+
+#include <cudnn.h>
+#include <device_launch_parameters.h>
+#include <vector>
+
+#include "debug.h"
+#include "tensor.h"
+
+template <typename T> __host__ __device__ __forceinline__ T ceilDiv(T a, T b) {
+  return (a + b - 1) / b;
+}
+
+template <typename T> __host__ T resolve_func_ptr(void *func_symbol_ptr) {
+  void *v_func_ptr = nullptr;
+  checkCudaErrors(cudaMemcpyFromSymbol(
+      &v_func_ptr, *(void **)func_symbol_ptr, sizeof(void *)));
+  return (T)v_func_ptr;
+}
+
+std::vector<size_t> sizes(Tensor *t);
+
+std::vector<size_t> sizes(const Dimension &dim);
+
+size_t num_elems(const std::vector<size_t> &dim_sizes);
+
+size_t num_elems(const Dimension &dim);
+
+size_t num_elems(Tensor *t);
+
+#endif // IMAGE_PROCESSING_COMMON_H
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map.cuh b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1375fc8476304af1e74bf6d6be9349cdd802e98d
--- /dev/null
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map.cuh
@@ -0,0 +1,89 @@
+#ifndef RUNTIME_MAP_H
+#define RUNTIME_MAP_H
+
+#include <array>
+#include <cstddef>
+#include <device_launch_parameters.h>
+#include <type_traits>
+
+#include "broadcast.h"
+#include "common.h"
+#include "debug.h"
+#include "map_typing.h"
+#include "tensor.h"
+#include "tensor_utils.cu"
+
+template <size_t N> void mapPrecheck(const std::array<Tensor *, N> &srcs) {
+  for (Tensor *src : srcs) {
+    if (src->dims.num_dims != 4 || src->data_format != CUDNN_TENSOR_NCHW)
+      throw std::runtime_error("Not supported"); // TODO: support this
+  }
+}
+
+template <typename Scalar, size_t N>
+__global__ void kernelMapBroadcast(
+    Scalar *target, unsigned num_rows, NTo1MapF<Scalar, N> n_ary_op,
+    Scalar **srcs, size_t *tail_strides) {
+  unsigned threadId = blockIdx.x * blockDim.x + threadIdx.x,
+           stride = gridDim.x * blockDim.x;
+  Scalar buf[N];
+  for (unsigned row = threadId; row < num_rows; row += stride) {
+    for (unsigned i = 0; i < N; i++) {
+      unsigned j = (unsigned)__fdividef(row, tail_strides[i]);
+      buf[i] = srcs[i][j];
+    }
+    target[row] = call_on_c_array<Scalar, Scalar, N>(n_ary_op, buf);
+  }
+}
+
+template <typename Scalar, size_t N>
+std::tuple<size_t *, Scalar **> make_cuda_params(
+    const BroadcastRemap<N> &br, const std::array<Tensor *, N> &srcs) {
+  std::array<Scalar *, N> gpu_datas;
+  std::transform(srcs.begin(), srcs.end(), gpu_datas.begin(), [](Tensor *t) {
+    hostToDeviceCopy(t);
+    return (Scalar *)t->gpu_data;
+  });
+  size_t *cuda_strides;
+  Scalar **cuda_gpu_data;
+  cudaMalloc(&cuda_strides, N * sizeof(size_t));
+  cudaMemcpy(
+      cuda_strides, br.getStrides(), N * sizeof(size_t),
+      cudaMemcpyHostToDevice);
+  cudaMalloc(&cuda_gpu_data, N * sizeof(Scalar *));
+  cudaMemcpy(
+      cuda_gpu_data, gpu_datas.data(), N * sizeof(size_t),
+      cudaMemcpyHostToDevice);
+  return std::make_tuple(cuda_strides, cuda_gpu_data);
+}
+
+template <
+    typename Scalar, size_t N, typename std::enable_if<N >= 1, int>::type = 0>
+__host__ Tensor *
+mapGeneral(void *host_func_ptr, const std::array<Tensor *, N> &srcs) {
+  mapPrecheck(srcs);
+  auto func_ptr = resolve_func_ptr<NTo1MapF<Scalar, N>>(host_func_ptr);
+
+  auto br = BroadcastRemap<N>(srcs);
+  std::vector<size_t> dim_sizes = br.getDim();
+  auto *target = (Tensor *)create4DTensor(
+      CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, dim_sizes[0], dim_sizes[1],
+      dim_sizes[2], dim_sizes[3]);
+  changeTensorPlacement(target, DEVICE);
+
+  size_t *cuda_strides;
+  Scalar **gpu_data;
+  std::tie(cuda_strides, gpu_data) = make_cuda_params<Scalar, N>(br, srcs);
+
+  unsigned n_elem = num_elems(dim_sizes);
+  unsigned max_threads = 512, max_grid = 1024;
+  unsigned threads = std::min(max_threads, n_elem);
+  unsigned grids = std::min(max_grid, ceilDiv(n_elem, threads));
+  kernelMapBroadcast<Scalar, N><<<grids, threads>>>(
+      (Scalar *)target->gpu_data, n_elem, func_ptr, gpu_data, cuda_strides);
+  cudaDeviceSynchronize();
+  checkCudaErrors(cudaGetLastError());
+  return target;
+}
+
+#endif
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map_typing.h b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map_typing.h
new file mode 100644
index 0000000000000000000000000000000000000000..4cb93661384a572425361a053371c409606de164
--- /dev/null
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/map_typing.h
@@ -0,0 +1,68 @@
+#ifndef IMAGE_PROCESSING_MAP_TYPING_H
+#define IMAGE_PROCESSING_MAP_TYPING_H
+
+// Constructs type T (*)(T, T, T, T, ... <n_times>) from T and N
+#include <cstddef>
+#include <device_launch_parameters.h>
+#include <tuple>
+#include <utility>
+
+namespace {
+template <class T, size_t> using Type = T;
+
+template <typename, template <typename...> typename, typename> struct _RepNType;
+
+template <typename T, template <typename...> typename W, size_t... Is>
+struct _RepNType<T, W, std::index_sequence<Is...>> {
+  using type = W<Type<T, Is>...>;
+};
+
+template <typename T, template <typename...> typename W, size_t N>
+using RepNType = typename _RepNType<T, W, std::make_index_sequence<N>>::type;
+
+template <typename Ret, typename... Args> using FuncPtrT = Ret (*)(Args...);
+
+template <typename Ret, typename Arg, size_t N> struct _NAToBFunc {
+  template <typename... Args> using Wrapper = FuncPtrT<Ret, Args...>;
+
+  using type = RepNType<Arg, Wrapper, N>;
+};
+} // namespace
+
+template <typename Ret, typename Arg, size_t N>
+using NAToBF = typename _NAToBFunc<Ret, Arg, N>::type;
+
+template <typename Scalar, size_t N> using NTo1MapF = NAToBF<Scalar, Scalar, N>;
+
+template <typename T, size_t N> using RepNTuple = RepNType<T, std::tuple, N>;
+
+namespace {
+template <typename TIterable, typename T, size_t... Is>
+constexpr RepNTuple<T, sizeof...(Is)>
+as_tuple(TIterable arr, std::index_sequence<Is...>) {
+  return std::make_tuple(arr[Is]...);
+}
+
+template <typename Function, typename Tuple, size_t... I>
+__device__ auto call(Function f, Tuple t, std::index_sequence<I...>) {
+  return f(std::get<I>(t)...);
+}
+} // namespace
+
+template <typename TIterable, typename T, size_t N>
+constexpr RepNTuple<T, N> as_tuple(TIterable arr) {
+  return as_tuple<TIterable, T>(arr, std::make_index_sequence<N>{});
+}
+
+template <typename Function, typename Tuple>
+__device__ auto call_on_tuple(Function f, Tuple t) {
+  static constexpr auto size = std::tuple_size<Tuple>::value;
+  return call(f, t, std::make_index_sequence<size>{});
+}
+
+template <typename Ret, typename T, size_t N>
+__device__ Ret call_on_c_array(NAToBF<Ret, T, N> f, const T arr[N]) {
+  return call_on_tuple(f, as_tuple<const T *, T, N>(arr));
+}
+
+#endif // IMAGE_PROCESSING_MAP_TYPING_H
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/reduce.cuh b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/reduce.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..bcd58f90bdf444266e511bb19d69dbb4d4d9bdcf
--- /dev/null
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/include/functional/reduce.cuh
@@ -0,0 +1,184 @@
+#include <device_launch_parameters.h>
+#include <functional>
+#include <numeric>
+
+#include "common.h"
+#include "debug.h"
+#include "tensor.h"
+#include "tensor_utils.cu"
+
+// Between CUDA compute capability 1.0 and 7.5,
+// Least "max # threads per block" is 512, so 512 is used to be compatible;
+// at most 2048 threads per multiprocessor, where # of cores varies greatly
+// among devices. Titan X has 3072 cores, Quadro P1000 has 640. A bit of
+// over-subscription doesn't hurt. These numbers will keep us compatible even
+// for 1.0 devices.
+constexpr size_t NThreads = 512, MaxNBlocks = 2048 / NThreads * 3072;
+constexpr size_t MaxBlocksPerDim = 65535;
+
+constexpr size_t AlongDimTh = 16, CrossDimTh = 32;
+
+/*
+ * Reduce along one dimension with a single thread.
+ */
+template <typename K, class BinOp>
+__device__ void reduceAlongDim(
+    K *target, K *src, K init, BinOp binary_op, size_t num_irows,
+    size_t dim_size) {
+  K acc = init;
+  for (size_t col = 0; col < dim_size; ++col) {
+    acc = binary_op(acc, *src);
+    src += num_irows;
+  }
+  *target = acc;
+}
+
+/*
+ * Parallel reduce a dimension of tensor to a scalar value.
+ * Use `n_along_dim_threads` threads to sweep along the dim to be reduced.
+ * Intermediate values are collected in a divide-and-conquer manner,
+ * with thread 0 finally writing back the result.
+ */
+template <typename K, class BinOp>
+__device__ void parallelReduceAlongDim(
+    K *target, K *src, K *sbuf, K init, BinOp binary_op, size_t num_irows,
+    size_t dim_size, size_t along_dim_tid, size_t n_along_dim_threads) {
+  K acc = init;
+  // Sequential reduction within a thread.
+  for (size_t col = along_dim_tid; col < dim_size; col += n_along_dim_threads) {
+    acc = binary_op(acc, src[col * num_irows]);
+  }
+
+  sbuf[along_dim_tid] = acc;
+
+  __syncthreads();
+
+  // Reduce intermediate values to single value.
+  for (size_t s = AlongDimTh >> 1u; s > 0; s >>= 1u) {
+    if (along_dim_tid < s) {
+      K arg1 = sbuf[along_dim_tid];
+      K arg2 = sbuf[along_dim_tid + s];
+      K res = binary_op(arg1, arg2);
+      sbuf[along_dim_tid] = res;
+    }
+    __syncthreads();
+  }
+
+  if (along_dim_tid == 0) {
+    *target = sbuf[0];
+  }
+  __syncthreads();
+}
+
+/*
+ * Reduce the whole tensor with parallelism only on output.
+ * The reduce axis is reduced sequentially.
+ * Block is 2D, thread is 1D; block.y covers outer rows, block.x * thread.x
+ * covers inner rows.
+ */
+template <typename K, class BinOp>
+__global__ void kernelReduceDimSeq(
+    K *target_, K *src_, K init, BinOp binary_op, size_t num_irows,
+    size_t num_orows, size_t row_size, size_t approx_row_size) {
+  size_t start_orow = blockIdx.y,
+         start_irow = blockIdx.x * blockDim.x + threadIdx.x;
+  size_t orow_stride = gridDim.y, irow_stride = gridDim.x * blockDim.x;
+  for (size_t orow = start_orow; orow < num_orows; orow += orow_stride) {
+    for (size_t irow = start_irow; irow < num_irows; irow += irow_stride) {
+      K *src = src_ + orow * row_size * num_irows + irow;
+      K *target = target_ + orow * num_irows + irow;
+      reduceAlongDim(target, src, init, binary_op, num_irows, approx_row_size);
+    }
+  }
+}
+
+/*
+ * Reduce the whole tensor with parallelism on output and reduce axis.
+ * I.e., the reduce axis is reduced parallel.
+ * Block is 2D, thread is 2D;
+ * thread.x covers reduce axis,
+ * block.x * thread.y covers inner rows,
+ * and block.y covers outer rows.
+ */
+template <typename K, class BinOp>
+__global__ void kernelReduceDimParallel(
+    K *target_, K *src_, K init, BinOp binary_op, size_t num_irows,
+    size_t num_orows, size_t row_size, size_t approx_row_size) {
+  __shared__ K sbuf[CrossDimTh][AlongDimTh + 1]; // avoid bank conflict
+  size_t start_orow = blockIdx.y,
+         start_irow = blockIdx.x * blockDim.y + threadIdx.y;
+  size_t orow_stride = gridDim.y, irow_stride = gridDim.x * blockDim.y;
+  for (size_t orow = start_orow; orow < num_orows; orow += orow_stride) {
+    for (size_t irow = start_irow; irow < num_irows; irow += irow_stride) {
+      K *src = src_ + orow * row_size * num_irows + irow;
+      K *target = target_ + orow * num_irows + irow;
+      parallelReduceAlongDim(
+          target, src, sbuf[threadIdx.y], init, binary_op, num_irows,
+          approx_row_size, threadIdx.x, blockDim.x);
+    }
+  }
+}
+
+template <typename Scalar>
+__host__ Tensor *reduceDim(
+    Tensor *src, const Scalar &init, void *func, size_t axis, float skip_rate) {
+  // Copy input over
+  hostToDeviceCopy(src);
+
+  // Cast function ptr to right type
+  using bin_float_op = NTo1MapF<Scalar, 2>;
+  auto func_ptr = resolve_func_ptr<bin_float_op>(func);
+
+  // Prepare output
+  std::vector<size_t> in_sizes = sizes(src), out_sizes = in_sizes;
+  out_sizes[axis] = 1;
+  auto *target = (Tensor *)create4DTensor(
+      CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, out_sizes[0], out_sizes[1], out_sizes[2], out_sizes[3]);
+  changeTensorPlacement(target, DEVICE);
+
+  // Calculate schedule parameters
+  size_t num_orows = std::accumulate(
+      in_sizes.begin(), in_sizes.begin() + axis, 1, std::multiplies<>());
+  size_t row_size = in_sizes[axis];
+  size_t num_irows = std::accumulate(
+      in_sizes.begin() + axis + 1, in_sizes.end(), 1, std::multiplies<>());
+  size_t num_rows = num_irows * num_orows;
+
+  // Calculate approximation parameters  
+  if (skip_rate != 0.0f)
+    INFO("Approximation happening...");
+  size_t approx_row_size = (size_t)((1 - skip_rate) * row_size);
+
+  // If # of output entries is small, and row_size is enough for 16 threads,
+  // reduce in parallel.
+  // Remember if reducing dim in parallel, threads must be (16, 32).
+  if (num_rows < NThreads * MaxNBlocks && row_size >= AlongDimTh * 8) {
+    DEBUG(
+        "Reducing in parallel, row size = %lu, actually using %lu", row_size,
+        approx_row_size);
+    size_t grid_x = std::min(MaxBlocksPerDim, ceilDiv(num_irows, 32ul));
+    size_t grid_y = std::min(
+        std::min(MaxBlocksPerDim, num_orows), ceilDiv(MaxNBlocks, grid_x));
+    dim3 threads(AlongDimTh, CrossDimTh);
+    dim3 grid(grid_x, grid_y);
+    kernelReduceDimParallel<Scalar, bin_float_op><<<grid, threads>>>(
+        (Scalar *)target->gpu_data, (Scalar *)src->gpu_data, init, func_ptr,
+        num_irows, num_orows, row_size, approx_row_size);
+  } else {
+    DEBUG(
+        "Reducing sequentially, row size = %lu, actually using %lu", row_size,
+        approx_row_size);
+    // Reduce sequentially.
+    size_t threads = std::min(NThreads, num_irows);
+    size_t grid_x = std::min(MaxBlocksPerDim, ceilDiv(num_irows, threads));
+    size_t grid_y = std::min(
+        std::min(MaxBlocksPerDim, num_orows), ceilDiv(MaxNBlocks, grid_x));
+    dim3 grid(grid_x, grid_y);
+    kernelReduceDimSeq<Scalar, bin_float_op><<<grid, threads>>>(
+        (Scalar *)target->gpu_data, (Scalar *)src->gpu_data, init, func_ptr,
+        num_irows, num_orows, row_size, approx_row_size);
+  }
+  cudaDeviceSynchronize();
+  checkCudaErrors(cudaGetLastError());
+  return target;
+}
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/common.cpp b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/common.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c30ddf25e5aeffef8123897af863b520846b62ce
--- /dev/null
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/common.cpp
@@ -0,0 +1,19 @@
+#include "functional/common.h"
+
+#include <numeric>
+#include <functional>
+
+std::vector<size_t> sizes(const Dimension &dim) {
+  return std::vector<size_t>(dim.dim_sizes, dim.dim_sizes + dim.num_dims);
+}
+
+std::vector<size_t> sizes(Tensor *t) { return sizes(t->dims); }
+
+size_t num_elems(const std::vector<size_t> &dim_sizes) {
+  return std::accumulate(
+      dim_sizes.begin(), dim_sizes.end(), 1, std::multiplies<>());
+}
+
+size_t num_elems(const Dimension &dim) { return num_elems(sizes(dim)); }
+
+size_t num_elems(Tensor *t) { return num_elems(sizes(t)); }
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/debug.cpp b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/debug.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..467336dd411ebc8d805bccf3430b74be98f4fec0
--- /dev/null
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/debug.cpp
@@ -0,0 +1,130 @@
+#include <cstdarg>
+#include <cstdio>
+#include <stdexcept>
+#include <cuda_runtime_api.h>
+#include "debug.h"
+
+void throwError(const char *file, int line, const char *fmt, ...) {
+    char msg[2048];
+    va_list args;
+    /* vasprintf not standard */
+    /* vsnprintf: how to handle if does not exist? */
+    va_start(args, fmt);
+    int n = vsnprintf(msg, 2048, fmt, args);
+    va_end(args);
+    if (n < 2048) {
+        snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
+    }
+
+    throw std::runtime_error(msg);
+}
+
+template<typename T, typename F>
+void checkCompareFlag(
+        T err, T success_const, F get_err_str, const char *error_kind, const char *file, int line
+) {
+    if (err != success_const) {
+        static int alreadyFailed = 0;
+        if (!alreadyFailed) {
+            fprintf(
+                    stderr, "%s Error file=%s line=%i error=%i : %s\n",
+                    error_kind, file, line, err,
+                    get_err_str(err)
+            );
+            alreadyFailed = 1;
+        }
+        throwError(
+                file, line, "%s Error error (%d) : %s", error_kind, err,
+                get_err_str(err)
+        );
+    }
+}
+
+void _checkCUDA(cudaError_t err, const char *file, int line) {
+    checkCompareFlag(err, cudaSuccess, cudaGetErrorString, "CUDA", file, line);
+}
+
+void _checkWarnCUDA(cudaError_t err, const char *file, int line) {
+    if (err != cudaSuccess) {
+        fprintf(stderr, "CUDA Warning file=%s line=%i error=%i : %s\n", file, line, err,
+                cudaGetErrorString(err));
+    }
+}
+
+void _checkCUDNN(cudnnStatus_t error, const char *file, int line) {
+    checkCompareFlag(error, CUDNN_STATUS_SUCCESS, cudnnGetErrorString, "CUDNN", file, line);
+}
+
+static const char *cublasGetErrorString(cublasStatus_t status) {
+    switch (status) {
+        case CUBLAS_STATUS_SUCCESS:
+            return "CUBLAS_STATUS_SUCCESS";
+        case CUBLAS_STATUS_NOT_INITIALIZED:
+            return "CUBLAS_STATUS_NOT_INITIALIZED";
+        case CUBLAS_STATUS_ALLOC_FAILED:
+            return "CUBLAS_STATUS_ALLOC_FAILED";
+        case CUBLAS_STATUS_INVALID_VALUE:
+            return "CUBLAS_STATUS_INVALID_VALUE";
+        case CUBLAS_STATUS_ARCH_MISMATCH:
+            return "CUBLAS_STATUS_ARCH_MISMATCH";
+        case CUBLAS_STATUS_MAPPING_ERROR:
+            return "CUBLAS_STATUS_MAPPING_ERROR";
+        case CUBLAS_STATUS_EXECUTION_FAILED:
+            return "CUBLAS_STATUS_EXECUTION_FAILED";
+        case CUBLAS_STATUS_INTERNAL_ERROR:
+            return "CUBLAS_STATUS_INTERNAL_ERROR";
+        case CUBLAS_STATUS_NOT_SUPPORTED:
+            return "CUBLAS_STATUS_NOT_SUPPORTED";
+        case CUBLAS_STATUS_LICENSE_ERROR:
+            return "CUBLAS_STATUS_LICENSE_ERROR";
+    }
+    return "unknown error";
+}
+
+void _checkCUBLAS(cublasStatus_t error, const char *file, int line) {
+    checkCompareFlag(error, CUBLAS_STATUS_SUCCESS, cublasGetErrorString, "CUBLAS", file, line);
+}
+
+static const char *cufftGetErrorString(cufftResult error) {
+    switch (error) {
+        case CUFFT_SUCCESS:
+            return "CUFFT_SUCCESS";
+        case CUFFT_INVALID_PLAN:
+            return "CUFFT_INVALID_PLAN";
+        case CUFFT_ALLOC_FAILED:
+            return "CUFFT_ALLOC_FAILED";
+        case CUFFT_INVALID_TYPE:
+            return "CUFFT_INVALID_TYPE";
+        case CUFFT_INVALID_VALUE:
+            return "CUFFT_INVALID_VALUE";
+        case CUFFT_INTERNAL_ERROR:
+            return "CUFFT_INTERNAL_ERROR";
+        case CUFFT_EXEC_FAILED:
+            return "CUFFT_EXEC_FAILED";
+        case CUFFT_SETUP_FAILED:
+            return "CUFFT_SETUP_FAILED";
+        case CUFFT_INVALID_SIZE:
+            return "CUFFT_INVALID_SIZE";
+        case CUFFT_UNALIGNED_DATA:
+            return "CUFFT_UNALIGNED_DATA";
+        case CUFFT_INCOMPLETE_PARAMETER_LIST:
+            return "CUFFT_INCOMPLETE_PARAMETER_LIST";
+        case CUFFT_INVALID_DEVICE:
+            return "CUFFT_INVALID_DEVICE";
+        case CUFFT_PARSE_ERROR:
+            return "CUFFT_PARSE_ERROR";
+        case CUFFT_NO_WORKSPACE:
+            return "CUFFT_NO_WORKSPACE";
+        case CUFFT_NOT_IMPLEMENTED:
+            return "CUFFT_NOT_IMPLEMENTED";
+        case CUFFT_LICENSE_ERROR:
+            return "CUFFT_LICENSE_ERROR";
+        case CUFFT_NOT_SUPPORTED:
+            return "CUFFT_NOT_SUPPORTED";
+    }
+    return "<unknown>";
+}
+
+void _checkCUFFT(cufftResult error, const char *file, int line) {
+    checkCompareFlag(error, CUFFT_SUCCESS, cufftGetErrorString, "CUFFT", file, line);
+}
diff --git a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/img_tensor_runtime.cu b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/img_tensor_runtime.cu
index 5776e5c385de79ed96255763a770a31d2b6d3df3..497d9ec38bfbff86afc21e290c29c589fa12fc7c 100644
--- a/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/img_tensor_runtime.cu
+++ b/llvm/projects/hpvm-tensor-rt/tensor_runtime/src/img_tensor_runtime.cu
@@ -1,16 +1,108 @@
-#include "../include/approxhpvm_img_runtime_utils.h"
-#include "../include/debug.h"
-#include "../include/img_tensor_runtime.h"
+#include "approxhpvm_img_runtime_utils.h"
+#include "debug.h"
+#include "img_tensor_runtime.h"
+
+#include "functional/map.cuh"
+#include "functional/reduce.cuh"
+#include "tensor_utils.cu"
+
+#include <cufft.h>
+
+// FIXME: really just a hack to compile into a single .o
+#include "common.cpp"
+#include "debug.cpp"
 
 // ***                       Runtime implementation                      *** //
-void *tensorFft(void *input) {}
-void *tensorReduce(void *input, size_t axis, void *func) {}
+void *tensorFft(void *input) {
+  // https://docs.nvidia.com/cuda/cufft/index.html#twod-complex-to-real-transforms
+  // Tensor checking
+  INFO("FFT");
+  auto *t_input = (Tensor *)input;
+  if (t_input->data_type != CUDNN_DATA_FLOAT)
+    throw std::runtime_error("Only float32 is supported");
+  int total_rank = t_input->dims.num_dims;
+  if (total_rank != 4)
+    throw std::runtime_error("Only 4-dim tensor supported");
+  // Dimensions
+  size_t *all_dims = t_input->dims.dim_sizes;
+  int width = all_dims[2], height = all_dims[3];
+  int fft_dim[2] = {width, height};
+  int n_batch = int(all_dims[0]) * int(all_dims[1]);
+  // Prepare input data
+  hostToDeviceCopy(t_input);
+  auto *input_cuda = (cufftReal *)t_input->gpu_data;
+  // Define output data
+  // FIXME: make a flag for float2_; not CUDNN_DATA_FLOAT.
+  auto *out_tensor = (Tensor *)create4DTensor(
+      CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, all_dims[0], all_dims[1], width,
+      (height / 2 + 1));
+  changeTensorPlacement(out_tensor, DEVICE);
+  auto *output_cuda = (cufftComplex *)out_tensor->gpu_data;
+  // Create a 2D FFT plan
+  cufftHandle plan;
+  checkCUFFT(cufftPlanMany(
+      &plan, 2, fft_dim, nullptr, 1, 0, nullptr, 1, 0, CUFFT_R2C, n_batch));
+  // Execute the plan
+  checkCUFFT(cufftExecR2C(plan, input_cuda, output_cuda));
+  // Wait for the device to finish
+  checkCudaErrors(cudaDeviceSynchronize());
+  // Release memory
+  cufftDestroy(plan);
+  return out_tensor;
+}
+
+void *tensorReduce(void *input, size_t axis, void *func) {
+  INFO("Reduce");
+  auto *src = (Tensor *)input;
+  if (axis >= src->dims.num_dims)
+    throw std::runtime_error("Dimension out of range");
+  if (src->dims.num_dims != 4 || src->data_format != CUDNN_TENSOR_NCHW)
+    throw std::runtime_error("Not supported");
+
+  // Skip 0% of sample
+  return reduceDim<float>(src, 0.0f, func, axis, 0.0f);
+}
+
 void *tensorReductionSamplingReduce(
-    void *input, size_t axis, void *func, int skip_level) {}
-void *tensorProjectiveT(void *input, void *transformation) {}
-void *tensorMap1(void *f, void *i) {}
-void *tensorMap2(void *f2, void *i1, void *i2) {}
-void *tensorMap3(void *f3, void *i1, void *i2, void *i3) {}
+    void *input, size_t axis, void *func, int skip_level) {
+  INFO("Reduce with sampling");
+  auto *src = (Tensor *)input;
+  if (axis >= src->dims.num_dims)
+    throw std::runtime_error("Dimension out of range");
+  if (src->dims.num_dims != 4 || src->data_format != CUDNN_TENSOR_NCHW)
+    throw std::runtime_error("Not supported");
+
+  switch (skip_level) {
+  case 0:
+    return reduceDim<float>(src, 0.0f, func, axis, 0.1f);
+  case 1:
+    return reduceDim<float>(src, 0.0f, func, axis, 0.2f);
+  case 2:
+    return reduceDim<float>(src, 0.0f, func, axis, 0.4f);
+  }
+}
+
+void *tensorProjectiveT(void *input, void *transformation) {
+  ERROR("ProjectiveT operation currently unsupported.\n");
+}
+
+void *tensorMap1(void *f, void *i) {
+  INFO("Map1");
+  auto *src = (Tensor *)i;
+  return mapGeneral<float, 1>(f, {src});
+}
+
+void *tensorMap2(void *f2, void *i1, void *i2) {
+  INFO("Map2");
+  auto *src1 = (Tensor *)i1, *src2 = (Tensor *)i2;
+  return mapGeneral<float, 2>(f2, {src1, src2});
+}
+
+void *tensorMap3(void *f3, void *i1, void *i2, void *i3) {
+  INFO("Map3");
+  auto *src1 = (Tensor *)i1, *src2 = (Tensor *)i2, *src3 = (Tensor *)i3;
+  return mapGeneral<float, 3>(f3, {src1, src2, src3});
+}
 
 // ***                     Wrapper API implementation                    *** //