implement negative-sample option to conv/aprob_fprob
This commit is contained in:
parent
c83a0716c1
commit
dd4f0262d2
@ -1,5 +1,7 @@
|
|||||||
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <numeric>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <args.hxx>
|
#include <args.hxx>
|
||||||
@ -21,6 +23,10 @@ ValueFlag<std::string> fmap {
|
|||||||
parser, "path", "feature map file path", {"fmap"},
|
parser, "path", "feature map file path", {"fmap"},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
ValueFlag<size_t> negative_sample {
|
||||||
|
parser, "samples", "number of samples used to calculate negative factor", {"negative-sample"}, 16,
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace param
|
} // namespace param
|
||||||
|
|
||||||
|
|
||||||
@ -32,34 +38,59 @@ static void Exec() {
|
|||||||
Enforce(!!fmap_st, "fmap path is invalid");
|
Enforce(!!fmap_st, "fmap path is invalid");
|
||||||
const auto fmap = ReadTensor3<uint32_t>(fmap_st);
|
const auto fmap = ReadTensor3<uint32_t>(fmap_st);
|
||||||
Enforce(fmap.size() > 0 && fmap[0].size() > 0, "empty fmap");
|
Enforce(fmap.size() > 0 && fmap[0].size() > 0, "empty fmap");
|
||||||
|
|
||||||
|
std::unordered_set<uint32_t> used_blocks_map;
|
||||||
for (auto& fmap_t : fmap) {
|
for (auto& fmap_t : fmap) {
|
||||||
Enforce(fmap_t.size() == fmap[0].size(), "fmap is broken");
|
Enforce(fmap_t.size() == fmap[0].size(), "fmap is broken");
|
||||||
|
for (auto& fmap_f : fmap_t) {
|
||||||
|
std::copy(fmap_f.begin(), fmap_f.end(), std::inserter(used_blocks_map, used_blocks_map.end()));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
std::vector<uint32_t> used_blocks;
|
||||||
|
used_blocks.reserve(used_blocks_map.size());
|
||||||
|
std::copy(used_blocks_map.begin(), used_blocks_map.end(), std::back_inserter(used_blocks));
|
||||||
|
std::cerr << "deb: " << used_blocks.size() << std::endl;
|
||||||
|
|
||||||
|
std::vector<double> negatives;
|
||||||
for (size_t t = 0; t < aprobs.size(); ++t) {
|
for (size_t t = 0; t < aprobs.size(); ++t) {
|
||||||
const auto tidx = t % fmap.size();
|
const auto tidx = t % fmap.size();
|
||||||
const auto bnum = aprobs[t].size();
|
|
||||||
for (size_t c = 0; c < fmap[tidx].size(); ++c) {
|
for (size_t c = 0; c < fmap[tidx].size(); ++c) {
|
||||||
const auto& blocks = fmap[tidx][c];
|
const auto& blocks = fmap[tidx][c];
|
||||||
|
|
||||||
double positive = 0, negative = 0;
|
const auto negative_sample = std::min(
|
||||||
for (uint32_t b = 0; b < aprobs[t].size(); ++b) {
|
args::get(param::negative_sample),
|
||||||
|
aprobs[t].size() - blocks.size());
|
||||||
|
negatives.reserve(negative_sample+1);
|
||||||
|
negatives.clear();
|
||||||
|
|
||||||
|
double positive = 0;
|
||||||
|
for (const auto b : used_blocks) {
|
||||||
if (blocks.end() != std::find(blocks.begin(), blocks.end(), b)) {
|
if (blocks.end() != std::find(blocks.begin(), blocks.end(), b)) {
|
||||||
positive += aprobs[t][b];
|
positive += aprobs[t][b];
|
||||||
} else {
|
} else {
|
||||||
negative += aprobs[t][b];
|
auto itr = std::lower_bound(
|
||||||
|
negatives.begin(), negatives.end(), aprobs[t][b], std::greater {});
|
||||||
|
negatives.insert(itr, aprobs[t][b]);
|
||||||
|
if (negatives.size() > negative_sample) {
|
||||||
|
negatives.resize(negative_sample);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (blocks.size() > 0) {
|
if (blocks.size() > 0) {
|
||||||
positive /= blocks.size();
|
positive /= blocks.size();
|
||||||
} else {
|
} else {
|
||||||
positive = 1;
|
positive = 1;
|
||||||
}
|
}
|
||||||
if (bnum > blocks.size()) {
|
|
||||||
negative /= bnum - blocks.size();
|
double negative = 0;
|
||||||
} else {
|
if (negative_sample > 0) {
|
||||||
negative = 0;
|
negative =
|
||||||
|
std::accumulate(negatives.begin(), negatives.end(), 0.) /
|
||||||
|
negative_sample;
|
||||||
|
std::cerr << negative << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto prob = positive * (1-negative);
|
const auto prob = positive * (1-negative);
|
||||||
std::cout << prob << ' ';
|
std::cout << prob << ' ';
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user