improve precision of block probabilities

This commit is contained in:
falsycat 2022-09-06 17:24:59 +09:00
parent edd91ef692
commit dc67745175
3 changed files with 69 additions and 43 deletions

View File

@ -55,8 +55,10 @@ static std::vector<std::vector<int32_t>> ReadIndices(
std::istream&) noexcept; std::istream&) noexcept;
static void Embed(int32_t t, Frame& dst, const Frame& base) { static void Embed(int32_t t, Frame& dst, const Frame& base) {
const auto bw = args::get(param::bw); const auto bw = args::get(param::bw);
const auto bh = args::get(param::bw); const auto bh = args::get(param::bw);
const auto hbw = bw/2;
const auto hbh = bh/2;
const auto bx_cnt = dst.w / bw; const auto bx_cnt = dst.w / bw;
const auto by_cnt = dst.h / bh; const auto by_cnt = dst.h / bh;
@ -70,7 +72,15 @@ static void Embed(int32_t t, Frame& dst, const Frame& base) {
for (int32_t y = 0; y < bh; ++y) { for (int32_t y = 0; y < bh; ++y) {
for (int32_t x = 0; x < bw; ++x) { for (int32_t x = 0; x < bw; ++x) {
const auto off = (by*bh+y)*base.w + (bx*bw+x); const auto off = (by*bh+y)*base.w + (bx*bw+x);
dst.Y[off] = (x == 0 || y == 0)? 0: base.Y[off]; // TODO: remove test code dst.Y[off] = base.Y[off];
}
}
for (int32_t y = 0; y < hbh; ++y) {
for (int32_t x = 0; x < hbw; ++x) {
const auto off = (by*hbh+y)*base.hw + (bx*hbw+x);
dst.U[off] = base.U[off];
dst.V[off] = base.V[off];
} }
} }
} }

View File

@ -15,6 +15,8 @@ inline void Enforce(bool eval, const std::string& msg) {
} }
} }
// ---- MP4 utilities
inline void CopyNal(std::vector<uint8_t>& v, const uint8_t* buf, size_t sz) noexcept { inline void CopyNal(std::vector<uint8_t>& v, const uint8_t* buf, size_t sz) noexcept {
v.resize(sz+4); v.resize(sz+4);
v[0] = 0; v[0] = 0;
@ -23,8 +25,6 @@ inline void CopyNal(std::vector<uint8_t>& v, const uint8_t* buf, size_t sz) noex
v[3] = 1; v[3] = 1;
std::memcpy(&v[4], buf, sz); std::memcpy(&v[4], buf, sz);
} }
struct Frame { struct Frame {
std::vector<uint8_t> Y; std::vector<uint8_t> Y;
std::vector<uint8_t> U; std::vector<uint8_t> U;

View File

@ -36,11 +36,11 @@ ValueFlag<int32_t> utime {
parser, "10", "duration of each feature (frame)", {"utime"}, 10, parser, "10", "duration of each feature (frame)", {"utime"}, 10,
}; };
ValueFlag<int32_t> bmw { ValueFlag<int32_t> bmix {
parser, "16", "width of blockmatch region (px)", {"bm-w"}, 16, parser, "8", "x interval of blockmatch (px)", {"bm-ix"}, 8,
}; };
ValueFlag<int32_t> bmh { ValueFlag<int32_t> bmiy {
parser, "16", "height of blockmatch region (px)", {"bm-h"}, 16, parser, "8", "y interval of blockmatch (px)", {"bm-iy"}, 8,
}; };
ValueFlag<int32_t> bmsw { ValueFlag<int32_t> bmsw {
parser, "4", "width of blockmatch search region (px)", {"bm-sw"}, 4, parser, "4", "width of blockmatch search region (px)", {"bm-sw"}, 4,
@ -51,13 +51,14 @@ ValueFlag<int32_t> bmsh {
enum Output { enum Output {
kProb, kProb,
kIndex,
kLen, kLen,
kVec, kVec,
kNull, kNull,
}; };
const std::unordered_map<std::string, Output> kOutput = { const std::unordered_map<std::string, Output> kOutput = {
{"default", kProb},
{"prob", kProb}, {"prob", kProb},
{"index", kIndex},
{"len", kLen}, {"len", kLen},
{"vec", kVec}, {"vec", kVec},
{"null", kNull}, {"null", kNull},
@ -74,13 +75,15 @@ Positional<std::string> vpath {
struct Vec { struct Vec {
double x, y; double x, y, score, len;
}; };
static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) { static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
const auto bmw = args::get(param::bmw); const auto bw = args::get(param::bw);
const auto bmh = args::get(param::bmh); const auto bh = args::get(param::bh);
const auto bmix = args::get(param::bmix);
const auto bmiy = args::get(param::bmiy);
const auto bmsw = args::get(param::bmsw); const auto bmsw = args::get(param::bmsw);
const auto bmsh = args::get(param::bmsh); const auto bmsh = args::get(param::bmsh);
@ -89,8 +92,8 @@ static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t b
for (int32_t sy = -bmsh; sy < bmsh; ++sy) { for (int32_t sy = -bmsh; sy < bmsh; ++sy) {
for (int32_t sx = -bmsw; sx < bmsw; ++sx) { for (int32_t sx = -bmsw; sx < bmsw; ++sx) {
double score = 0; double score = 0;
for (int32_t y = 0; y < bmw; ++y) { for (int32_t y = 0; y < bh; y += bmiy) {
for (int32_t x = 0; x < bmh; ++x) { for (int32_t x = 0; x < bw; x += bmix) {
const auto c_off = (bx+x) + (by+y)*cf.w; const auto c_off = (bx+x) + (by+y)*cf.w;
const auto p_off = (bx+x+sx) + (by+y+sy)*cf.w; const auto p_off = (bx+x+sx) + (by+y+sy)*cf.w;
const auto diff = static_cast<double>(cf.Y[c_off] - pf.Y[p_off]); const auto diff = static_cast<double>(cf.Y[c_off] - pf.Y[p_off]);
@ -107,27 +110,26 @@ static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t b
const auto sxf = static_cast<double>(min_sx) / static_cast<double>(bmsw); const auto sxf = static_cast<double>(min_sx) / static_cast<double>(bmsw);
const auto syf = static_cast<double>(min_sy) / static_cast<double>(bmsh); const auto syf = static_cast<double>(min_sy) / static_cast<double>(bmsh);
return { .x = sxf, .y = syf, }; const auto scf = static_cast<double>(min_score) / static_cast<double>(UINT8_MAX*(bw/bmix)*(bh/bmiy));
return { .x = sxf, .y = syf, .score = scf, .len = std::sqrt(sxf*sxf+syf*syf), };
} }
static double EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) { static Vec EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
const auto v = BlockMatching(cf, pf, bx, by); const auto v = BlockMatching(cf, pf, bx, by);
const auto len = std::sqrt(v.x*v.x + v.y*v.y);
switch (args::get(param::output)) { switch (args::get(param::output)) {
case param::kLen: case param::kLen:
std::cout << len << '\n'; std::cout << v.len << '\n';
break; break;
case param::kVec: case param::kVec:
std::cout << bx << " " << by << " " << v.x << " " << v.y << '\n'; std::cout << bx << " " << by << " " << v.x << " " << v.y << " " << v.score << '\n';
break; break;
default: default:
break; break;
} }
return len; return v;
} }
static void EachFrame(const Frame& cf, const Frame& pf) { static void EachFrame(int32_t t, const Frame& cf, const Frame& pf) {
const auto bw = args::get(param::bw); const auto bw = args::get(param::bw);
const auto bh = args::get(param::bw); const auto bh = args::get(param::bw);
const auto ut = args::get(param::utime); const auto ut = args::get(param::utime);
@ -135,17 +137,22 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
Enforce(cf.w == pf.w && cf.h == pf.h, "variable frame size is not allowed"); Enforce(cf.w == pf.w && cf.h == pf.h, "variable frame size is not allowed");
Enforce(cf.w > bw && cf.h > bh, "block size must be less than frame size"); Enforce(cf.w > bw && cf.h > bh, "block size must be less than frame size");
static size_t cnt = 0; struct Block {
static std::vector<double> probs; double len, score;
if (cnt%ut == 0) { };
probs.clear(); static std::vector<Block> blocks;
probs.resize((cf.w/bw) * (cf.h/bh)); if (t == 1) {
blocks.clear();
blocks.resize((cf.w/bw) * (cf.h/bh));
} }
double* prob = probs.data(); auto block = blocks.data();
for (int32_t by = 0; by+bh < cf.h; by+=bh) { for (int32_t by = 0; by+bh <= cf.h; by+=bh) {
for (int32_t bx = 0; bx+bw < cf.w; bx+=bw) { for (int32_t bx = 0; bx+bw <= cf.w; bx+=bw) {
*(prob++) += EachBlock(cf, pf, bx, by); const auto v = EachBlock(cf, pf, bx, by);
block->score += v.score;
block->len += v.len;
++block;
} }
} }
@ -154,10 +161,19 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
case param::kVec: case param::kVec:
std::cout << std::endl; std::cout << std::endl;
break; break;
case param::kIndex:
case param::kProb: case param::kProb:
if ((cnt+1)%ut == 0) { if (t == ut-1) {
for (const auto prob : probs) { for (size_t i = 0; i < blocks.size(); ++i) {
std::cout << prob/(ut-1)/std::sqrt(2) << ' '; const auto len = blocks[i].len/(ut-1)/std::sqrt(2); // length calculation
const auto score = blocks[i].score/(ut-1);
const auto prob = std::clamp((1-len) * (1-score), 0., 1.);
if (args::get(param::output) == param::kIndex) {
if (prob > 0.95) std::cout << i << ' ';
} else {
std::cout << prob << ' ';
}
} }
std::cout << std::endl; std::cout << std::endl;
} }
@ -165,7 +181,6 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
default: default:
break; break;
} }
++cnt;
} }
static void Exec() { static void Exec() {
@ -175,11 +190,11 @@ static void Exec() {
Enforce(bw > 0 && bh > 0, "block size must be greater than 0"); Enforce(bw > 0 && bh > 0, "block size must be greater than 0");
Enforce(ut > 0, "utime must be greater than 0"); Enforce(ut > 0, "utime must be greater than 0");
const auto bmw = args::get(param::bmw); const auto bmix = args::get(param::bmix);
const auto bmh = args::get(param::bmh); const auto bmiy = args::get(param::bmiy);
const auto bmsw = args::get(param::bmw); const auto bmsw = args::get(param::bmsw);
const auto bmsh = args::get(param::bmh); const auto bmsh = args::get(param::bmsh);
Enforce(bmw > 0 && bmh > 0, "block matching region size must be greater than 0"); Enforce(bmix > 0 && bmiy > 0, "block matching search interval must be greater than 0");
Enforce(bmsw > 0 && bmsh > 0, "block matching search region size must be greater than 0"); Enforce(bmsw > 0 && bmsh > 0, "block matching search region size must be greater than 0");
// open video stream // open video stream
@ -276,8 +291,9 @@ static void Exec() {
Enforce(ret == 0, "frame decode failure"); Enforce(ret == 0, "frame decode failure");
Frame cf = {yuv, frame}; Frame cf = {yuv, frame};
if (fidx%ut > 0) { const auto utf = fidx%ut;
EachFrame(cf, pf); if (utf > 0) {
EachFrame(utf, cf, pf);
} }
pf = std::move(cf); pf = std::move(cf);