用python实现基本数据结构和算法

1章:ADT抽象数据类型,定义数据和其操作

什么是ADT: 抽象数据类型,学过数据结构的应该都知道。

How to select datastructures for ADT

  1. Dose the data structure provide for the storage requirements as specified by the domain of the ADT?

  2. Does the data structure provide the data access and manipulation functionality to fully implement the ADT?

  3. Effcient implemention? based on complexity analysis.

下边代码是个简单的示例,比如实现一个简单的Bag类,先定义其具有的操作,然后我们再用类的magic method来实现这些方法:

class Bag:
    """
    constructor: 构造函数
    size
    contains
    append
    remove
    iter
    """
    def __init__(self):
        self._items = list()

    def __len__(self):
        return len(self._items)

    def __contains__(self, item):
        return item in self._items

    def add(self, item):
        self._items.append(item)

    def remove(self, item):
        assert item in self._items, 'item must in the bag'
        return self._items.remove(item)

    def __iter__(self):
        return _BagIterator(self._items)


class _BagIterator:
    """ 注意这里实现了迭代器类 """
    def __init__(self, seq):
        self._bag_items = seq
        self._cur_item = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self._cur_item < len(self._bag_items):
            item = self._bag_items[self._cur_item]
            self._cur_item += 1
            return item
        else:
            raise StopIteration


b = Bag()
b.add(1)
b.add(2)
for i in b:     # for使用__iter__构建,用__next__迭代
    print(i)


"""
# for 语句等价于
i = b.__iter__()
while True:
    try:
        item = i.__next__()
        print(item)
    except StopIteration:
        break
"""

2章:array vs list

array: 定长,操作有限,但是节省内存;貌似我的生涯中还没用过,不过python3.5中我试了确实有array类,可以用import array直接导入

list: 会预先分配内存,操作丰富,但是耗费内存。我用sys.getsizeof做了实验。我个人理解很类似C++ STL里的vector,是使用最频繁的数据结构。

  • list.append: 如果之前没有分配够内存,会重新开辟新区域,然后复制之前的数据,复杂度退化

  • list.insert: 会移动被插入区域后所有元素,O(n)

  • list.pop: pop不同位置需要的复杂度不同pop(0)是O(1)复杂度,pop()首位O(n)复杂度

  • list[]: slice操作copy数据(预留空间)到另一个list

来实现一个array的ADT:

import ctypes

class Array:
    def __init__(self, size):
        assert size > 0, 'array size must be > 0'
        self._size = size
        PyArrayType = ctypes.py_object * size
        self._elements = PyArrayType()
        self.clear(None)

    def __len__(self):
        return self._size

    def __getitem__(self, index):
        assert index >= 0 and index < len(self), 'out of range'
        return self._elements[index]

    def __setitem__(self, index, value):
        assert index >= 0 and index < len(self), 'out of range'
        self._elements[index] = value

    def clear(self, value):
        """ 设置每个元素为value """
        for i in range(len(self)):
            self._elements[i] = value

    def __iter__(self):
        return _ArrayIterator(self._elements)


class _ArrayIterator:
    def __init__(self, items):
        self._items = items
        self._idx = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self._idex < len(self._items):
            val = self._items[self._idx]
            self._idex += 1
            return val
        else:
            raise StopIteration

Two-Demensional Arrays

class Array2D:
    """ 要实现的方法
    Array2D(nrows, ncols):    constructor
    numRows()
    numCols()
    clear(value)
    getitem(i, j)
    setitem(i, j, val)
    """
    def __init__(self, numrows, numcols):
        self._the_rows = Array(numrows)     # 数组的数组
        for i in range(numrows):
            self._the_rows[i] = Array(numcols)

    @property
    def numRows(self):
        return len(self._the_rows)

    @property
    def NumCols(self):
        return len(self._the_rows[0])

    def clear(self, value):
        for row in self._the_rows:
            row.clear(value)

    def __getitem__(self, ndx_tuple):    # ndx_tuple: (x, y)
        assert len(ndx_tuple) == 2
        row, col = ndx_tuple[0], ndx_tuple[1]
        assert (row >= 0 and row < self.numRows and
                col >= 0 and col < self.NumCols)

        the_1d_array = self._the_rows[row]
        return the_1d_array[col]

    def __setitem__(self, ndx_tuple, value):
        assert len(ndx_tuple) == 2
        row, col = ndx_tuple[0], ndx_tuple[1]
        assert (row >= 0 and row < self.numRows and
                col >= 0 and col < self.NumCols)
        the_1d_array = self._the_rows[row]
        the_1d_array[col] = value

The Matrix ADT, m行,n列。这个最好用还是用pandas处理矩阵,自己实现比较*疼

class Matrix:
    """ 最好用pandas的DataFrame
    Matrix(rows, ncols): constructor
    numCols()
    getitem(row, col)
    setitem(row, col, val)
    scaleBy(scalar): 每个元素乘scalar
    transpose(): 返回transpose转置
    add(rhsMatrix):    size must be the same
    subtract(rhsMatrix)
    multiply(rhsMatrix)
    """
    def __init__(self, numRows, numCols):
        self._theGrid = Array2D(numRows, numCols)
        self._theGrid.clear(0)

    @property
    def numRows(self):
        return self._theGrid.numRows

    @property
    def NumCols(self):
        return self._theGrid.numCols

    def __getitem__(self, ndxTuple):
        return self._theGrid[ndxTuple[0], ndxTuple[1]]

    def __setitem__(self, ndxTuple, scalar):
        self._theGrid[ndxTuple[0], ndxTuple[1]] = scalar

    def scaleBy(self, scalar):
        for r in range(self.numRows):
            for c in range(self.numCols):
                self[r, c] *= scalar

    def __add__(self, rhsMatrix):
        assert (rhsMatrix.numRows == self.numRows and
                rhsMatrix.numCols == self.numCols)
        newMartrix = Matrix(self.numRows, self.numCols)
        for r in range(self.numRows):
            for c in range(self.numCols):
                newMartrix[r, c] = self[r, c] + rhsMatrix[r, c]

3章:Sets and Maps

除了list之外,最常用的应该就是python内置的set和dict了。

sets ADT

A set is a container that stores a collection of unique values over a given comparable domain in which the stored values have no particular ordering.

class Set:
    """ 使用list实现set ADT
    Set()
    length()
    contains(element)
    add(element)
    remove(element)
    equals(element)
    isSubsetOf(setB)
    union(setB)
    intersect(setB)
    difference(setB)
    iterator()
    """
    def __init__(self):
        self._theElements = list()

    def __len__(self):
        return len(self._theElements)

    def __contains__(self, element):
        return element in self._theElements

    def add(self, element):
        if element not in self:
            self._theElements.append(element)

    def remove(self, element):
        assert element in self, 'The element must be set'
        self._theElements.remove(element)

    def __eq__(self, setB):
        if len(self) != len(setB):
            return False
        else:
            return self.isSubsetOf(setB)

    def isSubsetOf(self, setB):
        for element in self:
            if element not in setB:
                return False
        return True

    def union(self, setB):
        newSet = Set()
        newSet._theElements.extend(self._theElements)
        for element in setB:
            if element not in self:
                newSet._theElements.append(element)
        return newSet

Maps or Dict: 键值对,python内部采用hash实现。

