#include "liblocky.h"

#include <assert.h>
#include <stdlib.h>


void blky_pathfinder_init(blky_pathfinder_t* pf) {
  assert(pf->block_num > 0);
  assert(pf->feat_bits > 0);
  assert(pf->feat_bits < 32);
  assert(pf->seed      > 0);

  assert(pf->block_num > (uint32_t) (1 << pf->feat_bits));

  pf->step_last  = NULL;

  pf->step_bytes =
    sizeof(blky_pathfinder_step_t) +
    sizeof(uint32_t)*(pf->block_num-1);

  pf->probs = calloc(sizeof(*pf->probs), pf->block_num);
  assert(pf->probs);

  pf->probs_prev = calloc(sizeof(*pf->probs_prev), pf->block_num);
  assert(pf->probs_prev);
}

void blky_pathfinder_deinit(blky_pathfinder_t* pf) {
  free(pf->probs);
  free(pf->probs_prev);

  blky_pathfinder_step_t* step = pf->step_last;
  while (step) {
    blky_pathfinder_step_t* temp = step->prev;
    free(step);
    step = temp;
  }
}

void blky_pathfinder_feed(blky_pathfinder_t* pf, const double* probs) {
  double* temp = pf->probs;
  pf->probs      = pf->probs_prev;
  pf->probs_prev = temp;

  blky_pathfinder_step_t* step = NULL;
  if (++pf->steps > 1) {
    step = calloc(pf->step_bytes, 1);
    assert(step);
  }

  const uint32_t feat_max = 1 << pf->feat_bits;
  assert(feat_max < pf->block_num);

  pf->seed = blky_numeric_xorshift64(pf->seed);
  for (uint32_t bi = 0; bi < pf->block_num; ++bi) {
    const double prob = probs[bi];
    for (uint32_t pbi = 0; pbi < pf->block_num; ++pbi) {
      for (uint32_t fi = 0; fi < feat_max; ++fi) {
        if (blky_numeric_hop(pbi, fi, pf->seed)%pf->block_num != bi) continue;

        const double sum  = pf->probs_prev[pbi] + prob;
        if (pf->probs[bi] < sum) {
          pf->probs[bi] = sum;
          if (step) step->indices[bi] = pbi;
        }
      }
    }
  }

  if (step) {
    step->prev    = pf->step_last;
    pf->step_last = step;
  }
}