package nl.camilstaps.cs.graphs;

import java.util.*;
import java.util.stream.Collectors;

/**
 * A Graph is a list of Nodes.
 *
 * @author Camil Staps
 */
public class Graph<T> {
    private final List<Node<T>> nodes;
    // When removing nodes, it's possible to create restore points to add them back later on.
    private final Stack<List<Node<T>>> restorePoints = new Stack<>();

    public Graph() {
        this.nodes = new ArrayList<>();
        restorePoint();
    }

    public Graph(List<Node<T>> nodes) {
        this.nodes = nodes;
        restorePoint();
    }

    /**
     * Computes the size of the maximum independent set using Robson's algorithm.
     *
     * @see #maximumIndependentSetSizeWith1From(Node, Node)
     * @see #maximumIndependentSetSizeWith2From(Collection)
     *
     * @return the size of the maximum independent set
     */
    public int maximumIndependentSetSize() {
        if (nodes.size() <= 1)
            return nodes.size();

        int miss;
        restorePoint();

        Graph<T> stronglyConnectedComponent = findStronglyConnectedComponent();
        if (stronglyConnectedComponent.size() == size()) {
            // Case: connected graph. Pick a nice Node A (low degree) and a neighbour B with high degree.
            Collections.sort(nodes);
            Node<T> A = nodes.get(0);
            Collections.sort(A.getNeighbourhood(), Collections.reverseOrder());
            Node<T> B = A.getNeighbourhood().get(0);

            int try_1, try_2;
            switch (A.getDegree()) {
                case 1:
                    // One neighbour; pick A and continue with the rest of the Graph.
                    removeNodes(A.getInclusiveNeighbourhood());
                    miss = 1 + maximumIndependentSetSize();
                    break;
                case 2:
                    // Two neighbours: if they are connected pick A and continue with the rest of the Graph. If not,
                    // try both picking the neighbours and picking A with two of its second order neighbours.
                    Node<T> B_ = A.getNeighbourhood().get(1);
                    if (B.getNeighbourhood().contains(B_)) {
                        removeNodes(A.getInclusiveNeighbourhood());
                        miss = 1 + maximumIndependentSetSize();
                    } else {
                        restorePoint();
                        removeNodes(B.getInclusiveNeighbourhood());
                        removeNodes(B_.getInclusiveNeighbourhood());
                        try_1 = 2 + maximumIndependentSetSize();
                        restore();
                        removeNodes(A.getInclusiveNeighbourhood());
                        try_2 = 1 + maximumIndependentSetSizeWith2From(A.getSecondNeighbourhood());
                        miss = Math.max(try_1, try_2);
                    }
                    break;
                case 3:
                    // Three neighbours: either we pick two of the neighbours, or we pick A.
                    removeNode(A);
                    try_1 = maximumIndependentSetSizeWith2From(A.getNeighbourhood());
                    removeNodes(A.getNeighbourhood());
                    try_2 = 1 + maximumIndependentSetSize();
                    miss = Math.max(try_1, try_2);
                    break;
                default:
                    // If A dominates his neighbour we may safely remove that neighbour. Otherwise, we try both with
                    // and without the neighbour.
                    if (A.dominates(B)) {
                        removeNode(B);
                        miss = maximumIndependentSetSize();
                    } else {
                        removeNode(B);
                        try_1 = maximumIndependentSetSize();
                        removeNodes(B.getNeighbourhood());
                        try_2 = 1 + maximumIndependentSetSize();
                        miss = Math.max(try_1, try_2);
                    }
                    break;
            }
        } else {
            // Case: at least two strongly connected components. Compute for the first component, and for the rest.
            miss = stronglyConnectedComponent.maximumIndependentSetSize();
            removeNodes(stronglyConnectedComponent.nodes);
            miss += maximumIndependentSetSize();
        }

        restore();
        return miss;
    }