class Map:
    """ Map ADT list implemention
    Map()
    length()
    contains(key)
    add(key, value)
    remove(key)
    valudOf(key)
    iterator()
    """
    def __init__(self):
        self._entryList = list()

    def __len__(self):
        return len(self._entryList)

    def __contains__(self, key):
        ndx = self._findPosition(key)
        return ndx is not None

    def add(self, key, value):
        ndx = self._findPosition(key)
        if ndx is not None:
            self._entryList[ndx].value = value
            return False
        else:
            entry = _MapEntry(key, value)
            self._entryList.append(entry)
            return True

    def valueOf(self, key):
        ndx = self._findPosition(key)
        assert ndx is not None, 'Invalid map key'
        return self._entryList[ndx].value

    def remove(self, key):
        ndx = self._findPosition(key)
        assert ndx is not None, 'Invalid map key'
        self._entryList.pop(ndx)

    def __iter__(self):
        return _MapIterator(self._entryList)

    def _findPosition(self, key):
        for i in range(len(self)):
            if self._entryList[i].key == key:
                return i
        return None


class _MapEntry:    # or use collections.namedtuple('_MapEntry', 'key,value')
    def __init__(self, key, value):
        self.key = key
        self.value = value

The multiArray ADT, 多维数组,一般是使用一个一维数组模拟,然后通过计算下标获取元素

class MultiArray:
    """ row-major or column-marjor ordering, this is row-major ordering
    MultiArray(d1, d2, ...dn)
    dims():   the number of dimensions
    length(dim): the length of given array dimension
    clear(value)
    getitem(i1, i2, ... in), index(i1,i2,i3) = i1*(d2*d3) + i2*d3 + i3
    setitem(i1, i2, ... in)
    计算下标:index(i1,i2,...in) = i1*f1 + i2*f2 + ... + i(n-1)*f(n-1) + in*1
    """
    def __init__(self, *dimensions):
        # Implementation of MultiArray ADT using a 1-D # array,数组的数组的数组。。。
        assert len(dimensions) > 1, 'The array must have 2 or more dimensions'
        self._dims = dimensions
        # Compute to total number of elements in the array
        size = 1
        for d in dimensions:
            assert d > 0, 'Dimensions must be > 0'
            size *= d
        # Create the 1-D array to store the elements
        self._elements = Array(size)
        # Create a 1-D array to store the equation factors
        self._factors = Array(len(dimensions))
        self._computeFactors()

    @property
    def numDims(self):
        return len(self._dims)

    def length(self, dim):
        assert dim > 0 and dim < len(self._dims), 'Dimension component out of range'
        return self._dims[dim-1]

    def clear(self, value):
        self._elements.clear(value)

    def __getitem__(self, ndxTuple):
        assert len(ndxTuple) == self.numDims, 'Invalid # of array subscripts'
        index = self._computeIndex(ndxTuple)
        assert index is not None, 'Array subscript out of range'
        return self._elements[index]

    def __setitem__(self, ndxTuple, value):
        assert len(ndxTuple) == self.numDims, 'Invalid # of array subscripts'
        index = self._computeIndex(ndxTuple)
        assert index is not None, 'Array subscript out of range'
        self._elements[index] = value

    def _computeIndex(self, ndxTuple):
        # using the equation: i1*f1 + i2*f2 + ... + in*fn
        offset = 0
        for j in range(len(ndxTuple)):
            if ndxTuple[j] < 0 or ndxTuple[j] >= self._dims[j]:
                return None
            else:
                offset += ndexTuple[j] * self._factors[j]
        return offset

4章:Algorithm Analysis

一般使用大O标记法来衡量算法的平均时间复杂度, 1 < log(n) < n < nlog(n) < n^2 < n^3 < a^n。 了解常用数据结构操作的平均时间复杂度有利于使用更高效的数据结构,当然有时候需要在时间和空间上进行衡量,有些操作甚至还会退化,比如list的append操作,如果list空间不够,会去开辟新的空间,操作复杂度退化到O(n),有时候还需要使用均摊分析(amortized)


5章:Searching and Sorting

排序和查找是最基础和频繁的操作,python内置了in操作符和bisect二分操作模块实现查找,内置了sorted方法来实现排序操作。二分和快排也是面试中经常考到的,本章讲的是基本的排序和查找。

def binary_search(sorted_seq, val):
    """ 实现标准库中的bisect.bisect_left """
    low = 0
    high = len(sorted_seq) - 1
    while low <= high:
        mid = (high + low) // 2
        if sorted_seq[mid] == val:
            return mid
        elif val < sorted_seq[mid]:
            high = mid - 1
        else:
            low = mid + 1
    return low

  def bubble_sort(seq):  # O(n^2), n(n-1)/2 = 1/2(n^2 + n)
      n = len(seq)
      for i in range(n-1):
          for j in range(n-1-i):    # 这里之所以 n-1 还需要 减去 i 是因为每一轮冒泡最大的元素都会冒泡到最后,无需再比较
              if seq[j] > seq[j+1]:
                  seq[j], seq[j+1] = seq[j+1], seq[j]


def select_sort(seq):
    """可以看作是冒泡的改进,每次找一个最小的元素交换,每一轮只需要交换一次"""
    n = len(seq)
    for i in range(n-1):
        min_idx = i    # assume the ith element is the smallest
        for j in range(i+1, n):
            if seq[j] < seq[min_idx]:   # find the minist element index
                min_idx = j
        if min_idx != i:    # swap
            seq[i], seq[min_idx] = seq[min_idx], seq[i]


def insertion_sort(seq):
    """ 每次挑选下一个元素插入已经排序的数组中,初始时已排序数组只有一个元素"""
    n = len(seq)
    for i in range(1, n):
        value = seq[i]    # save the value to be positioned
        # find the position where value fits in the ordered part of the list
        pos = i
        while pos > 0 and value < seq[pos-1]:
            # Shift the items to the right during the search
            seq[pos] = seq[pos-1]
            pos -= 1
        seq[pos] = value


def merge_sorted_list(listA, listB):
    """ 归并两个有序数组 """
    new_list = list()
    a = b = 0
    while a < len(listA) and b < len(listB):
        if listA[a] < listB[b]:
            new_list.append(listA[a])
            a += 1
        else:
            new_list.append(listB[b])
            b += 1

    while a < len(listA):
        new_list.append(listA[a])
        a += 1

    while b < len(listB):
        new_list.append(listB[b])
        b += 1

    return new_list

6章: Linked Structure

list是最常用的数据结构,但是list在中间增减元素的时候效率会很低,这时候linked list会更适合,缺点就是获取元素的平均时间复杂度变成了O(n)

# 单链表实现
class ListNode:
    def __init__(self, data):
        self.data = data
        self.next = None


def travsersal(head, callback):
    curNode = head
    while curNode is not None:
        callback(curNode.data)
        curNode = curNode.next


def unorderdSearch(head, target):
    curNode = head
    while curNode is not None and curNode.data != target:
        curNode = curNode.next
    return curNode is not None


# Given the head pointer, prepend an item to an unsorted linked list.
def prepend(head, item):
    newNode = ListNode(item)
    newNode.next = head
    head = newNode


# Given the head reference, remove a target from a linked list
def remove(head, target):
    predNode = None
    curNode = head
    while curNode is not None and curNode.data != target:
        # 寻找目标
        predNode = curNode
        curNode = curNode.data
    if curNode is not None:
        if curNode is head:
            head = curNode.next
        else:
            predNode.next = curNode.next

