【计算机图形学】3DIT的训练数据总结

时间:2024-11-12 07:18:47

3D Implicit Transporter用的是PartNet-Mobility数据集,我想用Shape2Motion数据集,但是3DIT是所有类别的数据扔一起训练的,为了避免到时候弄3DIT数据忘记了,我在这里记录一些点云数据训练的东西。方便之后用3DIT来训练BaseLine

1. 训练文件

  1. ./core/datasets/splits/train_partnet.txt可以看出,所有数据是混合训练的,具体每个数据的意思说明如下:
  2. ./core/datasets/train_dataset.py中:obj_name, articulate_type, instance_name, name_index = self.filenames[index].split(' '),可以看出每一行的意思是:物体类别名,铰接类型(R为旋转T为平移),实例名称,第x帧点云

2.点云数据分析

  1. 读取点云文件后,pc的shape为[31,],可以理解为对一个物体采31帧
  2. 对于pc[i],其shape为[n,4]n是采点的数量,前3维是物体的(x,y,z)坐标,最后一维尚不明确

3. train_dataset.py,训练文件

  1. 注意对于合成数据,看的类是class Articulated_Obj_Syn(BaseDataset)
  2. 对于BaseDataset,主要注意的就是定义了:
self.on_occupancy_num = config_public_params["on_occupancy_num"] # 表面点
self.off_occupancy_num = config_public_params["off_occupancy_num"] # 离面点
  1. __getitem__操作第一步是get_input,首先基于一个中间索引(middle)生成前后随机两帧的索引start, end,加载了该物体点云后,通过这两帧随机索引,获取pc[start]pc[end],实际上就是起始帧的点云和结束帧的点云
# 返回这个点云的pre、middle、last三个点云的点坐标
# 物体名称、实例名称、铰接物体类型
return data[name_index_1][:, :3].astype(np.float), \
        data[name_index][:, :3].astype(np.float), \
        data[name_index_2][:, :3].astype(np.float), \
        None, None, \
        obj_name, instance_name, articulate_type
  1. 第二步是prepare_input_data(),主要就是根据config中的指定数量来采样点云
# sample
# print(f"self.max_down_sample:{self.max_down_sample}") # 1
pc_start = self.prepare_input_data(pc_start_ori, self.max_down_sample)
pc_middle = self.prepare_input_data(pc_middle_ori, 1)
pc_end = self.prepare_input_data(pc_end_ori, self.max_down_sample)
  1. 第三步是数据增强,主要是做旋转、扰动,能处理旋转是因为在Pipeline里使用PointNet&PointNet++做了处理,即使旋转了也没关系
if self.config_aug["do_aug"]: # 默认为True
    # rotate
    z_angle = np.random.uniform() * self.config_aug["rotate_angle"] / 180.0 * (np.pi)
    angles_2d = [0, 0, z_angle]

    # pre,middle,final在z轴随机旋转
    pc_start = atomic_rotate(pc_start, angles_2d)
    pc_middle = atomic_rotate(pc_middle, angles_2d) # 被裁后的点云
    pc_middle_ori = atomic_rotate(pc_middle_ori, angles_2d) # 原始点云
    pc_end = atomic_rotate(pc_end, angles_2d)

    # jitter (Gaussian noise) -> 高斯噪声随机扰动
    sigma, clip = self.config_aug["sigma"], self.config_aug["clip"]
    jitter1 = np.clip(sigma * np.random.uniform(pc_start.shape[0], 3), -1 * clip, clip)
    jitter2 = np.clip(sigma * np.random.uniform(pc_middle.shape[0], 3), -1 * clip, clip)
    jitter3 = np.clip(sigma * np.random.uniform(pc_end.shape[0], 3), -1 * clip, clip)
    pc_start += jitter1
    pc_middle += jitter2
    pc_end += jitter3
  1. 第四步对点云数据归一化、缩放
# normalize
bound_max = np.maximum(pc_start.max(0), pc_middle.max(0), pc_end.max(0)) # 最大点云
bound_min = np.minimum(pc_start.min(0), pc_middle.min(0), pc_end.min(0)) # 最小点云
center = (bound_min + bound_max) / 2 # 求中间值?
scale = (bound_max - bound_min).max() # / (1 + self.padding)

# 缩放
pc_start = (pc_start - center) / scale
pc_middle = (pc_middle - center) / scale
pc_middle_ori = (pc_middle_ori - center) / scale
pc_end = (pc_end - center) / scale
  1. 第五步是生成查询点云的occupancy label,就是表面点取1,然后随机采样config里指定的里面点为0
occup_coords, occup_labels = self.prepare_occupancy_data(pc_middle_ori, self.sampling_mode)

↓↓↓

def prepare_occupancy_data(self, pcd_data, sampling_mode='random'):
    # print(f"sampling_mode:{sampling_mode}") # random
    can_repeat = True if self.on_occupancy_num > pcd_data.shape[0] else False
    
    # 采一些表面点
    rand_idcs_on = np.random.choice(pcd_data.shape[0], 
                                    size=self.on_occupancy_num, 
                                    replace=can_repeat)
    on_surface_coords = pcd_data[rand_idcs_on]
    # 表面点的occupancy label置True
    on_surface_labels = np.ones(self.on_occupancy_num)

    if sampling_mode == 'random':
        # 在bound范围内采离心点~
        # self.bound默认是0.5,相当于是在边长为1的立方体内做采样
        off_surface_x = np.random.uniform(-self.bound, self.bound, 
                                            size=(self.off_occupancy_num, 1))
        off_surface_y = np.random.uniform(-self.bound, self.bound, 
                                            size=(self.off_occupancy_num, 1))
        off_surface_z  = np.random.uniform(-self.bound, self.bound, 
                                            size=(self.off_occupancy_num, 1))
        off_surface_coords = np.concatenate((off_surface_x, off_surface_y, off_surface_z), 
                                            axis=1)
        # 非表面点置false
        off_surface_labels = np.zeros(self.off_occupancy_num)
    else:
        grid = make_3d_grid([-0.5, 0.5, 15], [-0.5, 0.5, 15], [-0.5, 0.5, 15])
        rand_idcs_on = np.random.choice(grid.shape[0], 
                                    size=self.off_occupancy_num)
        off_surface_coords = grid[rand_idcs_on]
        off_surface_labels = np.zeros(self.off_occupancy_num)
    
    coords = np.concatenate((on_surface_coords, off_surface_coords), axis=0)
    labels = np.concatenate((on_surface_labels, off_surface_labels), axis=0)

    # 打乱输入点云
    rix = np.random.permutation(coords.shape[0])
    coords = coords[rix]
    labels = labels[rix]

    return coords, labels

10.第六步返回结果

4. 采样点云的数据范围

点云xyz应该是在[-1, +1]之间