#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

#include <args.hxx>
#include <minimp4.h>

#include <codec/api/wels/codec_api.h>

#include "common.hh"


namespace param {
using namespace ::args;

ArgumentParser parser {
  "converter: block indices + host -> stego"
};
HelpFlag help {
  parser, "help", "display this menu", {'h', "help"},
};

ValueFlag<int32_t> bw {
  parser, "128", "width of blocks (px)", {"block-w"}, 128,
};
ValueFlag<int32_t> bh {
  parser, "128", "height of blocks (px)", {"block-h"}, 128,
};
ValueFlag<int32_t> utime {
  parser, "10", "duration of each feature (frame)", {"utime"}, 10,
};

Flag uvfix {
  parser, "uvfix", "fix UV values in feature", {"uvfix"},
};

Positional<std::string> dst {
  parser, "path", "destination video file path",
};
Positional<std::string> src {
  parser, "path", "source video file path",
};

// from stdin
std::vector<std::vector<int32_t>> indices;

}  // namespace param


static void Embed(int32_t t, Frame& dst, const Frame& base) {
  const auto bw  = args::get(param::bw);
  const auto bh  = args::get(param::bh);
  const auto hbw = bw/2;
  const auto hbh = bh/2;

  const auto bx_cnt = dst.w / bw;
  const auto by_cnt = dst.h / bh;

  t = t%param::indices.size();
  for (auto idx : param::indices[t]) {
    Enforce(idx >= 0, "block index underflow");

    const auto bx = idx%bx_cnt;
    const auto by = idx/bx_cnt;
    Enforce(by < by_cnt, "block index overflow");

    for (int32_t y = 0; y < bh; ++y) {
      for (int32_t x = 0; x < bw; ++x) {
        const auto off = (by*bh+y)*base.w + (bx*bw+x);
        dst.Y[off] = base.Y[off];
      }
    }

    if (param::uvfix) {
      for (int32_t y = 0; y < hbh; ++y) {
        for (int32_t x = 0; x < hbw; ++x) {
          const auto off = (by*hbh+y)*base.hw + (bx*hbw+x);
          dst.U[off] = base.U[off];
          dst.V[off] = base.V[off];
        }
      }
    }
  }
}