7章:Stacks

栈也是计算机里用得比较多的数据结构,栈是一种后进先出的数据结构,可以理解为往一个桶里放盘子,先放进去的会被压在地下,拿盘子的时候,后放的会被先拿出来。

class Stack:
    """ Stack ADT, using a python list
    Stack()
    isEmpty()
    length()
    pop(): assert not empty
    peek(): assert not empty, return top of non-empty stack without removing it
    push(item)
    """
    def __init__(self):
        self._items = list()

    def isEmpty(self):
        return len(self) == 0

    def __len__(self):
        return len(self._items)

    def peek(self):
        assert not self.isEmpty()
        return self._items[-1]

    def pop(self):
        assert not self.isEmpty()
        return self._items.pop()

    def push(self, item):
        self._items.append(item)


class Stack:
    """ Stack ADT, use linked list
    使用list实现很简单,但是如果涉及大量push操作,list的空间不够时复杂度退化到O(n)
    而linked list可以保证最坏情况下仍是O(1)
    """
    def __init__(self):
        self._top = None    # top节点, _StackNode or None
        self._size = 0    # int

    def isEmpty(self):
        return self._top is None

    def __len__(self):
        return self._size

    def peek(self):
        assert not self.isEmpty()
        return self._top.item

    def pop(self):
        assert not self.isEmpty()
        node = self._top
        self.top = self._top.next
        self._size -= 1
        return node.item

    def _push(self, item):
        self._top = _StackNode(item, self._top)
        self._size += 1


class _StackNode:
    def __init__(self, item, link):
        self.item = item
        self.next = link

8章:Queues

队列也是经常使用的数据结构,比如发送消息等,celery可以使用redis提供的list实现消息队列。 本章我们用list和linked list来实现队列和优先级队列。

class Queue:
    """ Queue ADT, use list。list实现,简单但是push和pop效率最差是O(n)
    Queue()
    isEmpty()
    length()
    enqueue(item)
    dequeue()
    """
    def __init__(self):
        self._qList = list()

    def isEmpty(self):
        return len(self) == 0

    def __len__(self):
        return len(self._qList)

    def enquue(self, item):
        self._qList.append(item)

    def dequeue(self):
        assert not self.isEmpty()
        return self._qList.pop(0)


from array import Array    # Array那一章实现的Array ADT
class Queue:
    """
    circular Array ,通过头尾指针实现。list内置append和pop复杂度会退化,使用
    环数组实现可以使得入队出队操作时间复杂度为O(1),缺点是数组长度需要固定。
    """
    def __init__(self, maxSize):
        self._count = 0
        self._front = 0
        self._back = maxSize - 1
        self._qArray = Array(maxSize)

    def isEmpty(self):
        return self._count == 0

    def isFull(self):
        return self._count == len(self._qArray)

    def __len__(self):
        return len(self._count)

    def enqueue(self, item):
        assert not self.isFull()
        maxSize = len(self._qArray)
        self._back = (self._back + 1) % maxSize     # 移动尾指针
        self._qArray[self._back] = item
        self._count += 1

    def dequeue(self):
        assert not self.isFull()
        item = self._qArray[self._front]
        maxSize = len(self._qArray)
        self._front = (self._front + 1) % maxSize
        self._count -= 1
        return item

class _QueueNode:
    def __init__(self, item):
        self.item = item


class Queue:
    """ Queue ADT, linked list 实现。为了改进环型数组有最大数量的限制,改用
    带有头尾节点的linked list实现。
    """
    def __init__(self):
        self._qhead = None
        self._qtail = None
        self._qsize = 0

    def isEmpty(self):
        return self._qhead is None

    def __len__(self):
        return self._count

    def enqueue(self, item):
        node = _QueueNode(item)    # 创建新的节点并用尾节点指向他
        if self.isEmpty():
            self._qhead = node
        else:
            self._qtail.next = node
        self._qtail = node
        self._qcount += 1

    def dequeue(self):
        assert not self.isEmpty(), 'Can not dequeue from an empty queue'
        node = self._qhead
        if self._qhead is self._qtail:
            self._qtail = None
        self._qhead = self._qhead.next    # 前移头节点
        self._count -= 1
        return node.item


class UnboundedPriorityQueue:
    """ PriorityQueue ADT: 给每个item加上优先级p,高优先级先dequeue
    分为两种:
    - bounded PriorityQueue: 限制优先级在一个区间[0...p)
    - unbounded PriorityQueue: 不限制优先级

    PriorityQueue()
    BPriorityQueue(numLevels): create a bounded PriorityQueue with priority in range
        [0, numLevels-1]
    isEmpty()
    length()
    enqueue(item, priority): 如果是bounded PriorityQueue, priority必须在区间内
    dequeue(): 最高优先级的出队,同优先级的按照FIFO顺序

    - 两种实现方式:
    1.入队的时候都是到队尾,出队操作找到最高优先级的出队,出队操作O(n)
    2.始终维持队列有序,每次入队都找到该插入的位置,出队操作是O(1)
    (注意如果用list实现list.append和pop操作复杂度会因内存分配退化)
    """
    from collections import namedtuple
    _PriorityQEntry = namedtuple('_PriorityQEntry', 'item, priority')

    # 采用方式1,用内置list实现unbounded PriorityQueue
    def __init__(self):
        self._qlist = list()

    def isEmpty(self):
        return len(self) == 0

    def __len__(self):
        return len(self._qlist)

    def enqueue(self, item, priority):
        entry = UnboundedPriorityQueue._PriorityQEntry(item, priority)
        self._qlist.append(entry)

    def deque(self):
        assert not self.isEmpty(), 'can not deque from an empty queue'
        highest = self._qlist[0].priority
        for i in range(len(self)):    # 出队操作O(n),遍历找到最高优先级
            if self._qlist[i].priority < highest:
                highest = self._qlist[i].priority
        entry = self._qlist.pop(highest)
        return entry.item


class BoundedPriorityQueue:
    """ BoundedPriorityQueue ADT,用linked list实现。上一个地方提到了 BoundedPriorityQueue
    但是为什么需要 BoundedPriorityQueue呢? BoundedPriorityQueue 的优先级限制在[0, maxPriority-1]
    对于 UnboundedPriorityQueue,出队操作由于要遍历寻找优先级最高的item,所以平均
    是O(n)的操作,但是对于 BoundedPriorityQueue,用队列数组实现可以达到常量时间,
    用空间换时间。比如要弹出一个元素,直接找到第一个非空队列弹出 元素就可以了。
    (小数字代表高优先级,先出队)

    qlist
    [0] -> ["white"]
    [1]
    [2] -> ["black", "green"]
    [3] -> ["purple", "yellow"]
    """
    # Implementation of the bounded Priority Queue ADT using an array of #
    # queues in which the queues are implemented using a linked list.
    from array import Array    #  第二章定义的ADT

    def __init__(self, numLevels):
        self._qSize = 0
        self._qLevels = Array(numLevels)
        for i in range(numLevels):
            self._qLevels[i] = Queue()    # 上一节讲到用linked list实现的Queue

    def isEmpty(self):
        return len(self) == 0

    def __len__(self):
        return len(self._qSize)

    def enqueue(self, item, priority):
        assert priority >= 0 and priority < len(self._qLevels), 'invalid priority'
        self._qLevel[priority].enquue(item)    # 直接找到 priority 对应的槽入队

    def deque(self):
        assert not self.isEmpty(), 'can not deque from an empty queue'
        i = 0
        p = len(self._qLevels)
        while i < p and not self._qLevels[i].isEmpty():    # 找到第一个非空队列
            i += 1
        return self._qLevels[i].dequeue()

