improve precision of block probabilities
This commit is contained in:
parent
edd91ef692
commit
dc67745175
@ -57,6 +57,8 @@ static std::vector<std::vector<int32_t>> ReadIndices(
|
|||||||
static void Embed(int32_t t, Frame& dst, const Frame& base) {
|
static void Embed(int32_t t, Frame& dst, const Frame& base) {
|
||||||
const auto bw = args::get(param::bw);
|
const auto bw = args::get(param::bw);
|
||||||
const auto bh = args::get(param::bw);
|
const auto bh = args::get(param::bw);
|
||||||
|
const auto hbw = bw/2;
|
||||||
|
const auto hbh = bh/2;
|
||||||
|
|
||||||
const auto bx_cnt = dst.w / bw;
|
const auto bx_cnt = dst.w / bw;
|
||||||
const auto by_cnt = dst.h / bh;
|
const auto by_cnt = dst.h / bh;
|
||||||
@ -70,7 +72,15 @@ static void Embed(int32_t t, Frame& dst, const Frame& base) {
|
|||||||
for (int32_t y = 0; y < bh; ++y) {
|
for (int32_t y = 0; y < bh; ++y) {
|
||||||
for (int32_t x = 0; x < bw; ++x) {
|
for (int32_t x = 0; x < bw; ++x) {
|
||||||
const auto off = (by*bh+y)*base.w + (bx*bw+x);
|
const auto off = (by*bh+y)*base.w + (bx*bw+x);
|
||||||
dst.Y[off] = (x == 0 || y == 0)? 0: base.Y[off]; // TODO: remove test code
|
dst.Y[off] = base.Y[off];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t y = 0; y < hbh; ++y) {
|
||||||
|
for (int32_t x = 0; x < hbw; ++x) {
|
||||||
|
const auto off = (by*hbh+y)*base.hw + (bx*hbw+x);
|
||||||
|
dst.U[off] = base.U[off];
|
||||||
|
dst.V[off] = base.V[off];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,8 @@ inline void Enforce(bool eval, const std::string& msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ---- MP4 utilities
|
||||||
inline void CopyNal(std::vector<uint8_t>& v, const uint8_t* buf, size_t sz) noexcept {
|
inline void CopyNal(std::vector<uint8_t>& v, const uint8_t* buf, size_t sz) noexcept {
|
||||||
v.resize(sz+4);
|
v.resize(sz+4);
|
||||||
v[0] = 0;
|
v[0] = 0;
|
||||||
@ -23,8 +25,6 @@ inline void CopyNal(std::vector<uint8_t>& v, const uint8_t* buf, size_t sz) noex
|
|||||||
v[3] = 1;
|
v[3] = 1;
|
||||||
std::memcpy(&v[4], buf, sz);
|
std::memcpy(&v[4], buf, sz);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
struct Frame {
|
struct Frame {
|
||||||
std::vector<uint8_t> Y;
|
std::vector<uint8_t> Y;
|
||||||
std::vector<uint8_t> U;
|
std::vector<uint8_t> U;
|
||||||
|
@ -36,11 +36,11 @@ ValueFlag<int32_t> utime {
|
|||||||
parser, "10", "duration of each feature (frame)", {"utime"}, 10,
|
parser, "10", "duration of each feature (frame)", {"utime"}, 10,
|
||||||
};
|
};
|
||||||
|
|
||||||
ValueFlag<int32_t> bmw {
|
ValueFlag<int32_t> bmix {
|
||||||
parser, "16", "width of blockmatch region (px)", {"bm-w"}, 16,
|
parser, "8", "x interval of blockmatch (px)", {"bm-ix"}, 8,
|
||||||
};
|
};
|
||||||
ValueFlag<int32_t> bmh {
|
ValueFlag<int32_t> bmiy {
|
||||||
parser, "16", "height of blockmatch region (px)", {"bm-h"}, 16,
|
parser, "8", "y interval of blockmatch (px)", {"bm-iy"}, 8,
|
||||||
};
|
};
|
||||||
ValueFlag<int32_t> bmsw {
|
ValueFlag<int32_t> bmsw {
|
||||||
parser, "4", "width of blockmatch search region (px)", {"bm-sw"}, 4,
|
parser, "4", "width of blockmatch search region (px)", {"bm-sw"}, 4,
|
||||||
@ -51,13 +51,14 @@ ValueFlag<int32_t> bmsh {
|
|||||||
|
|
||||||
enum Output {
|
enum Output {
|
||||||
kProb,
|
kProb,
|
||||||
|
kIndex,
|
||||||
kLen,
|
kLen,
|
||||||
kVec,
|
kVec,
|
||||||
kNull,
|
kNull,
|
||||||
};
|
};
|
||||||
const std::unordered_map<std::string, Output> kOutput = {
|
const std::unordered_map<std::string, Output> kOutput = {
|
||||||
{"default", kProb},
|
|
||||||
{"prob", kProb},
|
{"prob", kProb},
|
||||||
|
{"index", kIndex},
|
||||||
{"len", kLen},
|
{"len", kLen},
|
||||||
{"vec", kVec},
|
{"vec", kVec},
|
||||||
{"null", kNull},
|
{"null", kNull},
|
||||||
@ -74,13 +75,15 @@ Positional<std::string> vpath {
|
|||||||
|
|
||||||
|
|
||||||
struct Vec {
|
struct Vec {
|
||||||
double x, y;
|
double x, y, score, len;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
|
static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
|
||||||
const auto bmw = args::get(param::bmw);
|
const auto bw = args::get(param::bw);
|
||||||
const auto bmh = args::get(param::bmh);
|
const auto bh = args::get(param::bh);
|
||||||
|
const auto bmix = args::get(param::bmix);
|
||||||
|
const auto bmiy = args::get(param::bmiy);
|
||||||
const auto bmsw = args::get(param::bmsw);
|
const auto bmsw = args::get(param::bmsw);
|
||||||
const auto bmsh = args::get(param::bmsh);
|
const auto bmsh = args::get(param::bmsh);
|
||||||
|
|
||||||
@ -89,8 +92,8 @@ static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t b
|
|||||||
for (int32_t sy = -bmsh; sy < bmsh; ++sy) {
|
for (int32_t sy = -bmsh; sy < bmsh; ++sy) {
|
||||||
for (int32_t sx = -bmsw; sx < bmsw; ++sx) {
|
for (int32_t sx = -bmsw; sx < bmsw; ++sx) {
|
||||||
double score = 0;
|
double score = 0;
|
||||||
for (int32_t y = 0; y < bmw; ++y) {
|
for (int32_t y = 0; y < bh; y += bmiy) {
|
||||||
for (int32_t x = 0; x < bmh; ++x) {
|
for (int32_t x = 0; x < bw; x += bmix) {
|
||||||
const auto c_off = (bx+x) + (by+y)*cf.w;
|
const auto c_off = (bx+x) + (by+y)*cf.w;
|
||||||
const auto p_off = (bx+x+sx) + (by+y+sy)*cf.w;
|
const auto p_off = (bx+x+sx) + (by+y+sy)*cf.w;
|
||||||
const auto diff = static_cast<double>(cf.Y[c_off] - pf.Y[p_off]);
|
const auto diff = static_cast<double>(cf.Y[c_off] - pf.Y[p_off]);
|
||||||
@ -107,27 +110,26 @@ static Vec BlockMatching(const Frame& cf, const Frame& pf, int32_t bx, int32_t b
|
|||||||
|
|
||||||
const auto sxf = static_cast<double>(min_sx) / static_cast<double>(bmsw);
|
const auto sxf = static_cast<double>(min_sx) / static_cast<double>(bmsw);
|
||||||
const auto syf = static_cast<double>(min_sy) / static_cast<double>(bmsh);
|
const auto syf = static_cast<double>(min_sy) / static_cast<double>(bmsh);
|
||||||
return { .x = sxf, .y = syf, };
|
const auto scf = static_cast<double>(min_score) / static_cast<double>(UINT8_MAX*(bw/bmix)*(bh/bmiy));
|
||||||
|
return { .x = sxf, .y = syf, .score = scf, .len = std::sqrt(sxf*sxf+syf*syf), };
|
||||||
}
|
}
|
||||||
|
|
||||||
static double EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
|
static Vec EachBlock(const Frame& cf, const Frame& pf, int32_t bx, int32_t by) {
|
||||||
const auto v = BlockMatching(cf, pf, bx, by);
|
const auto v = BlockMatching(cf, pf, bx, by);
|
||||||
|
|
||||||
const auto len = std::sqrt(v.x*v.x + v.y*v.y);
|
|
||||||
switch (args::get(param::output)) {
|
switch (args::get(param::output)) {
|
||||||
case param::kLen:
|
case param::kLen:
|
||||||
std::cout << len << '\n';
|
std::cout << v.len << '\n';
|
||||||
break;
|
break;
|
||||||
case param::kVec:
|
case param::kVec:
|
||||||
std::cout << bx << " " << by << " " << v.x << " " << v.y << '\n';
|
std::cout << bx << " " << by << " " << v.x << " " << v.y << " " << v.score << '\n';
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return len;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void EachFrame(const Frame& cf, const Frame& pf) {
|
static void EachFrame(int32_t t, const Frame& cf, const Frame& pf) {
|
||||||
const auto bw = args::get(param::bw);
|
const auto bw = args::get(param::bw);
|
||||||
const auto bh = args::get(param::bw);
|
const auto bh = args::get(param::bw);
|
||||||
const auto ut = args::get(param::utime);
|
const auto ut = args::get(param::utime);
|
||||||
@ -135,17 +137,22 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
|
|||||||
Enforce(cf.w == pf.w && cf.h == pf.h, "variable frame size is not allowed");
|
Enforce(cf.w == pf.w && cf.h == pf.h, "variable frame size is not allowed");
|
||||||
Enforce(cf.w > bw && cf.h > bh, "block size must be less than frame size");
|
Enforce(cf.w > bw && cf.h > bh, "block size must be less than frame size");
|
||||||
|
|
||||||
static size_t cnt = 0;
|
struct Block {
|
||||||
static std::vector<double> probs;
|
double len, score;
|
||||||
if (cnt%ut == 0) {
|
};
|
||||||
probs.clear();
|
static std::vector<Block> blocks;
|
||||||
probs.resize((cf.w/bw) * (cf.h/bh));
|
if (t == 1) {
|
||||||
|
blocks.clear();
|
||||||
|
blocks.resize((cf.w/bw) * (cf.h/bh));
|
||||||
}
|
}
|
||||||
|
|
||||||
double* prob = probs.data();
|
auto block = blocks.data();
|
||||||
for (int32_t by = 0; by+bh < cf.h; by+=bh) {
|
for (int32_t by = 0; by+bh <= cf.h; by+=bh) {
|
||||||
for (int32_t bx = 0; bx+bw < cf.w; bx+=bw) {
|
for (int32_t bx = 0; bx+bw <= cf.w; bx+=bw) {
|
||||||
*(prob++) += EachBlock(cf, pf, bx, by);
|
const auto v = EachBlock(cf, pf, bx, by);
|
||||||
|
block->score += v.score;
|
||||||
|
block->len += v.len;
|
||||||
|
++block;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,10 +161,19 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
|
|||||||
case param::kVec:
|
case param::kVec:
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
break;
|
break;
|
||||||
|
case param::kIndex:
|
||||||
case param::kProb:
|
case param::kProb:
|
||||||
if ((cnt+1)%ut == 0) {
|
if (t == ut-1) {
|
||||||
for (const auto prob : probs) {
|
for (size_t i = 0; i < blocks.size(); ++i) {
|
||||||
std::cout << prob/(ut-1)/std::sqrt(2) << ' ';
|
const auto len = blocks[i].len/(ut-1)/std::sqrt(2); // length calculation
|
||||||
|
const auto score = blocks[i].score/(ut-1);
|
||||||
|
const auto prob = std::clamp((1-len) * (1-score), 0., 1.);
|
||||||
|
|
||||||
|
if (args::get(param::output) == param::kIndex) {
|
||||||
|
if (prob > 0.95) std::cout << i << ' ';
|
||||||
|
} else {
|
||||||
|
std::cout << prob << ' ';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -165,7 +181,6 @@ static void EachFrame(const Frame& cf, const Frame& pf) {
|
|||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
++cnt;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void Exec() {
|
static void Exec() {
|
||||||
@ -175,11 +190,11 @@ static void Exec() {
|
|||||||
Enforce(bw > 0 && bh > 0, "block size must be greater than 0");
|
Enforce(bw > 0 && bh > 0, "block size must be greater than 0");
|
||||||
Enforce(ut > 0, "utime must be greater than 0");
|
Enforce(ut > 0, "utime must be greater than 0");
|
||||||
|
|
||||||
const auto bmw = args::get(param::bmw);
|
const auto bmix = args::get(param::bmix);
|
||||||
const auto bmh = args::get(param::bmh);
|
const auto bmiy = args::get(param::bmiy);
|
||||||
const auto bmsw = args::get(param::bmw);
|
const auto bmsw = args::get(param::bmsw);
|
||||||
const auto bmsh = args::get(param::bmh);
|
const auto bmsh = args::get(param::bmsh);
|
||||||
Enforce(bmw > 0 && bmh > 0, "block matching region size must be greater than 0");
|
Enforce(bmix > 0 && bmiy > 0, "block matching search interval must be greater than 0");
|
||||||
Enforce(bmsw > 0 && bmsh > 0, "block matching search region size must be greater than 0");
|
Enforce(bmsw > 0 && bmsh > 0, "block matching search region size must be greater than 0");
|
||||||
|
|
||||||
// open video stream
|
// open video stream
|
||||||
@ -276,8 +291,9 @@ static void Exec() {
|
|||||||
Enforce(ret == 0, "frame decode failure");
|
Enforce(ret == 0, "frame decode failure");
|
||||||
|
|
||||||
Frame cf = {yuv, frame};
|
Frame cf = {yuv, frame};
|
||||||
if (fidx%ut > 0) {
|
const auto utf = fidx%ut;
|
||||||
EachFrame(cf, pf);
|
if (utf > 0) {
|
||||||
|
EachFrame(utf, cf, pf);
|
||||||
}
|
}
|
||||||
pf = std::move(cf);
|
pf = std::move(cf);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user