// Works on concave functions
import { SentryError } from "../../lib/sentry-error";

export function ternarySearchForGlobalMin(options: {
  fn: (num: number, lastIteration?: boolean) => number;
  low?: number;
  high?: number;
  maxIterations?: number;
}): {
  value: number;
  score: number;
} {
  let { fn, low, high, maxIterations } = options;
  if (maxIterations === undefined) {
    maxIterations = 100;
  }

  if (low === undefined || high === undefined) {
    low = -10;
    high = 10;
    // careful search phase 1. Expand horizon.

    let iters = 0;
    let lowEscaped = false;
    let highEscaped = false;

    while (true) {
      const mid1 = (low * 2 + high) / 3;
      const mid2 = (low + high * 2) / 3;

      const mv1 = fn(mid1);
      const mv2 = fn(mid2);
      const mv0 = fn(low);
      const mv3 = fn(high);

      if (mv0 <= mv1 && mv1 <= mv2 && mv2 <= mv3) {
        low = low - (high - low);
      } else {
        break;
      }

      iters++;
      if (iters > 50) {
        lowEscaped = true;
        break;
      }
    }

    iters = 0;
    while (true) {
      const mid1 = (low * 2 + high) / 3;
      const mid2 = (low + high * 2) / 3;

      const mv1 = fn(mid1);
      const mv2 = fn(mid2);
      const mv0 = fn(low);
      const mv3 = fn(high);

      if (mv0 >= mv1 && mv1 >= mv2 && mv2 >= mv3) {
        high = high + (high - low);
      } else {
        break;
      }

      iters++;
      if (iters > 50) {
        highEscaped = true;
        break;
      }
    }

    if (lowEscaped && highEscaped) {
      return {
        value: 0,
        score: fn(0),
      };
    }
  }

  // Phase 2. Center horizon
  let previousIter: number[] = [];

  let descending = false;

  let ans = 0;
  for (let i = 0; i < maxIterations; i++) {
    const lastIteration = i === maxIterations - 1;
    const mid1: number = (low * 2 + high) / 3;
    const mid2: number = (low + high * 2) / 3;

    const mv1 = fn(mid1);
    const mv2 = fn(mid2);
    const mv0 = fn(low!);
    const mv3 = fn(high!);

    if (!descending) {
      if (mv0 >= mv1 && mv2 <= mv3) {
        descending = true;
      }
    }

    if (mv1 > mv0 && descending) {
      fn(mid1, true);
      throw new SentryError(
        "mv0 is less than mv1",
        {},
        {
          mv0,
          mv1,
          mv2,
          mv3,
          previousIter: previousIter.join(", "),
          i,
        },
      );
    }
    if (mv2 > mv3 && descending) {
      fn(mid1, true);
      throw new SentryError(
        "mv3 is less than mv2",
        {},
        {
          mv0,
          mv1,
          mv2,
          mv3,
          previousIter: previousIter.join(", "),
          i,
          low,
          high,
        },
      );
    }

    previousIter = [mv0, mv1, mv2, mv3];

    if (mv1 < mv2) {
      high = mid2;
    } else {
      low = mid1;
    }
  }

  return {
    value: (high + low) / 2,
    score: fn((high + low) / 2),
  };
}

export function binarySearch(
  // true means too high, false means too low or eq
  fn: (mid: number) => boolean,
  options?: {
    low?: number;
    high?: number;
  },
) {
  let low = -1;
  if (options?.low == undefined) {
    while (fn(low)) {
      low = low * 2;
    }
  } else {
    low = options?.low;
  }

  let high = 1;
  if (options?.high == undefined) {
    while (!fn(high)) {
      high = high * 2;
    }
  } else {
    high = options?.high;
  }

  for (let i = 0; i < 100; i++) {
    const mid = (low + high) / 2;
    const midValue = fn(mid);
    if (midValue) {
      high = mid;
    } else {
      low = mid;
    }
  }

  return {
    value: (high + low) / 2,
    score: fn((high + low) / 2),
  };
}
