From dd4f0262d2291c77ae3575c1060a401a7a2aeb27 Mon Sep 17 00:00:00 2001 From: falsycat Date: Wed, 9 Nov 2022 22:00:56 +0900 Subject: [PATCH] implement negative-sample option to conv/aprob_fprob --- conv/aprob_fprob.cc | 47 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/conv/aprob_fprob.cc b/conv/aprob_fprob.cc index df3a452..5688945 100644 --- a/conv/aprob_fprob.cc +++ b/conv/aprob_fprob.cc @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include @@ -21,6 +23,10 @@ ValueFlag fmap { parser, "path", "feature map file path", {"fmap"}, }; +ValueFlag negative_sample { + parser, "samples", "number of samples used to calculate negative factor", {"negative-sample"}, 16, +}; + } // namespace param @@ -32,34 +38,59 @@ static void Exec() { Enforce(!!fmap_st, "fmap path is invalid"); const auto fmap = ReadTensor3(fmap_st); Enforce(fmap.size() > 0 && fmap[0].size() > 0, "empty fmap"); + + std::unordered_set used_blocks_map; for (auto& fmap_t : fmap) { Enforce(fmap_t.size() == fmap[0].size(), "fmap is broken"); + for (auto& fmap_f : fmap_t) { + std::copy(fmap_f.begin(), fmap_f.end(), std::inserter(used_blocks_map, used_blocks_map.end())); + } } + std::vector used_blocks; + used_blocks.reserve(used_blocks_map.size()); + std::copy(used_blocks_map.begin(), used_blocks_map.end(), std::back_inserter(used_blocks)); + std::cerr << "deb: " << used_blocks.size() << std::endl; + std::vector negatives; for (size_t t = 0; t < aprobs.size(); ++t) { const auto tidx = t % fmap.size(); - const auto bnum = aprobs[t].size(); for (size_t c = 0; c < fmap[tidx].size(); ++c) { const auto& blocks = fmap[tidx][c]; - double positive = 0, negative = 0; - for (uint32_t b = 0; b < aprobs[t].size(); ++b) { + const auto negative_sample = std::min( + args::get(param::negative_sample), + aprobs[t].size() - blocks.size()); + negatives.reserve(negative_sample+1); + negatives.clear(); + + double positive = 0; + for (const auto b : used_blocks) { if (blocks.end() != std::find(blocks.begin(), blocks.end(), b)) { positive += aprobs[t][b]; } else { - negative += aprobs[t][b]; + auto itr = std::lower_bound( + negatives.begin(), negatives.end(), aprobs[t][b], std::greater {}); + negatives.insert(itr, aprobs[t][b]); + if (negatives.size() > negative_sample) { + negatives.resize(negative_sample); + } } } + if (blocks.size() > 0) { positive /= blocks.size(); } else { positive = 1; } - if (bnum > blocks.size()) { - negative /= bnum - blocks.size(); - } else { - negative = 0; + + double negative = 0; + if (negative_sample > 0) { + negative = + std::accumulate(negatives.begin(), negatives.end(), 0.) / + negative_sample; + std::cerr << negative << std::endl; } + const auto prob = positive * (1-negative); std::cout << prob << ' '; }