    /**
     * Compute the maximum independent set size if it should contain (at least) one of two nodes.
     *
     * This is a helper function for {@link #maximumIndependentSetSizeWith2From(Collection)}, which is a helper
     * function for {@link #maximumIndependentSetSize()}, which uses Robson's algorithm.
     *
     * @param s_1 The first node
     * @param s_2 The second node
     * @return the maximum independent set size
     */
    private int maximumIndependentSetSizeWith1From(Node<T> s_1, Node<T> s_2) {
        assert s_1.getDegree() <= s_2.getDegree();

        int miss;
        restorePoint();

        if (s_1.getDegree() <= 1) {
            miss = maximumIndependentSetSize();
        } else if (s_1.getNeighbourhood().contains(s_2)) {
            if (s_1.getDegree() <= 3) {
                miss = maximumIndependentSetSize();
            } else {
                int try_1, try_2;
                restorePoint();
                removeNodes(s_1.getInclusiveNeighbourhood());
                try_1 = maximumIndependentSetSize();
                restore();
                removeNodes(s_2.getInclusiveNeighbourhood());
                try_2 = maximumIndependentSetSize();
                miss = Math.max(try_1, try_2) + 1;
            }
        } else if (!s_1.neighbourhoodsDisjoint(s_2)) {
            removeNodes(s_1.neighbourhoodIntersection(s_2));
            miss = maximumIndependentSetSizeWith1From(s_1, s_2);
        } else if (s_2.getDegree() == 2) {
            Node<T> E = s_1.getNeighbourhood().get(0), F = s_1.getNeighbourhood().get(1);
            if (E.getNeighbourhood().contains(F)) {
                removeNodes(s_1.getInclusiveNeighbourhood());
                miss = 1 + maximumIndependentSetSize();
            } else {
                boolean subset = true;
                for (Node n : E.getNeighbourhood())
                    if (n != s_1 && !s_2.getNeighbourhood().contains(n))
                        subset = false;
                for (Node n : F.getNeighbourhood())
                    if (n != s_1 && !s_2.getNeighbourhood().contains(n))
                        subset = false;
                if (subset) {
                    removeNodes(s_1.getInclusiveNeighbourhood());
                    removeNodes(s_2.getInclusiveNeighbourhood());
                    miss = 3 + maximumIndependentSetSize();
                } else {
                    int try_1, try_2;
                    removeNodes(s_1.getInclusiveNeighbourhood());
                    try_1 = 1 + maximumIndependentSetSize();
                    removeNodes(E.getInclusiveNeighbourhood());
                    removeNodes(F.getInclusiveNeighbourhood());
                    removeNodes(s_2.getInclusiveNeighbourhood());
                    try_2 = 3 + maximumIndependentSetSize();
                    miss = Math.max(try_1, try_2);
                }
            }
        } else {
            int try_1, try_2;
            restorePoint();
            removeNodes(s_2.getInclusiveNeighbourhood());
            try_1 = maximumIndependentSetSize();
            restore();
            removeNodes(s_1.getInclusiveNeighbourhood());
            removeNode(s_2);
            try_2 = maximumIndependentSetSizeWith2From(s_2.getNeighbourhood());
            miss = Math.max(try_1, try_2);
        }

        restore();
        return miss;
    }