static void Exec() {
  const auto bw = args::get(param::bw);
  const auto bh = args::get(param::bh);
  const auto ut = args::get(param::utime);
  Enforce(bw > 0 && bh > 0, "block size must be greater than 0");
  Enforce(ut > 0, "utime must be greater than 0");

  // read indices
  param::indices = ReadMatrix<int32_t>(std::cin);
  Enforce(param::indices.size() > 0, "empty indices");

  // open source video stream
  const auto srcpath = args::get(param::src);
  std::ifstream srcst {srcpath.c_str(), std::ifstream::binary | std::ifstream::ate};
  Enforce(!!srcst, "source video stream is invalid");
  const int64_t srcsz = srcst.tellg();

  // open destination video stream
  const auto dstpath = args::get(param::dst);
  std::ofstream dstst {dstpath.c_str(), std::ifstream::binary};
  Enforce(!!dstst, "destination video stream is invalid");

  // init decoder
  ISVCDecoder* dec;
  Enforce(0 == WelsCreateDecoder(&dec), "decoder creation failure");

  SDecodingParam decp = {};
  decp.sVideoProperty.eVideoBsType = VIDEO_BITSTREAM_DEFAULT;
  decp.eEcActiveIdc                = ERROR_CON_SLICE_COPY;
  Enforce(0 == dec->Initialize(&decp), "decoder init failure");

  int declv = WELS_LOG_INFO;
  dec->SetOption(DECODER_OPTION_TRACE_LEVEL, &declv);

  uint8_t*    yuv[3] = {0};
  SBufferInfo frame  = {};

  // demuxer
  MP4D_demux_t dem = {};
  MP4D_open(&dem, [](int64_t off, void* buf, size_t sz, void* ptr) {
    auto& st = *reinterpret_cast<std::ifstream*>(ptr);
    st.seekg(off);
    Enforce(!!st, "seek failure");
    st.read(reinterpret_cast<char*>(buf), sz);
    Enforce(!!st, "read failure");
    return 0;
  }, &srcst, srcsz);

  // find video track
  size_t ti;
  for (ti = 0; ti < dem.track_count; ++ti) {
    const auto& t = dem.track[ti];
    if (t.handler_type == MP4D_HANDLER_TYPE_VIDE) {
      break;
    }
  }
  Enforce(ti < dem.track_count, "no video track");
  const auto& tra = dem.track[ti];

  // calc params
  const auto tscale = tra.timescale;
  const auto dur    =
      (static_cast<uint64_t>(tra.duration_hi) << 32) |
      static_cast<uint64_t>(tra.duration_lo);
  const auto dursec = static_cast<float>(dur)/static_cast<float>(tscale);

  const float   fps  = static_cast<float>(tra.sample_count)/dursec;
  const auto    fps9 = static_cast<int>(90000/fps);
  const int32_t w    = tra.SampleDescription.video.width;
  const int32_t h    = tra.SampleDescription.video.height;

  // init encoder
  ISVCEncoder* enc;
  Enforce(0 == WelsCreateSVCEncoder(&enc), "encoder creation failure");

  SEncParamBase encp = {};
  encp.iUsageType     = SCREEN_CONTENT_REAL_TIME;
  encp.fMaxFrameRate  = fps;
  encp.iPicWidth      = w;
  encp.iPicHeight     = h;
  encp.iTargetBitrate = 5000000;
  Enforce(0 == enc->Initialize(&encp), "encoder init failure");

  int enclv = WELS_LOG_INFO;
  enc->SetOption(ENCODER_OPTION_TRACE_LEVEL, &enclv);

  // init muxer
  MP4E_mux_t* mux = MP4E_open(
      false, false, &dstst,
      [](int64_t off, const void* buf, size_t size, void* ptr) {
        auto& st = *reinterpret_cast<std::ostream*>(ptr);
        st.seekp(off);
        Enforce(!!st, "muxer seek failure");
        st.write(reinterpret_cast<const char*>(buf), size);
        Enforce(!!st, "muxer write failure");
        return 0;
      });

  mp4_h26x_writer_t writer;
  Enforce(
      MP4E_STATUS_OK == mp4_h26x_write_init(&writer, mux, w, h, false),
      "failed to init mp4_h26x_writer_t");

  // consume SPS
  std::vector<uint8_t> nal;
  for (size_t si = 0;; ++si) {
    int sz;
    auto sps = reinterpret_cast<const uint8_t*>(MP4D_read_sps(&dem, ti, si, &sz));
    if (!sps) break;
    CopyNal(nal, sps, sz);

    const auto ret = dec->DecodeFrameNoDelay(nal.data(), nal.size(), yuv, &frame);
    Enforce(ret == 0, "SPS decode failure");
  }

  // consume PPS
  for (size_t si = 0;; ++si) {
    int sz;
    auto pps = reinterpret_cast<const uint8_t*>(MP4D_read_pps(&dem, ti, si, &sz));
    if (!pps) break;
    CopyNal(nal, pps, sz);

    const auto ret = dec->DecodeFrameNoDelay(nal.data(), nal.size(), yuv, &frame);
    Enforce(ret == 0, "PPS decode failure");
  }

  // decode frame
  Frame bf = {};
  int32_t t = 0;
  for (size_t si = 0; si < tra.sample_count; ++si) {
    unsigned fsz, time, dur;
    const auto off = MP4D_frame_offset(&dem, ti, si, &fsz, &time, &dur);

    srcst.seekg(off);
    Enforce(!!srcst, "NAL seek failure");

    nal.resize(fsz);
    srcst.read(reinterpret_cast<char*>(nal.data()), fsz);
    Enforce(!!srcst, "NAL read failure");

    // decode all nal blocks
    for (size_t i = 0; i < nal.size();) {
      uint32_t sz =
          (nal[i] << 24) | (nal[i+1] << 16) | (nal[i+2] <<  8) | (nal[i+3] <<  0);

      nal[i+0] = 0;
      nal[i+1] = 0;
      nal[i+2] = 0;
      nal[i+3] = 1;
      sz += 4;

      // retrieve a frame
      const auto ret = dec->DecodeFrameNoDelay(&nal[i], sz, yuv, &frame);
      Enforce(ret == 0, "frame decode failure");

      // handle decoded frame
      if (frame.iBufferStatus) {
        // alter the frame if it's not the first
        Frame cf = {yuv, frame};
        if (t%ut > 0) {
          Embed(t/ut, cf, bf);
        }

        // encode
        SFrameBSInfo   info;
        SSourcePicture pic = cf.GetSourcePic();
        Enforce(cmResultSuccess == enc->EncodeFrame(&pic, &info),
                "encode failure");

        // write buffer
        if (info.eFrameType != videoFrameTypeSkip) {
          for (int li = 0; li < info.iLayerNum; ++li) {
            const auto& l = info.sLayerInfo[li];

            uint8_t* buf = l.pBsBuf;
            for (int ni = 0; ni < l.iNalCount; ++ni) {
              mp4_h26x_write_nal(
                  &writer, buf, l.pNalLengthInByte[ni], fps9);
              buf += l.pNalLengthInByte[ni];
            }
          }
        }

        // save the frame if it's the first
        if (t%ut == 0) {
          bf = std::move(cf);
        }
        ++t;
      }
      i += sz;
    }
  }

  // tear down
  MP4E_close(mux);
  mp4_h26x_write_close(&writer);
}

int main(int argc, char** argv)
try {
  param::parser.ParseCLI(argc, argv);
  Exec();
  return EXIT_SUCCESS;
} catch (const args::Help&) {
  std::cout << param::parser << std::endl;
  return EXIT_SUCCESS;
} catch (const std::exception& e) {
  std::cerr << e.what() << std::endl;
  return EXIT_FAILURE;
}