#!/usr/bin/python3

# SI 335: Computer Algorithms
# Unit 5

import sys
from heapq import heappush, heappop
from collections import deque
from random import randrange
from copy import copy

infinity = float('inf')

class Graph:
    '''An abstract base class for graphs.
       These are all the methods that would have to be implemented.
       Here it's just an empty graph.'''
    m = 0
    n = 0
    
    def getVertices(self):
        """Returns a list of all vertices in the graph."""
        return []

    def getEdges(self):
        """Returns a list of (u,v,w) triples for all edges in the graph."""
        E = []
        for u in self.getVertices():
            for (v,w) in self.neighbors(u):
                E.append((u,v,w))
        return E

    def edgeWeight(self, u, v):
        """Returns the weight of the specified edge."""
        return infinity

    def neighbors(self, u):
        """Returns a list of pairs (v,w) for all outgoing edges from node u.
        (v,w) indicates there is an edge from u to v with weight w."""
        return []


class ALGraph(Graph):
    '''Adjacency list representation of a graph'''

    def __init__(self, vertices, edges):
        self.n = len(vertices)
        self.m = len(edges)
        self.V = vertices
        
        # self.AL is the actual adjacency list, initialized to all empty.
        self.AL = {}
        for u in self.V:
            self.AL[u] = []

        # add each edge to the proper adjacency list
        for (u,v,w) in edges:
            self.AL[u].append((v,w))

    def getVertices(self):
        return self.V

    def edgeWeight(self, u, v):
        for (other, w) in self.AL[u]:
            if other == v:
                return w
        return infinity

    def neighbors(self, u):
        return self.AL[u]


class AMGraph(Graph):
    '''Adjacency matrix representation of a graph'''

    def __init__(self, vertices, edges):
        self.n = len(vertices)
        self.m = len(edges)
        self.V = list(vertices)

        # lookup table for the vertices
        self.vertind = {}
        i = 0
        for u in self.V:
            self.vertind[u] = i
            i += 1
        
        # self.AM is the actual adjacency matrix, initialized to 0 and infinity
        self.AM = []
        for i in range(self.n):
            self.AM.append([infinity] * self.n)
            self.AM[i][i] = 0

        # add each edge weight to the matrix
        for (u,v,w) in edges:
            self.AM[self.vertind[u]][self.vertind[v]] = w

    def getVertices(self):
        return self.V

    def edgeWeight(self, u, v):
        return self.AM[self.vertind[u]][self.vertind[v]]

    def neighbors(self, u):
        L = []
        uind = self.vertind[u]
        for i in range(self.n):
            w = self.AM[uind][i]
            if w > 0 and w < infinity:
                L.append((self.V[i], w))
        return L


def DFS(G, start):
    '''Returns a list of vertices in the order they are visited.'''
    visited = []
    fringe = [start]
    while len(fringe) > 0:
        u = fringe.pop() # pops from the end of the list
        if u not in visited:
            visited.append(u)
            for (v, w) in G.neighbors(u):
                fringe.append(v)
    return visited

def BFS(G, start):
    visited = []
    fringe = deque([start]) # only difference from DFS: queue instead of stack
    while len(fringe) > 0:
        u = fringe.popleft() # pops from the beginning of the queue
        if u not in visited:
            visited.append(u)
            for (v,w) in G.neighbors(u):
                fringe.append(v)
    return visited


def dijkstraHeap(G, start):
    '''A dictionary of shortest path lengths from start in G is returned.'''
    shortest = {}
    fringe = [(0, start)] # Note: the weight must come first for the order.
    while len(fringe) > 0:
        (w1, u) = heappop(fringe)
        if u not in shortest:
            shortest[u] = w1
            for (v, w2) in G.neighbors(u):
                heappush(fringe, (w1+w2, v))
    return shortest

def dijkstraArray(G, start):
    '''A dictionary of shortest path lengths from start in G is returned.'''
    shortest = {}
    fringe = {}
    for u in G.getVertices():
        fringe[u] = infinity
    fringe[start] = 0
    while len(fringe) > 0:
        w1 = min(fringe.values())
        for u in fringe:
            if fringe[u] == w1:
                break
        del fringe[u]
        shortest[u] = w1
        for (v, w2) in G.neighbors(u):
            if v in fringe:
                fringe[v] = min(fringe[v], w1+w2)
    return shortest


def FloydWarshall(AM):
    '''Calculates EVERY shortest path length between any two vertices
       in the original adjacency matrix graph.'''
    L = copy(AM)
    n = len(AM)
    for k in range(0, n):
        for i in range(0, n):
            for j in range(0, n):
                L[i][j] = min(L[i][j], L[i][k] + L[k][j])
    return L


