diff --git a/conv/aprob_fprob.cc b/conv/aprob_fprob.cc index d30b627..4babb93 100644 --- a/conv/aprob_fprob.cc +++ b/conv/aprob_fprob.cc @@ -33,20 +33,21 @@ static void Exec() { std::ifstream fmap_st {args::get(param::fmap)}; Enforce(!!fmap_st, "fmap path is invalid"); - const auto fmap = ReadMatrix(fmap_st); - Enforce(fmap.size() > 0, "empty fmap"); - for (auto& idxs : fmap) { - Enforce(idxs.size() > 0, "fmap has empty item"); + const auto fmap = ReadTensor3(fmap_st); + Enforce(fmap.size() > 0 && fmap[0].size() > 0, "empty fmap"); + for (auto& fmap_t : fmap) { + Enforce(fmap_t.size() == fmap[0].size(), "fmap is broken"); } for (size_t t = 0; t < aprobs.size(); ++t) { - Enforce(aprobs[t].size() <= fmap.size(), "unmatched aprobs and fmap"); - for (size_t c = 0; c < fmap.size(); ++c) { + const auto tidx = t % fmap.size(); + for (size_t c = 0; c < fmap[tidx].size(); ++c) { double sum = 0; - for (auto i : fmap[c]) { + for (auto i : fmap[tidx][c]) { + Enforce(i < aprobs[t].size(), "aprob has no enough columns"); sum += aprobs[t][i]; } - std::cout << sum / fmap[c].size() << ' '; + std::cout << sum / fmap[tidx][c].size() << ' '; } std::cout << '\n'; } diff --git a/conv/common.hh b/conv/common.hh index 40a28d4..e635a68 100644 --- a/conv/common.hh +++ b/conv/common.hh @@ -17,7 +17,7 @@ inline void Enforce(bool eval, const std::string& msg) { } template -std::vector> ReadMatrix(std::istream& st) noexcept { +auto ReadMatrix(std::istream& st) noexcept { std::vector> ret; std::string line; @@ -28,6 +28,23 @@ std::vector> ReadMatrix(std::istream& st) noexcept { } return ret; } +template +auto ReadTensor3(std::istream& st) noexcept { + std::vector>> ret(1); + + std::string line; + while (std::getline(st, line)) { + if (line == "----") { + ret.push_back({}); + } else { + std::istringstream sst {line}; + ret.back().emplace_back(std::istream_iterator {sst}, + std::istream_iterator {}); + } + } + return ret; +} + // ---- MP4 utilities diff --git a/conv/feat_block.cc b/conv/feat_block.cc index 60af16d..695a42f 100644 --- a/conv/feat_block.cc +++ b/conv/feat_block.cc @@ -27,17 +27,16 @@ ValueFlag fmap { static void Exec() { std::ifstream fmap_st {args::get(param::fmap)}; Enforce(!!fmap_st, "fmap path is invalid"); - const auto fmap = ReadMatrix(fmap_st); - Enforce(fmap.size() > 0, "empty fmap"); - for (auto& idxs : fmap) { - Enforce(idxs.size() > 0, "fmap has empty item"); + const auto fmap = ReadTensor3(fmap_st); + Enforce(fmap.size() > 0 && fmap[0].size() > 0, "empty fmap"); + for (auto& fmap_t : fmap) { + Enforce(fmap_t.size() == fmap[0].size(), "fmap is broken"); } - size_t feat; - while (std::cin >> feat) { - Enforce(feat < fmap.size(), "feat overflow"); - - for (const auto idx : fmap[feat]) { + for (size_t feat, t = 0; std::cin >> feat; ++t) { + const auto tidx = t % fmap.size(); + Enforce(feat < fmap[tidx].size(), "feat overflow"); + for (const auto idx : fmap[tidx][feat]) { std::cout << idx << ' '; } std::cout << '\n';