import { assertType } from "../../lib/utils";
import { isFire } from "../config";
import { CoreCalculatedObjectConcrete } from "../coreObjects";
import CoreLoadNode from "../coreObjects/coreLoadNode";
import { addWarning } from "../document/calculations-objects/warnings";
import PipeEntity, {
  fillDefaultConduitFields,
  isPipeEntity,
} from "../document/entities/conduit-entity";
import DirectedValveEntity from "../document/entities/directed-valves/directed-valve-entity";
import { NodeType } from "../document/entities/load-node-entity";
import { EntityType } from "../document/entities/types";
import CalculationEngine from "./calculation-engine";
import { TraceCalculation } from "./flight-data-recorder";
import { FlowAssignment } from "./flow-assignment";
import { CycleResult, solveFlow } from "./generic-flow-solver";
import { GlobalFDR } from "./global-fdr";
import Graph, { Edge } from "./graph";
import { IterationResult } from "./returns";
import { EdgeType, FlowEdge, FlowNode, PressurePushMode } from "./types";
import {
  FLOW_SOURCE_ROOT_NODE,
  findFireSubGroup,
  isDirectedValveDirected,
} from "./utils";

// Inner functions.
function imbalance(
  context: CalculationEngine,
  graph: Graph<FlowNode, FlowEdge>,
  flowAssignment: FlowAssignment,
  cycle: Edge<FlowNode, FlowEdge>[],
): number {
  let imbalance = 0;
  for (const edge of cycle) {
    let pipe = context.globalStore.getCalculatableOrThrow(edge.value.uid);
    if (
      pipe.type === EntityType.CONDUIT ||
      (pipe.type === EntityType.DIRECTED_VALVE &&
        !isDirectedValveDirected(pipe.entity as DirectedValveEntity)) ||
      pipe.type === EntityType.MULTIWAY_VALVE
    ) {
      assertType<CoreCalculatedObjectConcrete>(pipe);
      let res = pipe.getFrictionPressureLossKPA({
        context,
        flowLS: flowAssignment.getFlow(edge.uid, graph.sn(edge.from)),
        from: edge.from,
        to: edge.to,
        signed: true,
        pressurePushMode: PressurePushMode.PSD,
      }).pressureLossKPA!;
      imbalance += res;
    } else if (
      pipe.type === EntityType.DIRECTED_VALVE &&
      isDirectedValveDirected(pipe.entity as DirectedValveEntity)
    ) {
      //TODO add warning
      addWarning(context, "DIRECTED_VALVE_FORBIDDEN_HERE", [pipe.entity]);
    }
  }
  // console.log("imbalance", imbalance);
  return imbalance;
}

function _nbits(n: number): number {
  let bits = 0;
  while (n) {
    if (n & 1) bits++;
    n = n >> 1;
  }
  return bits;
}