9章:Advanced Linked Lists

之前曾经介绍过单链表,一个链表节点只有data和next字段,本章介绍高级的链表。

Doubly Linked List,双链表,每个节点多了个prev指向前一个节点。双链表可以用来编写文本编辑器的buffer。

class DListNode:
    def __init__(self, data):
        self.data = data
        self.prev = None
        self.next = None


def revTraversa(tail):
    curNode = tail
    while cruNode is not None:
        print(curNode.data)
        curNode = curNode.prev


def search_sorted_doubly_linked_list(head, tail, probe, target):
    """ probing technique探查法,改进直接遍历,不过最坏时间复杂度仍是O(n)
    searching a sorted doubly linked list using the probing technique
    Args:
        head (DListNode obj)
        tail (DListNode obj)
        probe (DListNode or None)
        target (DListNode.data): data to search
    """
    if head is None:    # make sure list is not empty
        return False
    if probe is None:    # if probe is null, initialize it to first node
        probe = head

    # if the target comes before the probe node, we traverse backward, otherwise
    # traverse forward
    if target < probe.data:
        while probe is not None and target <= probe.data:
            if target == probe.dta:
                return True
            else:
                probe = probe.prev
    else:
        while probe is not None and target >= probe.data:
            if target == probe.data:
                return True
            else:
                probe = probe.next
    return False


def insert_node_into_ordered_doubly_linekd_list(value):
    """ 最好画个图看,链表操作很容易绕晕,注意赋值顺序"""
    newnode = DListNode(value)

    if head is None:    # empty list
        head = newnode
        tail = head

    elif value < head.data:    # insert before head
        newnode.next = head
        head.prev = newnode
        head = newnode

    elif value > tail.data:    # insert after tail
        newnode.prev = tail
        tail.next = newnode
        tail = newnode

    else:    # insert into middle
        node = head
        while node is not None and node.data < value:
            node = node.next
        newnode.next = node
        newnode.prev = node.prev
        node.prev.next = newnode
        node.prev = newnode

循环链表

def travrseCircularList(listRef):
    curNode = listRef
    done = listRef is None
    while not None:
        curNode = curNode.next
        print(curNode.data)
        done = curNode is listRef   # 回到遍历起始点


def searchCircularList(listRef, target):
    curNode = listRef
    done = listRef is None
    while not done:
        curNode = curNode.next
        if curNode.data == target:
            return True
        else:
            done = curNode is listRef or curNode.data > target
    return False


def add_newnode_into_ordered_circular_linked_list(listRef, value):
    """ 插入并维持顺序
    1.插入空链表;2.插入头部;3.插入尾部;4.按顺序插入中间
    """
    newnode = ListNode(value)
    if listRef is None:    # empty list
        listRef = newnode
        newnode.next = newnode

    elif value < listRef.next.data:    # insert in front
        newnode.next = listRef.next
        listRef.next = newnode

    elif value > listRef.data:    # insert in back
        newnode.next = listRef.next
        listRef.next = newnode
        listRef = newnode

    else:    # insert in the middle
        preNode = None
        curNode = listRef
        done = listRef is None
        while not done:
            preNode = curNode
            preNode = curNode.next
            done = curNode is listRef or curNode.data > value

        newnode.next = curNode
        preNode.next = newnode

利用循环双端链表我们可以实现一个经典的缓存失效算法,lru:

# -*- coding: utf-8 -*-

class Node(object):
    def __init__(self, prev=None, next=None, key=None, value=None):
        self.prev, self.next, self.key, self.value = prev, next, key, value


class CircularDoubleLinkedList(object):
    def __init__(self):
        node = Node()
        node.prev, node.next = node, node
        self.rootnode = node

    def headnode(self):
        return self.rootnode.next

    def tailnode(self):
        return self.rootnode.prev

    def remove(self, node):
        if node is self.rootnode:
            return
        else:
            node.prev.next = node.next
            node.next.prev = node.prev

    def append(self, node):
        tailnode = self.tailnode()
        tailnode.next = node
        node.next = self.rootnode
        self.rootnode.prev = node


class LRUCache(object):
    def __init__(self, maxsize=16):
        self.maxsize = maxsize
        self.cache = {}
        self.access = CircularDoubleLinkedList()
        self.isfull = len(self.cache) >= self.maxsize

    def __call__(self, func):
        def wrapper(n):
            cachenode = self.cache.get(n)
            if cachenode is not None:  # hit
                self.access.remove(cachenode)
                self.access.append(cachenode)
                return cachenode.value
            else:  # miss
                value = func(n)
                if not self.isfull:
                    tailnode = self.access.tailnode()
                    newnode = Node(tailnode, self.access.rootnode, n, value)
                    self.access.append(newnode)
                    self.cache[n] = newnode

                    self.isfull = len(self.cache) >= self.maxsize
                    return value
                else:  # full
                    lru_node = self.access.headnode()
                    del self.cache[lru_node.key]
                    self.access.remove(lru_node)
                    tailnode = self.access.tailnode()
                    newnode = Node(tailnode, self.access.rootnode, n, value)
                    self.access.append(newnode)
                    self.cache[n] = newnode
                return value
        return wrapper


@LRUCache()
def fib(n):
    if n <= 2:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)


for i in range(1, 35):
    print(fib(i))

10章:Recursion

Recursion is a process for solving problems by subdividing a larger problem into smaller cases of the problem itself and then solving the smaller, more trivial parts.

递归函数:调用自己的函数

# 递归函数:调用自己的函数,看一个最简单的递归函数,倒序打印一个数
def printRev(n):
    if n > 0:
        print(n)
        printRev(n-1)


printRev(3)    # 从10输出到1


# 稍微改一下,print放在最后就得到了正序打印的函数
def printInOrder(n):
    if n > 0:
        printInOrder(n-1)
        print(n)    # 之所以最小的先打印是因为函数一直递归到n==1时候的最深栈,此时不再
                    # 递归,开始执行print语句,这时候n==1,之后每跳出一层栈,打印更大的值

printInOrder(3)    # 正序输出

Properties of Recursion: 使用stack解决的问题都能用递归解决

  • A recursive solution must contain a base case; 递归出口,代表最小子问题(n == 0退出打印)

  • A recursive solution must contain a recursive case; 可以分解的子问题

  • A recursive solution must make progress toward the base case. 递减n使得n像递归出口靠近

Tail Recursion: occurs when a function includes a single recursive call as the last statement of the function. In this case, a stack is not needed to store values to te used upon the return of the recursive call and thus a solution can be implemented using a iterative loop instead.

# Recursive Binary Search

def recBinarySearch(target, theSeq, first, last):
    # 你可以写写单元测试来验证这个函数的正确性
    if first > last:    # 递归出口1
        return False
    else:
        mid = (first + last) // 2
        if theSeq[mid] == target:
            return True    # 递归出口2
        elif theSeq[mid] > target:
            return recBinarySearch(target, theSeq, first, mid - 1)
        else:
            return recBinarySearch(target, theSeq, mid + 1, last)

