#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

#include <args.hxx>
#include <minimp4.h>

#include <codec/api/wels/codec_api.h>

#include "common.hh"


namespace param {
using namespace ::args;

ArgumentParser parser {
  "converter: stego -> alter-probability matrix"
};
HelpFlag help {
  parser, "help", "display this menu", {'h', "help"},
};

ValueFlag<int32_t> bw {
  parser, "128", "width of blocks (px)", {"block-w"}, 128,
};
ValueFlag<int32_t> bh {
  parser, "128", "height of blocks (px)", {"block-h"}, 128,
};
ValueFlag<int32_t> utime {
  parser, "10", "duration of each feature (frame)", {"utime"}, 10,
};

ValueFlag<int32_t> bmix {
  parser, "8", "x interval of blockmatch (px)", {"bm-ix"}, 8,
};
ValueFlag<int32_t> bmiy {
  parser, "8", "y interval of blockmatch (px)", {"bm-iy"}, 8,
};
ValueFlag<int32_t> bmsw {
  parser, "4", "width of blockmatch search region (px)", {"bm-sw"}, 4,
};
ValueFlag<int32_t> bmsh {
  parser, "4", "height of blockmatch search region (px)", {"bm-sh"}, 4,
};

enum Output {
  kProb,
  kIndex,
  kLen,
  kVec,
  kNull,
};
const std::unordered_map<std::string, Output> kOutput = {
  {"prob", kProb},
  {"index", kIndex},
  {"len", kLen},
  {"vec", kVec},
  {"null", kNull},
};
MapFlag<std::string, Output> output {
  parser, "prob", "output type (len, vec, null)", {"output"}, kOutput,
};

Positional<std::string> vpath {
  parser, "path", "video file path",
};

}  // namespace param


struct Vec {
  double x, y, score, len;
};


static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
  const auto bw   = args::get(param::bw);
  const auto bh   = args::get(param::bh);
  const auto bmix = args::get(param::bmix);
  const auto bmiy = args::get(param::bmiy);
  const auto bmsw = args::get(param::bmsw);
  const auto bmsh = args::get(param::bmsh);

  int32_t min_sx = 0, min_sy = 0;
  double  min_score = 1e+100;  // INF
  for (int32_t sy = -bmsh; sy < bmsh; ++sy) {
    for (int32_t sx = -bmsw; sx < bmsw; ++sx) {
      double score = 0;
      for (int32_t y = 0; y < bh; y += bmiy) {
        for (int32_t x = 0; x < bw; x += bmix) {
          const auto c_off = (bx+x) + (by+y)*cf.w;
          const auto p_off = (bx+x+sx) + (by+y+sy)*cf.w;
          const auto diff  = static_cast<double>(cf.Y[c_off] - pf.Y[p_off]);
          score += diff*diff;
        }
      }
      if (score < min_score) {
        min_score = score;
        min_sx    = sx;
        min_sy    = sy;
      }
    }
  }

  const auto sxf = static_cast<double>(min_sx) / static_cast<double>(bmsw);
  const auto syf = static_cast<double>(min_sy) / static_cast<double>(bmsh);
  const auto scf = static_cast<double>(min_score) / static_cast<double>(UINT8_MAX*(bw/bmix)*(bh/bmiy));
  return { .x = sxf, .y = syf, .score = scf, .len = std::sqrt(sxf*sxf+syf*syf), };
}

static Vec EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
  const auto v = BlockMatching(cf, pf, bx, by);
  switch (args::get(param::output)) {
  case param::kLen:
    std::cout << v.len << '\n';
    break;
  case param::kVec:
    std::cout << bx << " " << by << " " << v.x << " " << v.y << " " << v.score << '\n';
    break;
  default:
    break;
  }
  return v;
}

static void EachFrame(int32_t t, const Frame& cf, const Frame& pf) {
  const auto bw = args::get(param::bw);
  const auto bh = args::get(param::bh);
  const auto ut = args::get(param::utime);

  Enforce(cf.w == pf.w && cf.h == pf.h, "variable frame size is not allowed");
  Enforce(cf.w > bw && cf.h > bh, "block size must be less than frame size");

  struct Block {
    double len, score;
  };
  static std::vector<Block> blocks;
  if (t == 1) {
    blocks.clear();
    blocks.resize((cf.w/bw) * (cf.h/bh));
  }

  auto block = blocks.data();
  for (int32_t by = 0; by+bh <= cf.h; by+=bh) {
    for (int32_t bx = 0; bx+bw <= cf.w; bx+=bw) {
      const auto v = EachBlock(cf, pf, bx, by);
      block->score += v.score;
      block->len   += v.len;
      ++block;
    }
  }

  switch (args::get(param::output)) {
  case param::kLen:
  case param::kVec:
    std::cout << std::endl;
    break;
  case param::kIndex:
  case param::kProb:
    if (t == ut-1) {
      for (size_t i = 0; i < blocks.size(); ++i) {
        const auto len   = blocks[i].len/(ut-1)/std::sqrt(2);  // length calculation
        const auto score = blocks[i].score/(ut-1);
        const auto prob  = std::clamp((1-len) * (1-score), 0., 1.);

        if (args::get(param::output) == param::kIndex) {
          if (prob > 0.95) std::cout << i << ' ';
        } else {
          std::cout << prob << ' ';
        }
      }
      std::cout << std::endl;
    }
    break;
  default:
    break;
  }
}

