From dc67745175fd77faeb6e1dcc0e359432723e5f1e Mon Sep 17 00:00:00 2001 From: falsycat Date: Tue, 6 Sep 2022 17:24:59 +0900 Subject: [PATCH] improve precision of block probabilities --- conv/bidx_video.cc | 16 ++++++-- conv/common.hh | 4 +- conv/video_bprob.cc | 92 ++++++++++++++++++++++++++------------------- 3 files changed, 69 insertions(+), 43 deletions(-) diff --git a/conv/bidx_video.cc b/conv/bidx_video.cc index 28e5aa0..03e0c43 100644 --- a/conv/bidx_video.cc +++ b/conv/bidx_video.cc @@ -55,8 +55,10 @@ static std::vector> ReadIndices( std::istream&) noexcept; static void Embed(int32_t t, Frame& dst, const Frame& base) { - const auto bw = args::get(param::bw); - const auto bh = args::get(param::bw); + const auto bw = args::get(param::bw); + const auto bh = args::get(param::bw); + const auto hbw = bw/2; + const auto hbh = bh/2; const auto bx_cnt = dst.w / bw; const auto by_cnt = dst.h / bh; @@ -70,7 +72,15 @@ static void Embed(int32_t t, Frame& dst, const Frame& base) { for (int32_t y = 0; y < bh; ++y) { for (int32_t x = 0; x < bw; ++x) { const auto off = (by*bh+y)*base.w + (bx*bw+x); - dst.Y[off] = (x == 0 || y == 0)? 0: base.Y[off]; // TODO: remove test code + dst.Y[off] = base.Y[off]; + } + } + + for (int32_t y = 0; y < hbh; ++y) { + for (int32_t x = 0; x < hbw; ++x) { + const auto off = (by*hbh+y)*base.hw + (bx*hbw+x); + dst.U[off] = base.U[off]; + dst.V[off] = base.V[off]; } } } diff --git a/conv/common.hh b/conv/common.hh index fc8496b..c7e5346 100644 --- a/conv/common.hh +++ b/conv/common.hh @@ -15,6 +15,8 @@ inline void Enforce(bool eval, const std::string& msg) { } } + +// ---- MP4 utilities inline void CopyNal(std::vector& v, const uint8_t* buf, size_t sz) noexcept { v.resize(sz+4); v[0] = 0; @@ -23,8 +25,6 @@ inline void CopyNal(std::vector& v, const uint8_t* buf, size_t sz) noex v[3] = 1; std::memcpy(&v[4], buf, sz); } - - struct Frame { std::vector Y; std::vector U; diff --git a/conv/video_bprob.cc b/conv/video_bprob.cc index 05b01e2..a172051 100644 --- a/conv/video_bprob.cc +++ b/conv/video_bprob.cc @@ -36,11 +36,11 @@ ValueFlag utime { parser, "10", "duration of each feature (frame)", {"utime"}, 10, }; -ValueFlag bmw { - parser, "16", "width of blockmatch region (px)", {"bm-w"}, 16, +ValueFlag bmix { + parser, "8", "x interval of blockmatch (px)", {"bm-ix"}, 8, }; -ValueFlag bmh { - parser, "16", "height of blockmatch region (px)", {"bm-h"}, 16, +ValueFlag bmiy { + parser, "8", "y interval of blockmatch (px)", {"bm-iy"}, 8, }; ValueFlag bmsw { parser, "4", "width of blockmatch search region (px)", {"bm-sw"}, 4, @@ -51,13 +51,14 @@ ValueFlag bmsh { enum Output { kProb, + kIndex, kLen, kVec, kNull, }; const std::unordered_map kOutput = { - {"default", kProb}, {"prob", kProb}, + {"index", kIndex}, {"len", kLen}, {"vec", kVec}, {"null", kNull}, @@ -74,13 +75,15 @@ Positional vpath { struct Vec { - double x, y; + double x, y, score, len; }; static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) { - const auto bmw = args::get(param::bmw); - const auto bmh = args::get(param::bmh); + const auto bw = args::get(param::bw); + const auto bh = args::get(param::bh); + const auto bmix = args::get(param::bmix); + const auto bmiy = args::get(param::bmiy); const auto bmsw = args::get(param::bmsw); const auto bmsh = args::get(param::bmsh); @@ -89,8 +92,8 @@ static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t b for (int32_t sy = -bmsh; sy < bmsh; ++sy) { for (int32_t sx = -bmsw; sx < bmsw; ++sx) { double score = 0; - for (int32_t y = 0; y < bmw; ++y) { - for (int32_t x = 0; x < bmh; ++x) { + for (int32_t y = 0; y < bh; y += bmiy) { + for (int32_t x = 0; x < bw; x += bmix) { const auto c_off = (bx+x) + (by+y)*cf.w; const auto p_off = (bx+x+sx) + (by+y+sy)*cf.w; const auto diff = static_cast(cf.Y[c_off] - pf.Y[p_off]); @@ -107,27 +110,26 @@ static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t b const auto sxf = static_cast(min_sx) / static_cast(bmsw); const auto syf = static_cast(min_sy) / static_cast(bmsh); - return { .x = sxf, .y = syf, }; + const auto scf = static_cast(min_score) / static_cast(UINT8_MAX*(bw/bmix)*(bh/bmiy)); + return { .x = sxf, .y = syf, .score = scf, .len = std::sqrt(sxf*sxf+syf*syf), }; } -static double EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) { +static Vec EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) { const auto v = BlockMatching(cf, pf, bx, by); - - const auto len = std::sqrt(v.x*v.x + v.y*v.y); switch (args::get(param::output)) { case param::kLen: - std::cout << len << '\n'; + std::cout << v.len << '\n'; break; case param::kVec: - std::cout << bx << " " << by << " " << v.x << " " << v.y << '\n'; + std::cout << bx << " " << by << " " << v.x << " " << v.y << " " << v.score << '\n'; break; default: break; } - return len; + return v; } -static void EachFrame(const Frame& cf, const Frame& pf) { +static void EachFrame(int32_t t, const Frame& cf, const Frame& pf) { const auto bw = args::get(param::bw); const auto bh = args::get(param::bw); const auto ut = args::get(param::utime); @@ -135,17 +137,22 @@ static void EachFrame(const Frame& cf, const Frame& pf) { Enforce(cf.w == pf.w && cf.h == pf.h, "variable frame size is not allowed"); Enforce(cf.w > bw && cf.h > bh, "block size must be less than frame size"); - static size_t cnt = 0; - static std::vector probs; - if (cnt%ut == 0) { - probs.clear(); - probs.resize((cf.w/bw) * (cf.h/bh)); + struct Block { + double len, score; + }; + static std::vector blocks; + if (t == 1) { + blocks.clear(); + blocks.resize((cf.w/bw) * (cf.h/bh)); } - double* prob = probs.data(); - for (int32_t by = 0; by+bh < cf.h; by+=bh) { - for (int32_t bx = 0; bx+bw < cf.w; bx+=bw) { - *(prob++) += EachBlock(cf, pf, bx, by); + auto block = blocks.data(); + for (int32_t by = 0; by+bh <= cf.h; by+=bh) { + for (int32_t bx = 0; bx+bw <= cf.w; bx+=bw) { + const auto v = EachBlock(cf, pf, bx, by); + block->score += v.score; + block->len += v.len; + ++block; } } @@ -154,10 +161,19 @@ static void EachFrame(const Frame& cf, const Frame& pf) { case param::kVec: std::cout << std::endl; break; + case param::kIndex: case param::kProb: - if ((cnt+1)%ut == 0) { - for (const auto prob : probs) { - std::cout << prob/(ut-1)/std::sqrt(2) << ' '; + if (t == ut-1) { + for (size_t i = 0; i < blocks.size(); ++i) { + const auto len = blocks[i].len/(ut-1)/std::sqrt(2); // length calculation + const auto score = blocks[i].score/(ut-1); + const auto prob = std::clamp((1-len) * (1-score), 0., 1.); + + if (args::get(param::output) == param::kIndex) { + if (prob > 0.95) std::cout << i << ' '; + } else { + std::cout << prob << ' '; + } } std::cout << std::endl; } @@ -165,7 +181,6 @@ static void EachFrame(const Frame& cf, const Frame& pf) { default: break; } - ++cnt; } static void Exec() { @@ -175,11 +190,11 @@ static void Exec() { Enforce(bw > 0 && bh > 0, "block size must be greater than 0"); Enforce(ut > 0, "utime must be greater than 0"); - const auto bmw = args::get(param::bmw); - const auto bmh = args::get(param::bmh); - const auto bmsw = args::get(param::bmw); - const auto bmsh = args::get(param::bmh); - Enforce(bmw > 0 && bmh > 0, "block matching region size must be greater than 0"); + const auto bmix = args::get(param::bmix); + const auto bmiy = args::get(param::bmiy); + const auto bmsw = args::get(param::bmsw); + const auto bmsh = args::get(param::bmsh); + Enforce(bmix > 0 && bmiy > 0, "block matching search interval must be greater than 0"); Enforce(bmsw > 0 && bmsh > 0, "block matching search region size must be greater than 0"); // open video stream @@ -276,8 +291,9 @@ static void Exec() { Enforce(ret == 0, "frame decode failure"); Frame cf = {yuv, frame}; - if (fidx%ut > 0) { - EachFrame(cf, pf); + const auto utf = fidx%ut; + if (utf > 0) { + EachFrame(utf, cf, pf); } pf = std::move(cf);