#!/usr/bin/python3

# SI 335: Computer Algorithms
# Unit 5

import sys
from heapq import heapify, heappop
from random import randrange
from copy import copy
from unit2 import swap, randomArray, sortTest, mergeSort
from unit4 import toDigits, fromDigits

def selectBySort(A, k):
    mergeSort(A)
    return A[k]

def selectByHeap(A, k):
    H = copy(A)
    heapify(H)
    for i in range(0, k):
        heappop(H)
    return H[0]

def partition(A):
    '''Partitions A according to A[0]. A[0] is used as the pivot,
       and the final index where A[0] ends up (p) is returned.'''
    n = len(A)
    i, j = 1, n-1
    while i <= j:
        if A[i] <= A[0]:
            i = i + 1
        elif A[j] > A[0]:
            j = j - 1
        else:
            swap(A, i, j)
    swap(A, 0, j)
    return j


def choosePivot1(A):
    return 0

def quickSelect1(A, k):
    '''Returns the k'th smallest element, counting from k=0'''
    n = len(A)
    swap(A, 0, choosePivot1(A))
    p = partition(A)
    if p == k:
        return A[p]
    elif p < k:
        return quickSelect1(A[p+1 : n], k-p-1)
    elif p > k:
        return quickSelect1(A[0 : p], k)


def shuffle(A):
    n = len(A)
    for i in range(0, n):
        swap(A, i, randrange(i, n))

def randomSelect(A, k):
    shuffle(A)
    return quickSelect1(A, k)


def choosePivot2(A):
    # This returns a random number from 0 up to n-1
    return randrange(0, len(A))

def quickSelect2(A, k):
    '''Returns the k'th smallest element, counting from k=0'''
    n = len(A)
    swap(A, 0, choosePivot2(A))
    p = partition(A)
    if p == k:
        return A[p]
    elif p < k:
        return quickSelect2(A[p+1 : n], k-p-1)
    elif p > k:
        return quickSelect2(A[0 : p], k)


def choosePivot3(A, q=5):
    '''This is the median of medians algorithm.
       q is a parameter that affects the complexity; can be
       any value greater than or equal to 2.'''
    n = len(A)
    m = n // q
    if m <= 1: 
        # base case
        return n // 2 
    medians = []
    for i in range(0, m):
        # Find median of each group
        medians.append(quickSelect3(A[i*q : (i+1)*q], q//2))
    # Find the median of medians
    mom = quickSelect3(medians, m//2)
    for i in range(0, n):
        if A[i] == mom:
            return i

def quickSelect3(A, k):
    '''Returns the k'th smallest element, counting from k=0'''
    n = len(A)
    swap(A, 0, choosePivot3(A))
    p = partition(A)
    if p == k:
        return A[p]
    elif p < k:
        return quickSelect3(A[p+1 : n], k-p-1)
    elif p > k:
        return quickSelect3(A[0 : p], k)

def quickSort1(A):
    n = len(A)
    if n > 1:
        swap(A, 0, choosePivot1(A))
        p = partition(A)
        A[0 : p] = quickSort1(A[0 : p])
        A[p+1 : n] = quickSort1(A[p+1 : n])
    return A

def quickSort2(A):
    n = len(A)
    if n > 1:
        swap(A, 0, choosePivot2(A))
        p = partition(A)
        A[0 : p] = quickSort2(A[0 : p])
        A[p+1 : n] = quickSort2(A[p+1 : n])
    return A

def quickSort3(A):
    n = len(A)
    if n > 1:
        swap(A, 0, choosePivot3(A))
        p = partition(A)
        A[0 : p] = quickSort3(A[0 : p])
        A[p+1 : n] = quickSort3(A[p+1 : n])
    return A


def countingSort(A, k=None, value = lambda x: x):
    '''value is a function that coverts the elements of a into
       integer values from 0 up to k-1.'''
    if k is None:
        # Automatically determine k if it's not given.
        k = 0
        for x in A:
            k = max(k, value(x+1))
    C = [0] * k # size-k array filled with 0's
    for x in A:
        C[value(x)] = C[value(x)] + 1
    # Now C has the counts.
    # P will hold the positions.
    P = [0]
    for i in range(1, k):
        P.append(P[i-1] + C[i-1])
    # Now copy everything into its proper position.
    for x in copy(A):
        A[P[value(x)]] = x
        P[value(x)] = P[value(x)] + 1
    return A

def ithDigit(i, N):
    try:
        return N[i]
    except IndexError:
        return 0

def radixSort(A, d, B):
    for i in range(0, d):
        # counting-sort a based on the i'th digits
        countingSort(A, B, lambda N: ithDigit(i,N))
    return A

def radixSortWrapper(A):
    '''This takes an array of integers and converts to
       an array of multi-precision integers in base 10,
       then sorts that array'''
    mp_A = [toDigits(n) for n in A]
    d = max(len(N) for N in mp_A)
    radixSort(mp_A, d, 10)
    A[:] = [fromDigits(N) for N in mp_A]
    return A


# The rest is just for testing/debugging purposes.
def selectTest(algs = (selectBySort, selectByHeap, quickSelect1, randomSelect, quickSelect2, quickSelect3)):
    sys.setrecursionlimit(1000)
    allgood = True
    maxsize, maxval = 500, 1000000
    data = randomArray(maxsize, maxval)
    sortedData = sorted(data)
    for alg in algs:
        good = True
        for i in range(10):
            k = randrange(maxsize)
            X = alg(copy(data), k)
            # Check that X is actually the k'th smallest
            if X != sortedData[k]:
                good = False
        if not good:
            print("FAILED CHECK FOR", alg.__name__)
            allgood = False
    if allgood:
        print("Passed all selection checks")


if __name__ == '__main__':
    selectTest()
    sortTest((quickSort1, quickSort2, quickSort3, countingSort, radixSortWrapper))