11章:Hash Tables

基于比较的搜索(线性搜索,有序数组的二分搜索)最好的时间复杂度只能达到O(logn),利用hash可以实现O(1)查找,python内置dict的实现方式就是hash,你会发现dict的key必须要是实现了 __hash____eq__ 方法的。

Hashing: hashing is the process of mapping a search a key to a limited range of array indeices with the goal of providing direct access to the keys.

hash方法有个hash函数用来给key计算一个hash值,作为数组下标,放到该下标对应的槽中。当不同key根据hash函数计算得到的下标相同时,就出现了冲突。解决冲突有很多方式,比如让每个槽成为链表,每次冲突以后放到该槽链表的尾部,但是查询时间就会退化,不再是O(1)。还有一种探查方式,当key的槽冲突时候,就会根据一种计算方式去寻找下一个空的槽存放,探查方式有线性探查,二次方探查法等,cpython解释器使用的是二次方探查法。还有一个问题就是当python使用的槽数量大于预分配的2/3时候,会重新分配内存并拷贝以前的数据,所以有时候dict的add操作代价还是比较高的,牺牲空间但是可以始终保证O(1)的查询效率。如果有大量的数据,建议还是使用bloomfilter或者redis提供的HyperLogLog。

如果你感兴趣,可以看看这篇文章,介绍c解释器如何实现的python dict对象: Python dictionary implementation 。我们使用Python来实现一个类似的hash结构。

import ctypes

class Array:  # 第二章曾经定义过的ADT,这里当做HashMap的槽数组使用
    def __init__(self, size):
        assert size > 0, 'array size must be > 0'
        self._size = size
        PyArrayType = ctypes.py_object * size
        self._elements = PyArrayType()
        self.clear(None)

    def __len__(self):
        return self._size

    def __getitem__(self, index):
        assert index >= 0 and index < len(self), 'out of range'
        return self._elements[index]

    def __setitem__(self, index, value):
        assert index >= 0 and index < len(self), 'out of range'
        self._elements[index] = value

    def clear(self, value):
        """ 设置每个元素为value """
        for i in range(len(self)):
            self._elements[i] = value

    def __iter__(self):
        return _ArrayIterator(self._elements)


class _ArrayIterator:
    def __init__(self, items):
        self._items = items
        self._idx = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self._idx < len(self._items):
            val = self._items[self._idx]
            self._idx += 1
            return val
        else:
            raise StopIteration


class HashMap:
    """ HashMap ADT实现,类似于python内置的dict
    一个槽有三种状态:
    1.从未使用 HashMap.UNUSED。此槽没有被使用和冲突过,查找时只要找到UNUSEd就不用再继续探查了
    2.使用过但是remove了,此时是 HashMap.EMPTY,该探查点后边的元素扔可能是有key
    3.槽正在使用 _MapEntry节点
    """

    class _MapEntry:    # 槽里存储的数据
        def __init__(self, key, value):
            self.key = key
            self.value = value

    UNUSED = None    # 没被使用过的槽,作为该类变量的一个单例,下边都是is 判断
    EMPTY = _MapEntry(None, None)     # 使用过但是被删除的槽

    def __init__(self):
        self._table = Array(7)    # 初始化7个槽
        self._count = 0
        # 超过2/3空间被使用就重新分配,load factor = 2/3
        self._maxCount = len(self._table) - len(self._table) // 3

    def __len__(self):
        return self._count

    def __contains__(self, key):
        slot = self._findSlot(key, False)
        return slot is not None

    def add(self, key, value):
        if key in self:    # 覆盖原有value
            slot = self._findSlot(key, False)
            self._table[slot].value = value
            return False
        else:
            slot = self._findSlot(key, True)
            self._table[slot] = HashMap._MapEntry(key, value)
            self._count += 1
            if self._count == self._maxCount:    # 超过2/3使用就rehash
                self._rehash()
            return True

    def valueOf(self, key):
        slot = self._findSlot(key, False)
        assert slot is not None, 'Invalid map key'
        return self._table[slot].value

    def remove(self, key):
        """ remove操作把槽置为EMPTY"""
        assert key in self, 'Key error %s' % key
        slot = self._findSlot(key, forInsert=False)
        value = self._table[slot].value
        self._count -= 1
        self._table[slot] = HashMap.EMPTY
        return value

    def __iter__(self):
        return _HashMapIteraotr(self._table)

    def _slot_can_insert(self, slot):
        return (self._table[slot] is HashMap.EMPTY or
                self._table[slot] is HashMap.UNUSED)

    def _findSlot(self, key, forInsert=False):
        """ 注意原书有错误,代码根本不能运行,这里我自己改写的
        Args:
            forInsert (bool): if the search is for an insertion
        Returns:
            slot or None
        """
        slot = self._hash1(key)
        step = self._hash2(key)
        _len = len(self._table)

        if not forInsert:    # 查找是否存在key
            while self._table[slot] is not HashMap.UNUSED:
                # 如果一个槽是UNUSED,直接跳出
                if self._table[slot] is HashMap.EMPTY:
                    slot = (slot + step) % _len
                    continue
                elif self._table[slot].key == key:
                    return slot
                slot = (slot + step) % _len
            return None

        else:    # 为了插入key
            while not self._slot_can_insert(slot):    # 循环直到找到一个可以插入的槽
                slot = (slot + step) % _len
            return slot

    def _rehash(self):    # 当前使用槽数量大于2/3时候重新创建新的table
        origTable = self._table
        newSize = len(self._table) * 2 + 1    # 原来的2*n+1倍
        self._table = Array(newSize)

        self._count = 0
        self._maxCount = newSize - newSize // 3

        # 将原来的key value添加到新的table
        for entry in origTable:
            if entry is not HashMap.UNUSED and entry is not HashMap.EMPTY:
                slot = self._findSlot(entry.key, True)
                self._table[slot] = entry
                self._count += 1

    def _hash1(self, key):
        """ 计算key的hash值"""
        return abs(hash(key)) % len(self._table)

    def _hash2(self, key):
        """ key冲突时候用来计算新槽的位置"""
        return 1 + abs(hash(key)) % (len(self._table)-2)


class _HashMapIteraotr:
    def __init__(self, array):
        self._array = array
        self._idx = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self._idx < len(self._array):
            if self._array[self._idx] is not None and self._array[self._idx].key is not None:
                key = self._array[self._idx].key
                self._idx += 1
                return key
            else:
                self._idx += 1
        else:
            raise StopIteration


def print_h(h):
    for idx, i in enumerate(h):
        print(idx, i)
    print('\n')


def test_HashMap():
    """ 一些简单的单元测试,不过测试用例覆盖不是很全面 """
    h = HashMap()
    assert len(h) == 0
    h.add('a', 'a')
    assert h.valueOf('a') == 'a'
    assert len(h) == 1

    a_v = h.remove('a')
    assert a_v == 'a'
    assert len(h) == 0

    h.add('a', 'a')
    h.add('b', 'b')
    assert len(h) == 2
    assert h.valueOf('b') == 'b'
    b_v = h.remove('b')
    assert b_v == 'b'
    assert len(h) == 1
    h.remove('a')
    assert len(h) == 0

    n = 10
    for i in range(n):
        h.add(str(i), i)
    assert len(h) == n
    print_h(h)
    for i in range(n):
        assert str(i) in h
    for i in range(n):
        h.remove(str(i))
    assert len(h) == 0

