NumPy使用手记 (1)

时间:2021-06-21 08:05:03

1) 巧用 where函数

  where函数是numpy的内置,也是一个非常有用的函数,提供了快速并且灵活的计算功能。


def f_norm_1(data, estimate):
   residule = 0
   for row_index in range(data.shape[0]):
     for column_index in range(data.shape[1]):
       if data[row_index][column_index] != 0:
         residule += (data[row_index][column_index] - estimate[row_index][column_index]) ** 2
   return residule


def f_norm_2(data, estimate) 

    return sum(where(data != 0, (data-estimate) **2, 0))

 

这两段代码完成同样的功能,计算两个矩阵的差,然后将残差进行平方,注意,因为我需要的是考虑矩阵稀疏性,所以不能用内置的norm,函数1是我用普通的python写的,不太复杂,对于规模10*10的矩阵,计算200次耗时0.15s,函数2使用了where函数和sum函数,这两个函数都是为向量计算优化过的,不仅简介,而且耗时仅0.03s, 快了有五倍,不仅如此,有同学将NumPy和matlab做过比较,NumPy稍快一些,这已经是很让人兴奋的结果。