static void Exec() {
  const auto bw = args::get(param::bw);
  const auto bh = args::get(param::bh);
  const auto ut = args::get(param::utime);
  Enforce(bw > 0 && bh > 0, "block size must be greater than 0");
  Enforce(ut > 0, "utime must be greater than 0");

  const auto bmix = args::get(param::bmix);
  const auto bmiy = args::get(param::bmiy);
  const auto bmsw = args::get(param::bmsw);
  const auto bmsh = args::get(param::bmsh);
  Enforce(bmix > 0 && bmiy > 0, "block matching search interval must be greater than 0");
  Enforce(bmsw > 0 && bmsh > 0, "block matching search region size must be greater than 0");

  // open video stream
  const auto vpath = args::get(param::vpath);
  std::ifstream vst {vpath.c_str(), std::ifstream::binary | std::ifstream::ate};
  Enforce(!!vst, "video stream is invalid");
  const auto vsz = vst.tellg();

  // init decoder
  ISVCDecoder* dec;
  Enforce(0 == WelsCreateDecoder(&dec), "decoder creation failure");

  SDecodingParam decp = {};
  decp.sVideoProperty.eVideoBsType = VIDEO_BITSTREAM_DEFAULT;
  decp.eEcActiveIdc                = ERROR_CON_SLICE_COPY;
  Enforce(0 == dec->Initialize(&decp), "decoder init failure");

  int declv = WELS_LOG_DEBUG;
  dec->SetOption(DECODER_OPTION_TRACE_LEVEL, &declv);

  uint8_t*    yuv[3] = {0};
  SBufferInfo frame  = {};

  // demux
  MP4D_demux_t dem = {};
  MP4D_open(&dem, [](int64_t off, void* buf, size_t sz, void* ptr) {
    auto& vst = *reinterpret_cast<std::ifstream*>(ptr);
    vst.seekg(off);
    Enforce(!!vst, "seek failure");
    vst.read(reinterpret_cast<char*>(buf), sz);
    Enforce(!!vst, "read failure");
    return 0;
  }, &vst, vsz);

  // find video track
  size_t ti;
  for (ti = 0; ti < dem.track_count; ++ti) {
    const auto& tra = dem.track[ti];
    if (tra.handler_type == MP4D_HANDLER_TYPE_VIDE) {
      break;
    }
  }
  Enforce(ti < dem.track_count, "no video track");
  const auto& tra = dem.track[ti];

  // consume SPS
  std::vector<uint8_t> nal;
  for (size_t si = 0;; ++si) {
    int sz;
    auto sps = reinterpret_cast<const uint8_t*>(MP4D_read_sps(&dem, ti, si, &sz));
    if (!sps) break;
    CopyNal(nal, sps, sz);

    const auto ret = dec->DecodeFrameNoDelay(nal.data(), nal.size(), yuv, &frame);
    Enforce(ret == 0, "SPS decode failure");
  }

  // consume PPS
  for (size_t si = 0;; ++si) {
    int sz;
    auto pps = reinterpret_cast<const uint8_t*>(MP4D_read_pps(&dem, ti, si, &sz));
    if (!pps) break;
    CopyNal(nal, pps, sz);

    const auto ret = dec->DecodeFrameNoDelay(nal.data(), nal.size(), yuv, &frame);
    Enforce(ret == 0, "PPS decode failure");
  }

  // decode frame
  Frame pf = {};
  size_t t = 0;
  for (size_t si = 0; si < tra.sample_count; ++si) {
    unsigned fsz, time, dur;
    const auto off = MP4D_frame_offset(&dem, ti, si, &fsz, &time, &dur);

    vst.seekg(off);
    Enforce(!!vst, "NAL seek failure");

    nal.resize(fsz);
    vst.read(reinterpret_cast<char*>(nal.data()), fsz);
    Enforce(!!vst, "NAL read failure");

    for (size_t i = 0; i < nal.size();) {
      uint32_t sz =
          (nal[i] << 24) | (nal[i+1] << 16) | (nal[i+2] << 8) | nal[i+3];

      nal[i+0] = 0;
      nal[i+1] = 0;
      nal[i+2] = 0;
      nal[i+3] = 1;
      sz += 4;

      const auto ret = dec->DecodeFrameNoDelay(&nal[i], sz, yuv, &frame);
      Enforce(ret == 0, "frame decode failure");
      i += sz;

      Frame cf = {yuv, frame};
      if (cf.w == 0 || cf.h == 0) continue;

      const auto utf = t%ut;
      if (utf > 0) {
        EachFrame(utf, cf, pf);
      }
      pf = std::move(cf);

      ++t;
    }
  }
}

int main(int argc, char** argv)
try {
  param::parser.ParseCLI(argc, argv);
  Exec();
  return EXIT_SUCCESS;
} catch (const args::Help&) {
  std::cout << param::parser << std::endl;
  return EXIT_SUCCESS;
} catch (const std::exception& e) {
  std::cerr << e.what() << std::endl;
  return EXIT_FAILURE;
}