12章: Advanced Sorting

第5章介绍了基本的排序算法,本章介绍高级排序算法。

归并排序(mergesort): 分治法

def merge_sorted_list(listA, listB):
    """ 归并两个有序数组,O(max(m, n)) ,m和n是数组长度"""
    print('merge left right list', listA, listB, end='')
    new_list = list()
    a = b = 0
    while a < len(listA) and b < len(listB):
        if listA[a] < listB[b]:
            new_list.append(listA[a])
            a += 1
        else:
            new_list.append(listB[b])
            b += 1

    while a < len(listA):
        new_list.append(listA[a])
        a += 1

    while b < len(listB):
        new_list.append(listB[b])
        b += 1

    print(' ->', new_list)
    return new_list


def mergesort(theList):
    """ O(nlogn),log层调用,每层n次操作
    mergesort: divided and conquer 分治
    1. 把原数组分解成越来越小的子数组
    2. 合并子数组来创建一个有序数组
    """
    print(theList)    # 我把关键步骤打出来了,你可以运行下看看整个过程
    if len(theList) <= 1:    # 递归出口
        return theList
    else:
        mid = len(theList) // 2

        # 递归分解左右两边数组
        left_half = mergesort(theList[:mid])
        right_half = mergesort(theList[mid:])

        # 合并两边的有序子数组
        newList = merge_sorted_list(left_half, right_half)
        return newList

""" 这是我调用一次打出来的排序过程
[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
[10, 9, 8, 7, 6]
[10, 9]
[10]
[9]
merge left right list [10] [9] -> [9, 10]
[8, 7, 6]
[8]
[7, 6]
[7]
[6]
merge left right list [7] [6] -> [6, 7]
merge left right list [8] [6, 7] -> [6, 7, 8]
merge left right list [9, 10] [6, 7, 8] -> [6, 7, 8, 9, 10]
[5, 4, 3, 2, 1]
[5, 4]
[5]
[4]
merge left right list [5] [4] -> [4, 5]
[3, 2, 1]
[3]
[2, 1]
[2]
[1]
merge left right list [2] [1] -> [1, 2]
merge left right list [3] [1, 2] -> [1, 2, 3]
merge left right list [4, 5] [1, 2, 3] -> [1, 2, 3, 4, 5]
"""

快速排序

def quicksort(theSeq, first, last):    # average: O(nlog(n))
    """
    quicksort :也是分而治之,但是和归并排序不同的是,采用选定主元(pivot)而不是从中间
    进行数组划分
    1. 第一步选定pivot用来划分数组,pivot左边元素都比它小,右边元素都大于等于它
    2. 对划分的左右两边数组递归,直到递归出口(数组元素数目小于2)
    3. 对pivot和左右划分的数组合并成一个有序数组
    """
    if first < last:
        pos = partitionSeq(theSeq, first, last)
        # 对划分的子数组递归操作
        quicksort(theSeq, first, pos - 1)
        quicksort(theSeq, pos + 1, last)


def partitionSeq(theSeq, first, last):
    """ 快排中的划分操作,把比pivot小的挪到左边,比pivot大的挪到右边"""
    pivot = theSeq[first]
    print('before partitionSeq', theSeq)

    left = first + 1
    right = last

    while True:
        # 找到第一个比pivot大的
        while left <= right and theSeq[left] < pivot:
            left += 1

        # 从右边开始找到比pivot小的
        while right >= left and theSeq[right] >= pivot:
            right -= 1

        if right < left:
            break
        else:
            theSeq[left], theSeq[right] = theSeq[right], theSeq[left]

    # 把pivot放到合适的位置
    theSeq[first], theSeq[right] = theSeq[right], theSeq[first]

    print('after partitionSeq {}: {}\t'.format(theSeq, pivot))
    return right    # 返回pivot的位置


def test_partitionSeq():
    l = [0,1,2,3,4]
    assert partitionSeq(l, 0, len(l)-1) == 0
    l = [4,3,2,1,0]
    assert partitionSeq(l, 0, len(l)-1) == 4
    l = [2,3,0,1,4]
    assert partitionSeq(l, 0, len(l)-1) == 2

test_partitionSeq()


def test_quicksort():
    def _is_sorted(seq):
        for i in range(len(seq)-1):
            if seq[i] > seq[i+1]:
                return False
        return True

    from random import randint
    for i in range(100):
        _len = randint(1, 100)
        to_sort = []
        for i in range(_len):
            to_sort.append(randint(0, 100))
        quicksort(to_sort, 0, len(to_sort)-1)    # 注意这里用了原地排序,直接更改了数组
        print(to_sort)
        assert _is_sorted(to_sort)

test_quicksort()

利用快排中的partitionSeq操作,我们还能实现另一个算法,nth_element,快速查找一个无序数组中的第k大元素

def nth_element(seq, beg, end, k):
    if beg == end:
        return seq[beg]
    pivot_index = partitionSeq(seq, beg, end)
    if pivot_index == k:
        return seq[k]
    elif pivot_index > k:
        return nth_element(seq, beg, pivot_index-1, k)
    else:
        return nth_element(seq, pivot_index+1, end, k)

def test_nth_element():
    from random import shuffle
    n = 10
    l = list(range(n))
    shuffle(l)
    print(l)
    for i in range(len(l)):
        assert nth_element(l, 0, len(l)-1, i) == i

test_nth_element()

13章: Binary Tree

The binary Tree: 二叉树,每个节点做多只有两个子节点

class _BinTreeNode:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None


# 三种depth-first遍历
def preorderTrav(subtree):
    """ 先(根)序遍历"""
    if subtree is not None:
        print(subtree.data)
        preorderTrav(subtree.left)
        preorderTrav(subtree.right)


def inorderTrav(subtree):
    """ 中(根)序遍历"""
    if subtree is not None:
        preorderTrav(subtree.left)
        print(subtree.data)
        preorderTrav(subtree.right)


def postorderTrav(subtree):
    """ 后(根)序遍历"""
    if subtree is not None:
        preorderTrav(subtree.left)
        preorderTrav(subtree.right)
        print(subtree.data)


# 宽度优先遍历(bradth-First Traversal): 一层一层遍历, 使用queue
def breadthFirstTrav(bintree):
    from queue import Queue    # py3
    q = Queue()
    q.put(bintree)
    while not q.empty():
        node = q.get()
        print(node.data)
        if node.left is not None:
            q.put(node.left)
        if node.right is not None:
            q.put(node.right)


class _ExpTreeNode:
    __slots__ = ('element', 'left', 'right')

    def __init__(self, data):
        self.element = data
        self.left = None
        self.right = None

    def __repr__(self):
        return '<_ExpTreeNode: {} {} {}>'.format(
            self.element, self.left, self.right)

