【强化学习/tf/gym】(一)创建自定义gym环境

时间:2025-03-28 14:05:01
def step(self, action): reward = 0.0 # 奖励初始值为0 done = False # 该局游戏是否结束 # 首先调用方法更新球的状态 for _, ball in enumerate(self.balls): ball.update(action) # 然后处理球之间的吞噬 # 定一个要补充的球的类型列表,吃了多少球,就要补充多少球 _new_ball_types = [] # 遍历,这里就没有考虑性能问题了 for _, A_ball in enumerate(self.balls): for _, B_ball in enumerate(self.balls): # 自己,跳过 if A_ball.id == B_ball.id: continue # 先计算球A的半径 # 我们使用球的分数作为球的面积 A_radius = math.sqrt(A_ball.s / math.pi) # 计算球AB之间在x\y轴上的距离 AB_x = math.fabs(A_ball.x - B_ball.x) AB_y = math.fabs(A_ball.y - B_ball.y) # 如果AB之间在x\y轴上的距离 大于 A的半径,那么B一定在A外 if AB_x > A_radius or AB_y > A_radius: continue # 计算距离 if AB_x*AB_x + AB_y*AB_y > A_radius*A_radius: continue # 如果agent球被吃掉,游戏结束 if B_ball.t == BALL_TYPE_SELF: done = True # A吃掉B A加上B的分数 A_ball.addscore(B_ball.s) # 计算奖励 if A_ball.t == BALL_TYPE_SELF: reward += B_ball.s # 把B从列表中删除,并记录要增加一个B类型的球 _new_ball_types.append(B_ball.t) self.balls.remove(B_ball) # 补充球 for _, val in enumerate(_new_ball_types): self.balls.append(self.randball(np.int(val))) # 填充观察数据 self.state = np.vstack([ball.state() for (_, ball) in enumerate(self.balls)]) # 返回 return self.state, reward, done, {}