improve precision of block probabilities
This commit is contained in:
parent
edd91ef692
commit
dc67745175
@ -55,8 +55,10 @@ static std::vector<std::vector<int32_t>> ReadIndices(
|
||||
std::istream&) noexcept;
|
||||
|
||||
static void Embed(int32_t t, Frame& dst, const Frame& base) {
|
||||
const auto bw = args::get(param::bw);
|
||||
const auto bh = args::get(param::bw);
|
||||
const auto bw = args::get(param::bw);
|
||||
const auto bh = args::get(param::bw);
|
||||
const auto hbw = bw/2;
|
||||
const auto hbh = bh/2;
|
||||
|
||||
const auto bx_cnt = dst.w / bw;
|
||||
const auto by_cnt = dst.h / bh;
|
||||
@ -70,7 +72,15 @@ static void Embed(int32_t t, Frame& dst, const Frame& base) {
|
||||
for (int32_t y = 0; y < bh; ++y) {
|
||||
for (int32_t x = 0; x < bw; ++x) {
|
||||
const auto off = (by*bh+y)*base.w + (bx*bw+x);
|
||||
dst.Y[off] = (x == 0 || y == 0)? 0: base.Y[off]; // TODO: remove test code
|
||||
dst.Y[off] = base.Y[off];
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t y = 0; y < hbh; ++y) {
|
||||
for (int32_t x = 0; x < hbw; ++x) {
|
||||
const auto off = (by*hbh+y)*base.hw + (bx*hbw+x);
|
||||
dst.U[off] = base.U[off];
|
||||
dst.V[off] = base.V[off];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,8 @@ inline void Enforce(bool eval, const std::string& msg) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ---- MP4 utilities
|
||||
inline void CopyNal(std::vector<uint8_t>& v, const uint8_t* buf, size_t sz) noexcept {
|
||||
v.resize(sz+4);
|
||||
v[0] = 0;
|
||||
@ -23,8 +25,6 @@ inline void CopyNal(std::vector<uint8_t>& v, const uint8_t* buf, size_t sz) noex
|
||||
v[3] = 1;
|
||||
std::memcpy(&v[4], buf, sz);
|
||||
}
|
||||
|
||||
|
||||
struct Frame {
|
||||
std::vector<uint8_t> Y;
|
||||
std::vector<uint8_t> U;
|
||||
|
@ -36,11 +36,11 @@ ValueFlag<int32_t> utime {
|
||||
parser, "10", "duration of each feature (frame)", {"utime"}, 10,
|
||||
};
|
||||
|
||||
ValueFlag<int32_t> bmw {
|
||||
parser, "16", "width of blockmatch region (px)", {"bm-w"}, 16,
|
||||
ValueFlag<int32_t> bmix {
|
||||
parser, "8", "x interval of blockmatch (px)", {"bm-ix"}, 8,
|
||||
};
|
||||
ValueFlag<int32_t> bmh {
|
||||
parser, "16", "height of blockmatch region (px)", {"bm-h"}, 16,
|
||||
ValueFlag<int32_t> bmiy {
|
||||
parser, "8", "y interval of blockmatch (px)", {"bm-iy"}, 8,
|
||||
};
|
||||
ValueFlag<int32_t> bmsw {
|
||||
parser, "4", "width of blockmatch search region (px)", {"bm-sw"}, 4,
|
||||
@ -51,13 +51,14 @@ ValueFlag<int32_t> bmsh {
|
||||
|
||||
enum Output {
|
||||
kProb,
|
||||
kIndex,
|
||||
kLen,
|
||||
kVec,
|
||||
kNull,
|
||||
};
|
||||
const std::unordered_map<std::string, Output> kOutput = {
|
||||
{"default", kProb},
|
||||
{"prob", kProb},
|
||||
{"index", kIndex},
|
||||
{"len", kLen},
|
||||
{"vec", kVec},
|
||||
{"null", kNull},
|
||||
@ -74,13 +75,15 @@ Positional<std::string> vpath {
|
||||
|
||||
|
||||
struct Vec {
|
||||
double x, y;
|
||||
double x, y, score, len;
|
||||
};
|
||||
|
||||
|
||||
static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
|
||||
const auto bmw = args::get(param::bmw);
|
||||
const auto bmh = args::get(param::bmh);
|
||||
const auto bw = args::get(param::bw);
|
||||
const auto bh = args::get(param::bh);
|
||||
const auto bmix = args::get(param::bmix);
|
||||
const auto bmiy = args::get(param::bmiy);
|
||||
const auto bmsw = args::get(param::bmsw);
|
||||
const auto bmsh = args::get(param::bmsh);
|
||||
|
||||
@ -89,8 +92,8 @@ static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t b
|
||||
for (int32_t sy = -bmsh; sy < bmsh; ++sy) {
|
||||
for (int32_t sx = -bmsw; sx < bmsw; ++sx) {
|
||||
double score = 0;
|
||||
for (int32_t y = 0; y < bmw; ++y) {
|
||||
for (int32_t x = 0; x < bmh; ++x) {
|
||||
for (int32_t y = 0; y < bh; y += bmiy) {
|
||||
for (int32_t x = 0; x < bw; x += bmix) {
|
||||
const auto c_off = (bx+x) + (by+y)*cf.w;
|
||||
const auto p_off = (bx+x+sx) + (by+y+sy)*cf.w;
|
||||
const auto diff = static_cast<double>(cf.Y[c_off] - pf.Y[p_off]);
|
||||
@ -107,27 +110,26 @@ static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t b
|
||||
|
||||
const auto sxf = static_cast<double>(min_sx) / static_cast<double>(bmsw);
|
||||
const auto syf = static_cast<double>(min_sy) / static_cast<double>(bmsh);
|
||||
return { .x = sxf, .y = syf, };
|
||||
const auto scf = static_cast<double>(min_score) / static_cast<double>(UINT8_MAX*(bw/bmix)*(bh/bmiy));
|
||||
return { .x = sxf, .y = syf, .score = scf, .len = std::sqrt(sxf*sxf+syf*syf), };
|
||||
}
|
||||
|
||||
static double EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
|
||||
static Vec EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
|
||||
const auto v = BlockMatching(cf, pf, bx, by);
|
||||
|
||||
const auto len = std::sqrt(v.x*v.x + v.y*v.y);
|
||||
switch (args::get(param::output)) {
|
||||
case param::kLen:
|
||||
std::cout << len << '\n';
|
||||
std::cout << v.len << '\n';
|
||||
break;
|
||||
case param::kVec:
|
||||
std::cout << bx << " " << by << " " << v.x << " " << v.y << '\n';
|
||||
std::cout << bx << " " << by << " " << v.x << " " << v.y << " " << v.score << '\n';
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return len;
|
||||
return v;
|
||||
}
|
||||
|
||||
static void EachFrame(const Frame& cf, const Frame& pf) {
|
||||
static void EachFrame(int32_t t, const Frame& cf, const Frame& pf) {
|
||||
const auto bw = args::get(param::bw);
|
||||
const auto bh = args::get(param::bw);
|
||||
const auto ut = args::get(param::utime);
|
||||
@ -135,17 +137,22 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
|
||||
Enforce(cf.w == pf.w && cf.h == pf.h, "variable frame size is not allowed");
|
||||
Enforce(cf.w > bw && cf.h > bh, "block size must be less than frame size");
|
||||
|
||||
static size_t cnt = 0;
|
||||
static std::vector<double> probs;
|
||||
if (cnt%ut == 0) {
|
||||
probs.clear();
|
||||
probs.resize((cf.w/bw) * (cf.h/bh));
|
||||
struct Block {
|
||||
double len, score;
|
||||
};
|
||||
static std::vector<Block> blocks;
|
||||
if (t == 1) {
|
||||
blocks.clear();
|
||||
blocks.resize((cf.w/bw) * (cf.h/bh));
|
||||
}
|
||||
|
||||
double* prob = probs.data();
|
||||
for (int32_t by = 0; by+bh < cf.h; by+=bh) {
|
||||
for (int32_t bx = 0; bx+bw < cf.w; bx+=bw) {
|
||||
*(prob++) += EachBlock(cf, pf, bx, by);
|
||||
auto block = blocks.data();
|
||||
for (int32_t by = 0; by+bh <= cf.h; by+=bh) {
|
||||
for (int32_t bx = 0; bx+bw <= cf.w; bx+=bw) {
|
||||
const auto v = EachBlock(cf, pf, bx, by);
|
||||
block->score += v.score;
|
||||
block->len += v.len;
|
||||
++block;
|
||||
}
|
||||
}
|
||||
|
||||
@ -154,10 +161,19 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
|
||||
case param::kVec:
|
||||
std::cout << std::endl;
|
||||
break;
|
||||
case param::kIndex:
|
||||
case param::kProb:
|
||||
if ((cnt+1)%ut == 0) {
|
||||
for (const auto prob : probs) {
|
||||
std::cout << prob/(ut-1)/std::sqrt(2) << ' ';
|
||||
if (t == ut-1) {
|
||||
for (size_t i = 0; i < blocks.size(); ++i) {
|
||||
const auto len = blocks[i].len/(ut-1)/std::sqrt(2); // length calculation
|
||||
const auto score = blocks[i].score/(ut-1);
|
||||
const auto prob = std::clamp((1-len) * (1-score), 0., 1.);
|
||||
|
||||
if (args::get(param::output) == param::kIndex) {
|
||||
if (prob > 0.95) std::cout << i << ' ';
|
||||
} else {
|
||||
std::cout << prob << ' ';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
@ -165,7 +181,6 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
|
||||
default:
|
||||
break;
|
||||
}
|
||||
++cnt;
|
||||
}
|
||||
|
||||
static void Exec() {
|
||||
@ -175,11 +190,11 @@ static void Exec() {
|
||||
Enforce(bw > 0 && bh > 0, "block size must be greater than 0");
|
||||
Enforce(ut > 0, "utime must be greater than 0");
|
||||
|
||||
const auto bmw = args::get(param::bmw);
|
||||
const auto bmh = args::get(param::bmh);
|
||||
const auto bmsw = args::get(param::bmw);
|
||||
const auto bmsh = args::get(param::bmh);
|
||||
Enforce(bmw > 0 && bmh > 0, "block matching region size must be greater than 0");
|
||||
const auto bmix = args::get(param::bmix);
|
||||
const auto bmiy = args::get(param::bmiy);
|
||||
const auto bmsw = args::get(param::bmsw);
|
||||
const auto bmsh = args::get(param::bmsh);
|
||||
Enforce(bmix > 0 && bmiy > 0, "block matching search interval must be greater than 0");
|
||||
Enforce(bmsw > 0 && bmsh > 0, "block matching search region size must be greater than 0");
|
||||
|
||||
// open video stream
|
||||
@ -276,8 +291,9 @@ static void Exec() {
|
||||
Enforce(ret == 0, "frame decode failure");
|
||||
|
||||
Frame cf = {yuv, frame};
|
||||
if (fidx%ut > 0) {
|
||||
EachFrame(cf, pf);
|
||||
const auto utf = fidx%ut;
|
||||
if (utf > 0) {
|
||||
EachFrame(utf, cf, pf);
|
||||
}
|
||||
pf = std::move(cf);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user