    /**
     * Compute the maximum independent set size if at least two Nodes of some set should be in it.
     *
     * This is a helper function for {@link #maximumIndependentSetSize()}, which uses Robson's algorithm.
     *
     * @see #maximumIndependentSetSizeWith1From(Node, Node)
     *
     * @param Scol the set of which two Nodes should be in the maximum independent set
     * @return the size of the maximum independent set
     */
    private int maximumIndependentSetSizeWith2From(Collection<Node<T>> Scol) {
        List<Node<T>> S = new ArrayList<>(Scol);
        int miss;
        restorePoint();

        if (S.size() < 2) {
            // Less than two Nodes to pick from; this is impossible
            miss = 0;
        } else if (S.size() == 2) {
            // Two Nodes to pick from; try to pick both and fail otherwise.
            if (S.get(0).getNeighbourhood().contains(S.get(1)))
                miss = 0;
            else {
                removeNodes(S.get(0).getInclusiveNeighbourhood());
                removeNodes(S.get(1).getInclusiveNeighbourhood());
                miss = 2 + maximumIndependentSetSize();
            }
        } else if (S.size() == 3) {
            // Three Nodes to pick from. An ugly case distinction.
            Node<T> s_1 = S.get(0), s_2 = S.get(1), s_3 = S.get(2);

            // If there is a Node with degree 0, we add it to the independent set and choose 1 Node from the other two
            if (s_1.getDegree() == 0) {
                removeNode(s_1);
                miss = 1 + maximumIndependentSetSizeWith1From(s_2, s_3);
            } else if (s_2.getDegree() == 0) {
                removeNode(s_2);
                miss = 1 + maximumIndependentSetSizeWith1From(s_1, s_3);
            } else if (s_3.getDegree() == 0) {
                removeNode(s_3);
                miss = 1 + maximumIndependentSetSizeWith1From(s_1, s_2);
            }
            // If the Nodes are connected we can't choose two in an independent set.
            else if (s_1.getNeighbourhood().contains(s_2) &&
                    s_2.getNeighbourhood().contains(s_3) &&
                    s_3.getNeighbourhood().contains(s_1)) {
                miss = 0;
            }
            // If there is at least one edge from s_1
            else if (s_1.getNeighbourhood().contains(s_2) || s_1.getNeighbourhood().contains(s_3)){
                Node<T> s_i, s_j, s_k;
                s_i = s_j = s_k = null;

                // If there are two edges among S, say si-sj-sk, we pick si and sk.
                if (s_1.getNeighbourhood().contains(s_2) && s_1.getNeighbourhood().contains(s_3)) {
                    s_i = s_1; s_j = s_2; s_k = s_3;
                } else if (s_2.getNeighbourhood().contains(s_1) && s_2.getNeighbourhood().contains(s_3)) {
                    s_i = s_2; s_j = s_1; s_k = s_3;
                }  else if (s_3.getNeighbourhood().contains(s_1) && s_3.getNeighbourhood().contains(s_2)) {
                    s_i = s_3; s_j = s_1; s_k = s_2;
                }

                // Check if we managed to find two edges. If not, there must be at least one edge, say si-sj. We then
                // choose sk and either si or sj. These are the cases that s_1 != sk.
                if (s_i != null) {                                      // Case two edges
                    removeNodes(s_j.getInclusiveNeighbourhood());
                    removeNodes(s_k.getInclusiveNeighbourhood());
                    miss = 2 + maximumIndependentSetSize();
                } else if (s_1.getNeighbourhood().contains(s_2)) {      // Case one edge; s_1 - s_2
                    removeNodes(s_3.getInclusiveNeighbourhood());
                    miss = 1 + maximumIndependentSetSizeWith1From(s_1, s_2);
                } else {                                                // Case one edge; s_1 - s_3 (implicit)
                    removeNodes(s_2.getInclusiveNeighbourhood());
                    miss = 1 + maximumIndependentSetSizeWith1From(s_1, s_3);
                }
            }
            // Final case with one edge; between s_2 and s_3: pick s_1 and either s_2 or s_3.
            else if (s_2.getNeighbourhood().contains(s_3)) {
                removeNodes(s_1.getInclusiveNeighbourhood());
                miss = 1 + maximumIndependentSetSizeWith1From(s_2, s_3);
            }
            // When two neighbourhoods of two nodes si, sj aren't disjoint, we can remove the intersection, because
            // either si or sj is in the independent set and so the intersection is not. The intersection has at most
            // one member.
            else if (!s_1.neighbourhoodsDisjoint(s_2)) {
                removeNode(s_1.neighbourhoodIntersection(s_2).get(0));
                miss = maximumIndependentSetSizeWith2From(Scol);
            } else if (!s_1.neighbourhoodsDisjoint(s_3)) {
                removeNode(s_1.neighbourhoodIntersection(s_3).get(0));
                miss = maximumIndependentSetSizeWith2From(Scol);
            } else if (!s_2.neighbourhoodsDisjoint(s_3)) {
                removeNode(s_2.neighbourhoodIntersection(s_3).get(0));
                miss = maximumIndependentSetSizeWith2From(Scol);
            }
            // If s_1 has degree 1, we pick it and either s_2 or s_3
            else if (s_1.getDegree() == 1) {
                removeNodes(s_1.getInclusiveNeighbourhood());
                miss = 1 + maximumIndependentSetSizeWith1From(s_2, s_3);
            }
            // If all else fails, we try both picking s_1 with either s_2 or s_3, and picking s_1 and s_2 and at least
            // one neighbour of s_1 (Note: Robson has forgotten to add 2 for s_2 and s_3 in this step).
            else {
                int try_1, try_2;
                restorePoint();
                removeNodes(s_1.getInclusiveNeighbourhood());
                try_1 = 1 + maximumIndependentSetSizeWith1From(s_2, s_3);
                restore();
                removeNodes(s_2.getInclusiveNeighbourhood());
                removeNodes(s_3.getInclusiveNeighbourhood());
                removeNode(s_1);
                try_2 = 2 + maximumIndependentSetSizeWith2From(s_1.getNeighbourhood());
                miss = Math.max(try_1, try_2);
            }
        } else if (S.size() == 4) {
            // Four Nodes to pick from. If there's a node with degree <= 3 in the Graph, it's more efficient to just
            // compute normally. If not, we try recursively both with and without the first Node in S.
            Collections.sort(nodes);
            if (nodes.get(0).getDegree() <= 3) {
                miss = maximumIndependentSetSize();
            } else {
                Node<T> s_1 = S.get(0);
                int try_1, try_2;
                removeNode(s_1);
                S.remove(s_1);
                try_1 = maximumIndependentSetSizeWith2From(S);
                removeNodes(s_1.getNeighbourhood());
                try_2 = 1 + maximumIndependentSetSize();
                miss = Math.max(try_1, try_2);
            }
        } else {
            miss = maximumIndependentSetSize();
        }

        restore();
        return miss;
    }