export class FireCalculations {
  @TraceCalculation("Calculating fire flow rates")
  static fireFlowRates(context: CalculationEngine): void {
    // if (allFireNodes[0].entity.node.type === NodeType.FIRE) {
    //   console.log("this is a fire node");
    //   if (
    //     availableNodes[0].customEntityId ===
    //     allFireNodes[0].entity.node.customEntityId
    //   ) {
    //     console.log("this is the same node");
    //   }
    // }

    let graph = context.flowGraph;
    const bridges = new Set(
      graph.findBridges().map((o) => o.value.uid.split(".")[0]),
    );
    // initialize pipe sizes
    // for (const obj of context.globalStore.values()) {
    //   if (
    //     obj.type === EntityType.CONDUIT &&
    //     !bridges.has(obj.entity.uid) &&
    //     !isGas(
    //       obj.entity.systemUid,
    //       context.catalog.fluids,
    //       context.drawing.metadata.flowSystems
    //     )
    //   ) {
    //     const pCalc = context.globalStore.getOrCreateCalculation(obj.entity);
    //     const filled = fillPipeDefaultFields(context, obj.entity);
    //     if (!pCalc.realNominalPipeDiameterMM)
    //       context.sizePipeForFlowRate(obj.entity, [
    //         [0, filled.maximumVelocityMS!],
    //       ]);
    //   }
    // }
    // CorePipe.getManufacturerCatalogPage;

    //do calculation
    // iteratively calculate pipesize
    const RETURNS_RESIZE_MAX_ITER = 10;
    //
    for (let i = 0; i < RETURNS_RESIZE_MAX_ITER; i++) {
      GlobalFDR.focusData([
        "Iteration " + i + " of " + RETURNS_RESIZE_MAX_ITER,
      ]);
      const iterStart = performance.now();
      let didChange = false;

      let lastIterationResult = this._determineFlowRatesSingleIteration(
        context,
        graph,
      );
      didChange = lastIterationResult?.pipeSizesChanged;
      console.log(
        "determineFlowRatesSingleIteration took",
        performance.now() - iterStart,
        "ms",
      );
      if (!didChange) {
        console.log("determineFlowRatesSingleIteration converged");
        break;
      }
    }
    // set result

    // example: how to set the pipe size of a pipe (feel free to make this a helper function)
    // const pipeObj: CorePipe = {};
    // const pCalc = context.globalStore.getCalculation(pipeObj.entity);
  }
  @TraceCalculation("Find optimal node sets")
  static findOptimalNodeSets(
    context: CalculationEngine,
    graph: Graph<FlowNode, FlowEdge>,
    bridges: Set<string>,
    nodes: CoreLoadNode[],
    numSupport: number,
  ): Set<string>[] {
    // Node Chain algorithm
    return this._findAllNodeChains(context, graph, bridges, nodes, numSupport);

    // geniric algorithm, but slow
    // return _findAllNodeSubsets(graph, nodes, numSupport);
    // return [];
  }

