From e0f175c6baac157305a9841d90aa0817b5b18dd1 Mon Sep 17 00:00:00 2001 From: falsycat Date: Tue, 6 Sep 2022 21:22:28 +0900 Subject: [PATCH] add new converter: cprob -> code --- conv/CMakeLists.txt | 3 + conv/bidx_video.cc | 21 +------ conv/common.hh | 14 +++++ conv/cprob_code.cc | 150 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 170 insertions(+), 18 deletions(-) create mode 100644 conv/cprob_code.cc diff --git a/conv/CMakeLists.txt b/conv/CMakeLists.txt index 0a466cd..c207ab6 100644 --- a/conv/CMakeLists.txt +++ b/conv/CMakeLists.txt @@ -5,3 +5,6 @@ target_link_libraries(bidx_video PRIVATE args minimp4 openh264) add_executable(video_fprob common.hh video_fprob.cc) target_link_libraries(video_fprob PRIVATE args minimp4 openh264) + +add_executable(cprob_code common.hh cprob_code.cc) +target_link_libraries(cprob_code PRIVATE args minimp4 openh264) diff --git a/conv/bidx_video.cc b/conv/bidx_video.cc index 493b2b1..34bf6d3 100644 --- a/conv/bidx_video.cc +++ b/conv/bidx_video.cc @@ -54,9 +54,6 @@ std::vector> indices; } // namespace param -// util -static std::vector> ReadIndices(std::istream&) noexcept; - static void Embed(int32_t t, Frame& dst, const Frame& base) { const auto bw = args::get(param::bw); const auto bh = args::get(param::bw); @@ -68,6 +65,8 @@ static void Embed(int32_t t, Frame& dst, const Frame& base) { t = t%param::indices.size(); for (auto idx : param::indices[t]) { + Enforce(idx >= 0, "block index underflow"); + const auto bx = idx%bx_cnt; const auto by = idx/bx_cnt; Enforce(by < by_cnt, "block index overflow"); @@ -99,7 +98,7 @@ static void Exec() { Enforce(ut > 0, "utime must be greater than 0"); // read indices - param::indices = ReadIndices(std::cin); + param::indices = ReadMatrix(std::cin); Enforce(param::indices.size() > 0, "empty indices"); // open source video stream @@ -301,17 +300,3 @@ try { std::cerr << e.what() << std::endl; return EXIT_FAILURE; } - - - -static std::vector> ReadIndices(std::istream& st) noexcept { - std::vector> ret; - - std::string line; - while (std::getline(st, line)) { - std::istringstream sst {line}; - ret.emplace_back(std::istream_iterator {sst}, - std::istream_iterator {}); - } - return ret; -} diff --git a/conv/common.hh b/conv/common.hh index c7e5346..cce68bf 100644 --- a/conv/common.hh +++ b/conv/common.hh @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -15,6 +16,19 @@ inline void Enforce(bool eval, const std::string& msg) { } } +template +std::vector> ReadMatrix(std::istream& st) noexcept { + std::vector> ret; + + std::string line; + while (std::getline(st, line)) { + std::istringstream sst {line}; + ret.emplace_back(std::istream_iterator {sst}, + std::istream_iterator {}); + } + return ret; +} + // ---- MP4 utilities inline void CopyNal(std::vector& v, const uint8_t* buf, size_t sz) noexcept { diff --git a/conv/cprob_code.cc b/conv/cprob_code.cc new file mode 100644 index 0000000..75c371d --- /dev/null +++ b/conv/cprob_code.cc @@ -0,0 +1,150 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "conv/common.hh" + + +namespace param { +using namespace ::args; + +ArgumentParser parser { + "converter: code probability matrix -> code" +}; +HelpFlag help { + parser, "help", "display this menu", {'h', "help"}, +}; + +enum StepAlgo { + kIncrement, +}; +const std::unordered_map kStepAlgo = { + {"inc", kIncrement}, +}; +MapFlag algo { + parser, "inc", "step algorithm (inc)", {"algorithm", "algo"}, kStepAlgo, +}; + +Group inc { + parser, "increment algorithm parameters" +}; +ValueFlag inc_min { + inc, "1", "max stride of increment algorithm", {"inc-min"}, 1, +}; +ValueFlag inc_max { + inc, "1", "max stride of increment algorithm", {"inc-max"}, 1, +}; + +Flag output_prob { + parser, "output-prob", "prints path probability at last", {"output-prob", "prob"}, +}; + +} // namespace param + + +static auto GenerateLegalStepMap(size_t dur, size_t n) { + const auto inc_max = args::get(param::inc_max); + const auto inc_min = args::get(param::inc_min); + Enforce(0 <= inc_min && inc_min <= inc_max, "invalid increment stride"); + + std::vector> ret; + ret.resize(dur*n); + auto legals = &ret[0]; + + for (size_t t = 0; t < dur; ++t) { + for (size_t i = 0; i < n; ++i) { + switch (args::get(param::algo)) { + case param::kIncrement: + legals->reserve(inc_max-inc_min+1); + for (uint32_t j = inc_min; j <= inc_max; ++j) { + legals->push_back(static_cast((i+j)%n)); + } + break; + } + ++legals; + } + } + return ret; +} + +static void Exec() { + const auto cprobs = ReadMatrix(std::cin); + Enforce(cprobs.size() > 0 && cprobs[0].size() > 0, "empty matrix"); + + const auto dur = cprobs.size(); + const auto n = cprobs[0].size(); + + const auto lmap = GenerateLegalStepMap(dur, n); + + struct Step { + double prob = -1; + size_t from = 0; + }; + std::vector steps((dur+1)*n); + for (size_t i = 0; i < n; ++i) { + steps[i].prob = cprobs[0][i]; + } + for (size_t t = 1; t < dur; ++t) { + Enforce(cprobs[t].size() == n, "ill-formed matrix"); + for (size_t i = 0; i < n; ++i) { + const auto& cur = steps[(t-1)*n + i]; + for (auto j : lmap[t*n+i]) { + auto& next = steps[t*n + j]; + + const auto sum = cur.prob + cprobs[t][j]; + if (next.prob < sum) { + next.prob = sum; + next.from = i; + } + } + } + } + + double max_prob = -1; + size_t max_idx = 0; + for (size_t i = 0; i < n; ++i) { + const auto& step = steps[(dur-1)*n + i]; + if (max_prob < step.prob) { + max_prob = step.prob; + max_idx = i; + } + } + + std::vector path = {max_idx}; + path.reserve(dur); + for (size_t t = dur-1; t > 0; --t) { + path.push_back(steps[t*n + path.back()].from); + } + for (auto itr = path.rbegin(); itr < path.rend(); ++itr) { + std::cout << *itr << '\n'; + } + if (param::output_prob) { + std::cout << max_prob/static_cast(path.size())*100 << "%" << std::endl; + } +} + +int main(int argc, char** argv) +try { + param::parser.ParseCLI(argc, argv); + Exec(); + return EXIT_SUCCESS; +} catch (const args::Help&) { + std::cout << param::parser << std::endl; + return EXIT_SUCCESS; +} catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return EXIT_FAILURE; +} +