# import numpy as np def median(arr):
#return np.median(arr)
arr.sort()
return arr[len(arr)>>1] def patition(arr, low, high):
pivot = arr[low]
i = low+1
while i <= high:
if arr[i] > pivot:
arr[i], arr[high] = arr[high], arr[i]
high -= 1
else:
i += 1
arr[high], arr[low] = arr[low], arr[high]
return high def median_helper(arr, low, high, found_index):
pivot_index = patition(arr, low, high)
if pivot_index == found_index:
return arr[found_index]
elif pivot_index > found_index:
return median_helper(arr, low, pivot_index - 1, found_index)
else:
return median_helper(arr, pivot_index + 1, high, found_index) def median2(arr):
assert arr
mid = len(arr)>>1
return median_helper(arr, 0, len(arr) - 1, mid) from random import randint
for j in range(2, 2000):
arr = [randint(0, 2000) for i in range(1, j)]
a = median(list(arr))
b = median2(list(arr))
if a != b:
print(a)
print(b)
print("debug:")
a = median(list(arr))
b = median2(list(arr))
时间复杂度:O(2n)
因为 n+n/2+n/4+.... = 2n
下面这样写也很直观,比我写的跑起来还快些(诡异):
import random
def quick_select(A, k):
#pivot value is random
pivot = random.choice(A) A1 = [] #values < pivot
A2 = [] #values > pivot for i in A:
if i < pivot:
A1.append(i)
elif i > pivot:
A2.append(i)
else:
pass # ignore Pivot value! #case 1: median is in A1
if k <= len(A1):
return quick_select(A1, k)
#case 2: median is in A2
elif k > len(A) - len(A2):
return quick_select(A2, k - (len(A) - len(A2)))
#case 3: median found
else:
return pivot
C=n+n2+n4+n8+…=2n=O(n)