  @TraceCalculation("Determine fire flow rates single iteration")
  static _determineFlowRatesSingleIteration(
    context: CalculationEngine,
    graph: Graph<FlowNode, FlowEdge>,
  ): IterationResult {
    const availableNodes = context.drawing.metadata.nodes.fire;
    const allFireNodes: CoreLoadNode[] = Array.from(
      context.globalStore.entitiesOfType<CoreLoadNode>(EntityType.LOAD_NODE),
    ).filter(
      (o) =>
        o.entity.node.type === NodeType.FIRE && o.entity.node.customEntityId,
    ) as CoreLoadNode[];
    const FireNodeGroups: { [key: string]: CoreLoadNode[] } = {};
    for (const node of allFireNodes) {
      if (
        node.entity.node.type !== NodeType.FIRE ||
        !node.entity.uid.includes(".")
      )
        continue;
      if (!FireNodeGroups[node.entity.node.customEntityId!]) {
        FireNodeGroups[node.entity.node.customEntityId!] = [];
      }
      FireNodeGroups[node.entity.node.customEntityId!].push(node);
    }

    // find bridges

    const bridges = new Set(graph.findBridges().map((o) => o.value.uid));

    // initialize flow rates

    for (const edge of graph.edgeList.values()) {
      GlobalFDR.focusData([edge.value.uid]);
      if (
        edge.value.type === EdgeType.CONDUIT &&
        !bridges.has(edge.value.uid)
      ) {
        let pipe = context.globalStore.getObjectOfTypeOrThrow(
          EntityType.CONDUIT,
          edge.value.uid,
        );
        if (!isPipeEntity(pipe.entity)) continue;
        const pCalc = context.globalStore.getOrCreateCalculation(pipe.entity);
        if (pCalc.fireFlowRateLS) {
          pCalc.totalPeakFlowRateLS! -= pCalc.fireFlowRateLS;
          pCalc.PSDFlowRateLS! -= pCalc.fireFlowRateLS;
          pCalc.fireFlowRateLS = 0;
        }
      }
    }
    let result: FlowAssignment = new FlowAssignment();
    // solve flow rate for each group
    for (const group of Object.values(FireNodeGroups)) {
      if (group.length === 0) continue;
      // solve flow rate for all possible assignments
      // for each subgroup, generate the node sets
      //collect all subgroups
      let subgroups: { [key: string]: [CoreLoadNode[], number] } = {};

      for (let node of group) {
        GlobalFDR.focusData([node.entity.uid]);
        if (node.entity.node.type !== NodeType.FIRE) continue;
        let firesubgroup = findFireSubGroup(
          context.drawing,
          node.entity.node.customEntityId,
          node.entity.node.subGroupId,
        );
        if (!(firesubgroup.subGroupId in subgroups)) {
          subgroups[firesubgroup.subGroupId] = [
            [],
            firesubgroup.maxiumumSimutaneousNode,
          ];
        }
        subgroups[firesubgroup.subGroupId][0].push(node);
      }
      let groupResult: FlowAssignment = new FlowAssignment();
      for (let subgroup of Object.values(subgroups)) {
        // do dfs for each node
        let nodeSets = this.findOptimalNodeSets(
          context,
          graph,
          bridges,
          subgroup[0],
          subgroup[1],
        ); // Find all possible node sets that can cause big flow rate for some pipes

        let subGroupResult: FlowAssignment = new FlowAssignment();
        for (let needToCover of nodeSets) {
          // generate initial flow assignment
          const assignment = this._generateInitialFlowAssignment(
            context,
            graph,
            needToCover,
          );
          // solve flow rates
          let nowResult = solveFlow<FlowNode, FlowEdge>({
            assignment,
            graph,
            escapeDelta: 0.0001,
            computeCycleImbalance: (
              flowAssignment: FlowAssignment,
              cycle: Edge<FlowNode, FlowEdge>[],
              iteration: number,
              learningRate?: number,
            ): CycleResult => {
              // careful search for the max flow rate
              let min = 0;
              let max = 0;
              let maxv = 0,
                minv = 0;
              let currentImbalance = imbalance(
                context,
                graph,
                flowAssignment,
                cycle,
              );
              for (let i = 1; currentImbalance; i *= 2) {
                for (const edge of cycle) {
                  flowAssignment.addFlow(edge.uid, graph.sn(edge.from), i);
                }
                let newImbalance = imbalance(
                  context,
                  graph,
                  flowAssignment,
                  cycle,
                );

                for (const edge of cycle) {
                  flowAssignment.addFlow(edge.uid, graph.sn(edge.from), -i);
                }
                if (newImbalance > currentImbalance && newImbalance > 0) {
                  max = i;
                  maxv = newImbalance;
                  break;
                }
              }
              for (let i = 1; currentImbalance; i *= 2) {
                for (const edge of cycle) {
                  flowAssignment.addFlow(edge.uid, graph.sn(edge.from), -i);
                }
                let newImbalance = imbalance(
                  context,
                  graph,
                  flowAssignment,
                  cycle,
                );
                for (const edge of cycle) {
                  flowAssignment.addFlow(edge.uid, graph.sn(edge.from), i);
                }
                if (newImbalance < currentImbalance && newImbalance < 0) {
                  min = -i;
                  minv = newImbalance;
                  break;
                }
              }
              // binary search
              let newImbalance = 0;
              while (max - min > 1e-9) {
                let mid = (max + min) / 2;
                for (const edge of cycle) {
                  flowAssignment.addFlow(edge.uid, graph.sn(edge.from), mid);
                }
                newImbalance = imbalance(context, graph, flowAssignment, cycle);

                for (const edge of cycle) {
                  flowAssignment.addFlow(edge.uid, graph.sn(edge.from), -mid);
                }
                if (newImbalance * (minv - newImbalance) < 0) {
                  max = mid;
                } else {
                  min = mid;
                }
              }
              return {
                suggestedAdjustment: min * (learningRate ?? 1),
                imbalance: currentImbalance,
              };
            },
            shortCycles: true,
          });
          // merge this nodeset's results
          for (let [k, v] of nowResult) {
            let before = subGroupResult.getFlow(k, v[1]);
            let after = nowResult.getFlow(k, v[1]);
            if (Math.abs(before) < Math.abs(after)) {
              subGroupResult.addFlow(k, v[1], after - before);
            }
          }
        }
        //merge subgroup's results
        for (let [k, v] of subGroupResult) {
          let before = groupResult.getFlow(k, v[1]);
          let after = subGroupResult.getFlow(k, v[1]);
          if (Math.abs(before) < Math.abs(after)) {
            groupResult.addFlow(k, v[1], after - before);
          }
        }
      }
      // merge group's results
      for (let [k, v] of groupResult) {
        let after = groupResult.getFlow(k, v[1]);
        result.addFlow(k, v[1], Math.abs(after));
      }
      // result = groupResult;
    }

    // collect results
    let pipeSizesChanged = false,
      totalFlowRateLS = 0;
    for (const edge of graph.edgeList.values()) {
      if (
        edge.value.type === EdgeType.CONDUIT &&
        !bridges.has(edge.value.uid)
      ) {
        let pipe = context.globalStore.getObjectOfTypeOrThrow(
          EntityType.CONDUIT,
          edge.value.uid,
        );
        const flowRateLS = Math.abs(
          result.getFlow(edge.uid, graph.sn(edge.from)),
        );
        if (!flowRateLS || !isPipeEntity(pipe.entity)) continue;
        const pCalc = context.globalStore.getOrCreateCalculation(pipe.entity);
        pCalc.totalPeakFlowRateLS! += flowRateLS;
        pCalc.PSDFlowRateLS! += flowRateLS;
        pCalc.fireFlowRateLS = flowRateLS;
        const origSize = pCalc.realNominalPipeDiameterMM;
        const filled = fillDefaultConduitFields(context, pipe.entity);

        context.sizePipeForFlowRate(pipe.entity, [
          [pCalc.PSDFlowRateLS, filled.conduit.maximumVelocityMS!],
        ]);
        if (
          origSize !== pCalc.realNominalPipeDiameterMM &&
          pCalc.realNominalPipeDiameterMM
        ) {
          console.log(
            "change size:",
            origSize,
            pCalc.realNominalPipeDiameterMM,
          );
          pipeSizesChanged = true;
        }
        totalFlowRateLS = Math.max(totalFlowRateLS, pCalc.PSDFlowRateLS!);
      }
    }

    return {
      pipeSizesChanged,
      totalFlowRateLS,
    };
  }