    /**
     * Find a strongly connected component of the graph. This may return itself, if the graph is empty or connected.
     * It is not guaranteed that the smallest / largest strongly connected component is returned. The method simply
     * returns the one that contains the first node (if it exists).
     *
     * @return some strongly connected component
     */
    private Graph<T> findStronglyConnectedComponent() {
        if (nodes.isEmpty())
            return this;

        List<Node<T>> seen = new ArrayList<>();
        Queue<Node<T>> queue = new LinkedList<>();
        queue.add(nodes.get(0));
        while (!queue.isEmpty()) {
            Node<T> n = queue.remove();
            queue.addAll(n.getNeighbourhood().stream().filter(x -> !seen.contains(x)).collect(Collectors.toList()));
            seen.add(n);
        }

        return new Graph<>(seen);
    }

    /**
     * Add a Node, and update its neighbours to include it as a neighbour.
     * @param node the Node to add
     */
    public void addNode(Node<T> node) {
        this.nodes.add(node);
        node.getNeighbourhood().stream()
                .filter(nodes::contains)
                .filter(n -> !n.getNeighbourhood().contains(node))
                .forEach(n -> n.getNeighbourhood().add(node));
        restorePoints.peek().remove(node);
    }

    /**
     * Get a Node by its index. Note: some internal methods may reorder the list; only use this directly after adding
     * nodes.
     * @param i the index in the list
     * @return the Node
     */
    public Node<T> getNode(int i) {
        return this.nodes.get(i);
    }

    /**
     * Remove a node, and remove its reference from its neighbours that are still in the graph
     *
     * @see #removeNodes(Collection)
     *
     * @param node the Node
     */
    private void removeNode(Node<T> node) {
        nodes.remove(node);
        node.getNeighbourhood().stream()
                .filter(nodes::contains)
                .forEach(n -> n.getNeighbourhood().remove(node));
        restorePoints.peek().add(node);
    }

    /**
     * Remove a list of nodes, and update all references
     *
     * @see #removeNode(Node)
     *
     * @param ns the Nodes to remove
     */
    private void removeNodes(Collection<Node<T>> ns) {
        nodes.removeAll(ns);
        for (Node<T> n : nodes)
            n.getNeighbourhood().removeAll(ns);
        restorePoints.peek().addAll(ns);
    }

    /**
     * Create a restore point to return back to after removing some Nodes.
     *
     * @see #restore()
     * @see #restorePoints
     */
    private void restorePoint() {
        restorePoints.push(new ArrayList<>());
    }

    /**
     * Go back to the last restore point, re-adding all Nodes that were removed since then.
     *
     * @see #restorePoint()
     * @see #restorePoints
     */
    private void restore() {
        List<Node<T>> removed = restorePoints.pop();
        Collections.reverse(removed);
        for (Node<T> node : removed) {
            this.nodes.add(node);
            node.getNeighbourhood().stream()
                    .filter(nodes::contains)
                    .filter(n -> !n.getNeighbourhood().contains(node))
                    .forEach(n -> n.getNeighbourhood().add(node));
        }
        removed.clear();
    }

    /**
     * The size of the graph is the number of Nodes.
     * @return the number of Nodes.
     */
    public int size() {
        return nodes.size();
    }

    /**
     * Represent the Graph as a String
     * @return a String representation of the Graph
     */
    public String toString() {
        StringBuilder sb = new StringBuilder("Graph:");
        for (Node n : nodes) {
            sb.append("\n  ").append(n).append(":   ");
            for (Object n2 : n.getNeighbourhood()) sb.append(" - ").append(n2);
        }
        return sb.toString();
    }
}