如何理解np.mean(x,1)和 np.mean(x,axis=(0,2,3))

时间:2024-11-04 07:20:42

1. np.mean(x,1,keepdims=True)

假设 x[B, C, H, W]

结果是[B,1,H,W] :3层通道变成1层通道,这一层通道是均值

def cal_mean(x:np.ndarray) -> np.ndarray:
    B,C,H,W = x.shape
    mean = np.zeros((B,1,H,W))

    for b in range(B):
        for h in range(H):
            for w in range(W):
                total_sum = 0.0

                for c in range(C):
                    total_sum += x[b,c,h,w]
                mean[b,0,h,w]=total_sum / C
    return mean


batch,channels, height,width = 4,3,8,8
x = np.random.randint(-10,10,(batch,channels,height,width))
mean_output=cal_mean(x)

mean2 = np.mean(x,1,keepdims=True)

print(mean_output==mean2) # True......

2. np.mean(x,axis=(0,2,3),keepdims=True)

假设 x[1, C, 1, 1]:通道不变,其余维度变为1层 

import numpy as np

def cal_mean(x:np.ndarray):
    B,C,H,W = x.shape
    batch_mean = np.zeros((1,C,1,1))

    for c in range(C):
        total_sum = 0.0
        count = 0
        for b in range(B):
            for h in range(H):
                for w in range(W):
                    total_sum += x[b,c,h,w]
                    count += 1
        # 实际上,count=B*H*W
        batch_mean[0,c,0,0] = total_sum / count
    return batch_mean

mean0 = np.mean(x,axis=(0,2,3),keepdims=True)
mean1 = cal_mean(x)
print(mean0==mean1)

3. 小结

np.mean(x,1):第1维度变为1,第1维在最内循环,[B, C, H, W] -- >[B, 1, H, W] ,

np.mean(x,axis=(0,2,3)):第0、2、3维度变为1,第1维度在最外循环,[b,c,h,w] -->[1,C,1,1]