  @TraceCalculation("Generate intial flow assignment")
  static _generateInitialFlowAssignment(
    context: CalculationEngine,
    graph: Graph<FlowNode, FlowEdge>,
    needToCover: Set<string>,
  ): FlowAssignment {
    if (!graph.edgeList.size) return new FlowAssignment();
    const assignment = new FlowAssignment();

    // dfs
    const visited = new Set<FlowNode>();
    const stack = new Array<Edge<FlowNode, FlowEdge>>();

    graph.dfsRecursive(
      FLOW_SOURCE_ROOT_NODE,
      (node) => {
        if (visited.has(node)) return;
        visited.add(node);

        if (node.connectable === "FLOW_SOURCE_ROOT") return;
        let object = context.globalStore.getObjectOfType(
          EntityType.LOAD_NODE,
          node.connectable,
        );
        if (object) {
          const entity = object.entity;

          if (entity.node.type === NodeType.FIRE) {
            let firesubgroup = findFireSubGroup(
              context.drawing,
              entity.node.customEntityId,
              entity.node.subGroupId,
            );
            let flow = firesubgroup.continuousFlowRateLS;
            stack.forEach((edge) => {
              let pipe = context.globalStore.getObjectOfType(
                EntityType.CONDUIT,
                edge.value.uid,
              );
              if (!pipe) return;
              let pipeCalc =
                context.globalStore.getOrCreateCalculation<PipeEntity>(
                  pipe.entity,
                );
            });
          }
        }
      },
      undefined,
      (edge) => {
        if (edge.value.type === EdgeType.FITTING_FLOW) {
          return;
        }
        if (edge.value.type === EdgeType.CONDUIT) {
          stack.push(edge);
        }
      },
      (edge) => {
        if (edge.value.type === EdgeType.FITTING_FLOW) {
          return;
        }
        if (edge.value.type === EdgeType.CONDUIT) {
          stack.pop();
        }
      },
    );
    visited.clear();
    graph.dfsRecursive(
      FLOW_SOURCE_ROOT_NODE,
      (node) => {
        if (visited.has(node)) return;
        visited.add(node);

        if (node.connectable === "FLOW_SOURCE_ROOT") return;
        let object = context.globalStore.getObjectOfType(
          EntityType.LOAD_NODE,
          node.connectable,
        );
        if (object) {
          const entity = object.entity;

          if (
            entity.node.type === NodeType.FIRE &&
            needToCover.has(entity.uid)
          ) {
            let firesubgroup = findFireSubGroup(
              context.drawing,
              entity.node.customEntityId,
              entity.node.subGroupId,
            );
            let flow = firesubgroup.continuousFlowRateLS;
            stack.forEach((edge) => {
              let pipe = context.globalStore.getObjectOfType(
                EntityType.CONDUIT,
                edge.value.uid,
              );
              if (!pipe) {
                return;
              }
              let pipeCalc =
                context.globalStore.getOrCreateCalculation<PipeEntity>(
                  pipe.entity,
                );
              assignment.addFlow(edge.uid, graph.sn(edge.from), flow);
            });
          }
        }
      },
      undefined,
      (edge) => {
        if (edge.value.type === EdgeType.FITTING_FLOW) {
          return;
        }
        if (edge.value.type === EdgeType.CONDUIT) {
          stack.push(edge);
        }
      },
      (edge) => {
        if (edge.value.type === EdgeType.FITTING_FLOW) {
          return;
        }
        if (edge.value.type === EdgeType.CONDUIT) {
          stack.pop();
        }
      },
    );

    return assignment;
  }