def Prims(G, start):
    '''Returns a list of edges in the MST starting from the given vertex.'''
    MST = []
    visited = set()
    # fringe will be a min-heap of edges (u,v,w), but where the weight
    # comes first (w,u,v) so that the weights determine the ordering.
    fringe = [(0,None,start)]
    while len(fringe) > 0:
        (w, u, v) = heappop(fringe)
        if v not in visited:
            visited.add(v)
            if u is not None:
                MST.append((u,v,w))
            for (v2, w2) in G.neighbors(v):
                heappush(fringe, (w2, v, v2))
    return MST


class DisjointSet:
    """A disjoint-set data structure using arrays"""

    def __init__(self, items):
        """Creates a new DisjointSet, intialized with
        every item in items as a separate set by itself."""
        self.sets = {} # hash table mapping each item to its set
        for x in items:
            self.sets[x] = [x] # each item is in a set by itself

    def union(self, x, y):
        """Combines the sets containing x and y"""
        xset = self.sets[x]
        yset = self.sets[y]
        if xset != yset:
            # they are in different sets; must be merged
            # aloways merge the smaller set into the bigger one
            if len(xset) >= len(yset):
                for item in yset:
                    xset.append(item)
                    self.sets[item] = xset
            else:
                for item in xset:
                    yset.append(item)
                    self.sets[item] = yset

    def find(self, x):
        """Returns the set containing x"""
        return self.sets[x]


def Kruskals(G):
    """Returns a list of edges in the MST"""
    MST = []
    UF = DisjointSet(G.getVertices())
    # have to put weights first in the edges, for sorting
    edges = [(w,u,v) for (u,v,w) in G.getEdges()]
    edges.sort()
    for (w,u,v) in edges:
        if UF.find(u) != UF.find(v):
            UF.union(u,v)
            MST.append((u,v,w))
    return MST


def approxVC(G):
    C = set() # makes an empty set
    for u in G.getVertices():
        for (v,w) in G.neighbors(u):
            if u not in C and v not in C:
                C.add(u)
                C.add(v)
    return C


# The rest is just for testing/debugging purposes.

# Specifications of my example graphs
def weighted(E):
    '''Makes a weighted from an unweighted graph'''
    return tuple(sorted((u,v,1) for (u,v) in E))

def directed(E):
    '''Makes directed from an undirectd graph'''
    Eset = set(E)
    for (u,v,w) in E:
        Eset.add((v,u,w))
    return tuple(sorted(Eset))

def fromE(E):
    '''Determines vertices from edges'''
    Vset = set()
    for (u,v,w) in E:
        Vset.add(u)
        Vset.add(v)
    return tuple(sorted(Vset)), E

a,b,c,d,e,f,g,h,i,j,k,l,m = (chr(let) for let in range(ord('a'), ord('n')))

ex1 = fromE(
    ((a,c,10), (a,d,22), (b,c,53), (b,e,45), (c,a,21), (c,e,33), (e,d,19))
)

ex2 = fromE(directed(
    ((a,b,6), (a,c,6), (a,d,3), (b,d,2), (b,e,4), (c,d,5), (c,e,1), (d,e,4))
))

ex3 = fromE(directed(
    ((a,c,1), (c,d,6), (b,e,1), (c,f,4), (a,f,6), (b,f,2), (c,e,5), 
     (e,d,2), (b,c,1))
))

match1 = fromE(directed(weighted(
    ((l,h), (h,d), (d,a), (a,b), (c,f), (f,e), (e,i), (j,m), (g,k), (j,e), (j,i),
     (g,f), (a,h), (d,b), (b,e), (l,m), (h,i), (c,b), (k,f), (m,k), (j,f), (d,i),
     (i,m))
)))


if __name__ == '__main__':
    assert BFS(ALGraph(*ex1),b) == [b,c,e,a,d]
    assert DFS(AMGraph(*ex1),a) == [a,d,c,e]
    
    assert dijkstraArray(ALGraph(*ex2), a) == {a:0, b:5, c:6, d:3, e:7}
    assert dijkstraHeap(ALGraph(*ex2), a) == {a:0, b:5, c:6, d:3, e:7}
    assert dijkstraArray(AMGraph(*ex2), a) == {a:0, b:5, c:6, d:3, e:7}
    assert dijkstraHeap(AMGraph(*ex2), a) == {a:0, b:5, c:6, d:3, e:7}

    ex2am = AMGraph(*ex2).AM
    L = FloydWarshall(ex2am)
    assert L[0][1] == 5
    assert L[0][4] == 7
    assert L[2][1] == 5

    assert Prims(ALGraph(*ex3), d) == [
            (d,e,2), (e,b,1), (b,c,1), (c,a,1), (b,f,2)]
    assert Kruskals(ALGraph(*ex3)) == [
            (a,c,1), (b,c,1), (b,e,1), (b,f,2), (d,e,2)]

    print("All checks passed!")

del a,b,c,d,e,f,g,h,i,j,k,l,m