from queue import Queue
class ExpressionTree:
    """
    表达式树: 操作符存储在内节点操作数存储在叶子节点的二叉树。(符号树真难打出来)
        *
       / \
      +   -
     / \  / \
     9  3 8   4
    (9+3) * (8-4)

    Expression Tree Abstract Data Type,可以实现二元操作符
    ExpressionTree(expStr): user string as constructor param
    evaluate(varDict): evaluates the expression and returns the numeric result
    toString(): constructs and retutns a string represention of the expression

    Usage:
        vars = {'a': 5, 'b': 12}
        expTree = ExpressionTree("(a/(b-3))")
        print('The result = ', expTree.evaluate(vars))
    """

    def __init__(self, expStr):
        self._expTree = None
        self._buildTree(expStr)

    def evaluate(self, varDict):
        return self._evalTree(self._expTree, varDict)

    def __str__(self):
        return self._buildString(self._expTree)

    def _buildString(self, treeNode):
        """ 在一个子树被遍历之前添加做括号,在子树被遍历之后添加右括号 """
        # print(treeNode)
        if treeNode.left is None and treeNode.right is None:
            return str(treeNode.element)    # 叶子节点是操作数直接返回
        else:
            expStr = '('
            expStr += self._buildString(treeNode.left)
            expStr += str(treeNode.element)
            expStr += self._buildString(treeNode.right)
            expStr += ')'
            return expStr

    def _evalTree(self, subtree, varDict):
        # 是不是叶子节点, 是的话说明是操作数,直接返回
        if subtree.left is None and subtree.right is None:
            # 操作数是合法数字吗
            if subtree.element >= '0' and subtree.element <= '9':
                return int(subtree.element)
            else:    # 操作数是个变量
                assert subtree.element in varDict, 'invalid variable.'
                return varDict[subtree.element]
        else:    # 操作符则计算其子表达式
            lvalue = self._evalTree(subtree.left, varDict)
            rvalue = self._evalTree(subtree.right, varDict)
            print(subtree.element)
            return self._computeOp(lvalue, subtree.element, rvalue)

    def _computeOp(self, left, op, right):
        assert op
        op_func = {
            '+': lambda left, right: left + right,    # or import operator, operator.add
            '-': lambda left, right: left - right,
            '*': lambda left, right: left * right,
            '/': lambda left, right: left / right,
            '%': lambda left, right: left % right,
        }
        return op_func[op](left, right)

    def _buildTree(self, expStr):
        expQ = Queue()
        for token in expStr:    # 遍历表达式字符串的每个字符
            expQ.put(token)
        self._expTree = _ExpTreeNode(None)    # 创建root节点
        self._recBuildTree(self._expTree, expQ)

    def _recBuildTree(self, curNode, expQ):
        token = expQ.get()
        if token == '(':
            curNode.left = _ExpTreeNode(None)
            self._recBuildTree(curNode.left, expQ)

            # next token will be an operator: + = * / %
            curNode.element = expQ.get()
            curNode.right = _ExpTreeNode(None)
            self._recBuildTree(curNode.right, expQ)

            # the next token will be ')', remmove it
            expQ.get()

        else:  # the token is a digit that has to be converted to an int.
            curNode.element = token


vars = {'a': 5, 'b': 12}
expTree = ExpressionTree("((2*7)+8)")
print(expTree)
print('The result = ', expTree.evaluate(vars))

Heap(堆):二叉树最直接的一个应用就是实现堆。堆就是一颗完全二叉树,最大堆的非叶子节点的值都比孩子大,最小堆的非叶子结点的值都比孩子小。 python内置了heapq模块帮助我们实现堆操作,比如用内置的heapq模块实现个堆排序:

# 使用python内置的heapq实现heap sort
def heapsort(iterable):
    from heapq import heappush, heappop
    h = []
    for value in iterable:
        heappush(h, value)
    return [heappop(h) for i in range(len(h))]

但是一般实现堆的时候实际上并不是用数节点来实现的,而是使用数组实现,效率比较高。为什么可以用数组实现呢?因为完全二叉树的性质, 可以用下标之间的关系表示节点之间的关系,MaxHeap的docstring中已经说明了

class MaxHeap:
    """
    Heaps:
    完全二叉树,最大堆的非叶子节点的值都比孩子大,最小堆的非叶子结点的值都比孩子小
    Heap包含两个属性,order property 和 shape property(a complete binary tree),在插入
    一个新节点的时候,始终要保持这两个属性
    插入操作:保持堆属性和完全二叉树属性, sift-up 操作维持堆属性
    extract操作:只获取根节点数据,并把树最底层最右节点copy到根节点后,sift-down操作维持堆属性

    用数组实现heap,从根节点开始,从上往下从左到右给每个节点编号,则根据完全二叉树的
    性质,给定一个节点i, 其父亲和孩子节点的编号分别是:
        parent = (i-1) // 2
        left = 2 * i + 1
        rgiht = 2 * i + 2
    使用数组实现堆一方面效率更高,节省树节点的内存占用,一方面还可以避免复杂的指针操作,减少
    调试难度。

    """

    def __init__(self, maxSize):
        self._elements = Array(maxSize)    # 第二章实现的Array ADT
        self._count = 0

    def __len__(self):
        return self._count

    def capacity(self):
        return len(self._elements)

    def add(self, value):
        assert self._count < self.capacity(), 'can not add to full heap'
        self._elements[self._count] = value
        self._count += 1
        self._siftUp(self._count - 1)
        self.assert_keep_heap()    # 确定每一步add操作都保持堆属性

    def extract(self):
        assert self._count > 0, 'can not extract from an empty heap'
        value = self._elements[0]    # save root value
        self._count -= 1
        self._elements[0] = self._elements[self._count]    # 最右下的节点放到root后siftDown
        self._siftDown(0)
        self.assert_keep_heap()
        return value

    def _siftUp(self, ndx):
        if ndx > 0:
            parent = (ndx - 1) // 2
            # print(ndx, parent)
            if self._elements[ndx] > self._elements[parent]:    # swap
                self._elements[ndx], self._elements[parent] = self._elements[parent], self._elements[ndx]
                self._siftUp(parent)    # 递归

    def _siftDown(self, ndx):
        left = 2 * ndx + 1
        right = 2 * ndx + 2
        # determine which node contains the larger value
        largest = ndx
        if (left < self._count and
            self._elements[left] >= self._elements[largest] and
            self._elements[left] >= self._elements[right]):  # 原书这个地方没写实际上找的未必是largest
            largest = left
        elif right < self._count and self._elements[right] >= self._elements[largest]:
            largest = right
        if largest != ndx:
            self._elements[ndx], self._elements[largest] = self._elements[largest], self._elements[ndx]
            self._siftDown(largest)

    def __repr__(self):
        return ' '.join(map(str, self._elements))

    def assert_keep_heap(self):
        """ 我加了这个函数是用来验证每次add或者extract之后,仍保持最大堆的性质"""
        _len = len(self)
        for i in range(0, int((_len-1)/2)):    # 内部节点(非叶子结点)
            l = 2 * i + 1
            r = 2 * i + 2
            if l < _len and r < _len:
                assert self._elements[i] >= self._elements[l] and self._elements[i] >= self._elements[r]

def test_MaxHeap():
    """ 最大堆实现的单元测试用例 """
    _len = 10
    h = MaxHeap(_len)
    for i in range(_len):
        h.add(i)
        h.assert_keep_heap()
    for i in range(_len):
        # 确定每次出来的都是最大的数字,添加的时候是从小到大添加的
        assert h.extract() == _len-i-1

test_MaxHeap()

def simpleHeapSort(theSeq):
    """ 用自己实现的MaxHeap实现堆排序,直接修改原数组实现inplace排序"""
    if not theSeq:
        return theSeq
    _len = len(theSeq)
    heap = MaxHeap(_len)
    for i in theSeq:
        heap.add(i)
    for i in reversed(range(_len)):
        theSeq[i] = heap.extract()
    return theSeq