  // Available algorithms

  // Subset algorithm, simulate all possible subsets
  @TraceCalculation("Finding all node subsets", (_, no, nu) =>
    no.map((n) => n.uid).concat([nu.toString()]),
  )
  static _findAllNodeSubsets(
    graph: Graph<FlowNode, FlowEdge>,
    nodes: CoreLoadNode[],
    numSupport: number,
  ): Set<string>[] {
    let T = 1 << nodes.length;

    let needToCovers = new Array<Set<string>>();
    for (let i = 0; i < T; i++) {
      if (_nbits(i) > numSupport) continue;
      let needToCover = new Set<string>();
      for (let j = 0; j < nodes.length; j++) {
        if (i & (1 << j)) {
          needToCover.add(nodes[j].entity.uid);
        }
      }
      needToCovers.push(new Set(needToCover));
    }
    return needToCovers;
  }

  // Chain algorithm, simulate dfs chains
  @TraceCalculation("Finding all node chains")
  static _findAllNodeChains(
    context: CalculationEngine,
    graph: Graph<FlowNode, FlowEdge>,
    bridges: Set<string>,
    nodes: CoreLoadNode[],
    numSupport: number,
  ): Set<string>[] {
    let needToCovers = new Set<Set<string>>();
    let CycleNodes = new Set<FlowNode>();
    let FlowRateReduces = new Map<string, { flow: number; node: string }[]>();
    for (let edge of graph.edgeList.values()) {
      GlobalFDR.focusData([edge.value.uid]);
      if (edge.value.type !== EdgeType.CONDUIT) continue;
      let pipe = context.globalStore.get(edge.value.uid);
      if (
        bridges.has(pipe.entity.uid) ||
        !isPipeEntity(pipe.entity) ||
        !isFire(context.drawing.metadata.flowSystems[pipe.entity.systemUid])
      )
        continue;
      CycleNodes.add(edge.from);
      CycleNodes.add(edge.to);
    }
    for (let cycleNode of CycleNodes) {
      GlobalFDR.focusData([cycleNode.connectable, cycleNode.connection]);
      let flowRates: { flow: number; node: string }[] = [];
      graph.dfs(
        cycleNode,
        (node) => {
          if (node.connectable === "FLOW_SOURCE_ROOT") return;
          let object = context.globalStore.getObjectOfType(
            EntityType.LOAD_NODE,
            node.connectable,
          );
          if (object && nodes.includes(object)) {
            const entity = object.entity;
            if (entity.node.type === NodeType.FIRE) {
              let firesubgroup = findFireSubGroup(
                context.drawing,
                entity.node.customEntityId,
                entity.node.subGroupId,
              );
              let flow = firesubgroup.continuousFlowRateLS;
              flowRates.push({ flow, node: entity.uid });
            }
          }
        },
        undefined,
        (edge) => {
          if (
            edge.value.type === EdgeType.CONDUIT &&
            !bridges.has(edge.value.uid) &&
            (edge.from.connectable === cycleNode.connectable ||
              edge.to.connectable === cycleNode.connectable)
          ) {
            //Wrong
            // if (edge.value.type === EdgeType.CONDUIT && !bridges.has(edge.value.uid)) {//Wrong
            return true;
          }
        },
      );
      flowRates.sort((a, b) => b.flow - a.flow);
      FlowRateReduces.set(cycleNode.connectable, flowRates);
    }
    // calculate point chain for each cyclenode which contain a fire node
    let seenCycleNodes = new Set<string>();
    for (let start of CycleNodes) {
      GlobalFDR.focusData([start.connectable, start.connection]);
      if (seenCycleNodes.has(start.connectable)) continue;
      seenCycleNodes.add(start.connectable);
      if (FlowRateReduces.get(start.connectable)?.length === 0) continue;

      let seen = new Set<string>(),
        seenEdges = new Set<string>(),
        chain = new Set<string>();
      let stack = new Array<
        | { type: "edge"; edge: Edge<FlowNode, FlowEdge> }
        | { type: "node"; node: string }
      >();
      let edgeStack = new Array<Edge<FlowNode, FlowEdge>>();

      graph.dfsRecursive(
        start,
        (node) => {
          if (node.connectable === "FLOW_SOURCE_ROOT") return true;
          let object = context.globalStore.get(node.connectable);
          if (chain.size >= numSupport) return true;
          if (FlowRateReduces.get(node.connectable)?.length) {
            for (let reduce of FlowRateReduces.get(node.connectable)!) {
              if (chain.size >= numSupport) break;
              chain.add(reduce.node);
              stack.push({ type: "node", node: reduce.node });
            }
            needToCovers.add(new Set(chain));
          }
          const entity = object.entity;
        },
        (node) => {
          seen.delete(graph.sn(node));
        },
        (edge) => {
          edgeStack.push(edge);
          if (
            edge.value.type === EdgeType.FITTING_FLOW &&
            edgeStack.length > 1 &&
            edgeStack[edgeStack.length - 2].value.type === EdgeType.FITTING_FLOW
          ) {
            return true;
          }
          if (edge.value.type === EdgeType.CONDUIT) {
            if (bridges.has(edge.value.uid)) return true;
            stack.push({ type: "edge", edge });
          }
        },
        (edge) => {
          edgeStack.pop();
          seenEdges.delete(edge.uid);
          if (
            edge.value.type === EdgeType.FITTING_FLOW ||
            bridges.has(edge.value.uid)
          ) {
            return;
          }
          if (edge.value.type === EdgeType.CONDUIT) {
            while (stack.length > 0) {
              let top = stack.pop()!;
              if (top.type === "edge") {
                if (top.edge.value.uid === edge.value.uid) break;
              } else {
                chain.delete(top.node);
              }
            }
          }
        },
        seen,
        seenEdges,
      );
    }
    let result = new Array<Set<string>>();
    let unique = new Set<Set<string>>();
    for (let needToCover of needToCovers.values()) {
      let bool = true;
      for (let st of unique) {
        // compare if two set is the same by another for loop
        let isSetSame = (s1: Set<string>, s2: Set<string>) => {
          if (s1.size !== s2.size) return false;
          for (let i of s1) {
            if (!s2.has(i)) return false;
          }
          return true;
        };

        if (isSetSame(needToCover, st)) {
          bool = false;
          break;
        }
      }

      if (!bool) {
        continue;
      }
      unique.add(needToCover);
    }
    for (let needToCover of unique.values()) {
      result.push(needToCover);
    }
    console.log("size", result.length);
    return result;
  }
}