def test_simpleHeapSort():
    """ 用一些测试用例证明实现的堆排序是可以工作的 """
    def _is_sorted(seq):
        for i in range(len(seq)-1):
            if seq[i] > seq[i+1]:
                return False
        return True

    from random import randint
    assert simpleHeapSort([]) == []
    for i in range(1000):
        _len = randint(1, 100)
        to_sort = []
        for i in range(_len):
            to_sort.append(randint(0, 100))
        simpleHeapSort(to_sort)    # 注意这里用了原地排序,直接更改了数组
        assert _is_sorted(to_sort)


test_simpleHeapSort()

14章: Search Trees

二叉差找树性质:对每个内部节点V, 1. 所有key小于V.key的存储在V的左子树。 2. 所有key大于V.key的存储在V的右子树 对BST进行中序遍历会得到升序的key序列

class _BSTMapNode:
    __slots__ = ('key', 'value', 'left', 'right')

    def __init__(self, key, value):
        self.key = key
        self.value = value
        self.left = None
        self.right = None

    def __repr__(self):
        return '<{}:{}> left:{}, right:{}'.format(
            self.key, self.value, self.left, self.right)

    __str__ = __repr__


class BSTMap:
    """ BST,树节点包含key可payload。用BST来实现之前用hash实现过的Map ADT.
    性质:对每个内部节点V,
    1.对于节点V,所有key小于V.key的存储在V的左子树。
    2.所有key大于V.key的存储在V的右子树
    对BST进行中序遍历会得到升序的key序列
    """
    def __init__(self):
        self._root = None
        self._size = 0
        self._rval = None     # 作为remove的返回值

    def __len__(self):
        return self._size

    def __iter__(self):
        return _BSTMapIterator(self._root, self._size)

    def __contains__(self, key):
        return self._bstSearch(self._root, key) is not None

    def valueOf(self, key):
        node = self._bstSearch(self._root, key)
        assert node is not None, 'Invalid map key.'
        return node.value

    def _bstSearch(self, subtree, target):
        if subtree is None:    # 递归出口,遍历到树底没有找到key或是空树
            return None
        elif target < subtree.key:
            return self._bstSearch(subtree.left, target)
        elif target > subtree.key:
            return self._bstSearch(subtree.right, target)
        return subtree    # 返回引用

    def _bstMinumum(self, subtree):
        """ 顺着树一直往左下角递归找就是最小的,向右下角递归就是最大的 """
        if subtree is None:
            return None
        elif subtree.left is None:
            return subtree
        else:
            return subtree._bstMinumum(self, subtree.left)

    def add(self, key, value):
        """ 添加或者替代一个key的value, O(N) """
        node = self._bstSearch(self._root, key)
        if node is not None:    # if key already exists, update value
            node.value = value
            return False
        else:   # insert a new entry
            self._root = self._bstInsert(self._root, key, value)
            self._size += 1
            return True

    def _bstInsert(self, subtree, key, value):
        """ 新的节点总是插入在树的叶子结点上 """
        if subtree is None:
            subtree = _BSTMapNode(key, value)
        elif key < subtree.key:
            subtree.left = self._bstInsert(subtree.left, key, value)
        elif key > subtree.key:
            subtree.right = self._bstInsert(subtree.right, key, value)
        # 注意这里没有else语句了,应为在被调用处add函数里先判断了是否有重复key
        return subtree

    def remove(self, key):
        """ O(N)
        被删除的节点分为三种:
        1.叶子结点:直接把其父亲指向该节点的指针置None
        2.该节点有一个孩子: 删除该节点后,父亲指向一个合适的该节点的孩子
        3.该节点有俩孩子:
            (1)找到要删除节点N和其后继S(中序遍历后该节点下一个)
            (2)复制S的key到N
            (3)从N的右子树中删除后继S(即在N的右子树中最小的)
        """
        assert key in self, 'invalid map key'
        self._root = self._bstRemove(self._root, key)
        self._size -= 1
        return self._rval

    def _bstRemove(self, subtree, target):
        # search for the item in the tree
        if subtree is None:
            return subtree
        elif target < subtree.key:
            subtree.left = self._bstRemove(subtree.left, target)
            return subtree
        elif target > subtree.key:
            subtree.right = self._bstRemove(subtree.right, target)
            return subtree

        else:    # found the node containing the item
            self._rval = subtree.value
            if subtree.left is None and subtree.right is None:
                # 叶子node
                return None
            elif subtree.left is None or subtree.right is None:
                # 有一个孩子节点
                if subtree.left is not None:
                    return subtree.left
                else:
                    return subtree.right
            else:   # 有俩孩子节点
                successor = self._bstMinumum(subtree.right)
                subtree.key = successor.key
                subtree.value = successor.value
                subtree.right = self._bstRemove(subtree.right, successor.key)
                return subtree

    def __repr__(self):
        return '->'.join([str(i) for i in self])

    def assert_keep_bst_property(self, subtree):
        """ 写这个函数为了验证add和delete操作始终维持了bst的性质 """
        if subtree is None:
            return
        if subtree.left is not None and subtree.right is not None:
            assert subtree.left.value <= subtree.value
            assert subtree.right.value >= subtree.value
            self.assert_keep_bst_property(subtree.left)
            self.assert_keep_bst_property(subtree.right)

        elif subtree.left is None and subtree.right is not None:
            assert subtree.right.value >= subtree.value
            self.assert_keep_bst_property(subtree.right)

        elif subtree.left is not None and subtree.right is None:
            assert subtree.left.value <= subtree.value
            self.assert_keep_bst_property(subtree.left)


class _BSTMapIterator:
    def __init__(self, root, size):
        self._theKeys = Array(size)
        self._curItem = 0
        self._bstTraversal(root)
        self._curItem = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self._curItem < len(self._theKeys):
            key = self._theKeys[self._curItem]
            self._curItem += 1
            return key
        else:
            raise StopIteration

    def _bstTraversal(self, subtree):
        if subtree is not None:
            self._bstTraversal(subtree.left)
            self._theKeys[self._curItem] = subtree.key
            self._curItem += 1
            self._bstTraversal(subtree.right)


def test_BSTMap():
    l = [60, 25, 100, 35, 17, 80]
    bst = BSTMap()
    for i in l:
        bst.add(i)

def test_HashMap():
    """ 之前用来测试用hash实现的map,改为用BST实现的Map测试 """
    # h = HashMap()
    h = BSTMap()
    assert len(h) == 0
    h.add('a', 'a')
    assert h.valueOf('a') == 'a'
    assert len(h) == 1

    a_v = h.remove('a')
    assert a_v == 'a'
    assert len(h) == 0

    h.add('a', 'a')
    h.add('b', 'b')
    assert len(h) == 2
    assert h.valueOf('b') == 'b'
    b_v = h.remove('b')
    assert b_v == 'b'
    assert len(h) == 1
    h.remove('a')

    assert len(h) == 0

    _len = 10
    for i in range(_len):
        h.add(str(i), i)
    assert len(h) == _len
    for i in range(_len):
        assert str(i) in h
    for i in range(_len):
        print(len(h))
        print('bef', h)
        _ = h.remove(str(i))
        assert _ == i
        print('aft', h)
        print(len(h))
    assert len(h) == 0

test_HashMap()