Albumentations 是一个流行的 Python 库,专为图像增强任务设计,以帮助提高机器学习和深度学习模型的性能。该库高效且易于使用,支持广泛的增强技术,旨在提供快速且多样化的数据增强方法。
安装 Albumentations
要安装 Albumentations,您可以使用 pip,这是 Python 的包管理工具。在命令行中运行以下命令即可安装:
bash
pip install albumentations
确保您的 Python 环境已激活(如果使用虚拟环境),这样安装的库才会放在正确的环境中。
数据增强方法
Albumentations 提供了多种数据增强方法,包括但不限于:
- 几何变换:如旋转(Rotate)、翻转(Flip)、缩放(Scale)、裁剪(Crop)等,这些操作可以改变图像的空间结构。
- 颜色变换:如调整亮度(Brightness)、对比度(Contrast)、饱和度(Saturation),以及更复杂的操作如随机对比度(RandomBrightnessContrast)和色调变换(HueSaturationValue)。
- 噪声注入:如高斯噪声(GaussianNoise)、盐和胡椒噪声(SaltAndPepper)等,这些增强可以帮助模型学习在噪声存在的情况下进行更鲁棒的预测。
- 遮挡和遮蔽:如遮挡部分图像区域(CoarseDropout),用于模拟丢失像素的场景。
- 模糊和锐化:如高斯模糊(GaussianBlur)、运动模糊(MotionBlur)和锐化(Sharpen),用于模拟摄影中的常见效果。
使用 Albumentations 进行数据增强不仅可以扩展训练数据集,还可以帮助提高模型对新、未见过的图像数据的泛化能力。它的设计允许轻松集成到现有的数据处理流程中,并支持与其他流行的深度学习库如 PyTorch 和 TensorFlow 的无缝配合。
示例代码
下面是一个使用 Albumentations 进行图像增强的简单示例,同时实现了坐标变换。(可以使用标签可视化工具脚本检查,见我的另外一个帖子:/mp_blog/creation/editor/140671100。)
import albumentations as A from PIL import Image import numpy as np # 加载图像 image = (('path_to_image.jpg')) # 定义一个增强管道
transform = ([ (width=256, height=256), (p=0.5), (p=0.2), ]) # 应用增强
augmented_image = transform(image=image)['image']
这个例子中,我们定义了一个变换管道,包括随机裁剪、水平翻转和随机亮度对比度调整,然后将其应用到一张图像上。这种方法可以在训练神经网络时用来动态创建增强的图像数据。
以下是批量处理的python脚本思路:
代码思路:
-
初始化:在
YOLOAug
类的构造函数中,接收输入参数,包括原始图像和标签的路径、保存增强后图像和标签的路径、类别标签列表、需要增强的类别以及每个类别的目标计数。 -
数据增强:使用
albumentations
库定义了一系列的数据增强操作,包括像素级变换和空间级变换。这些操作可以增加图像的多样性,提高模型的泛化能力。 -
遍历处理:在
aug_image
方法中,遍历原始图像路径下的所有图像文件,对每张图像应用定义好的数据增强操作。 -
保存结果:对每张增强后的图像和对应的标签进行保存,使用
uuid
生成唯一的文件名以区分原始图像和增强后的图像。 -
执行:在
main
函数中创建YOLOAug
类的实例,并调用aug_image
方法来执行整个数据增强流程。
这段代码是一个用于数据增强的Python脚本,主要针对目标检测任务中的图像和标签文件进行增强。使用的数据增强库是albumentations
,这是一个非常强大的图像增强库,支持各种像素级和空间级的变换。
读取的数据集目标的格式:
源目录:
datasetorg:绝对路径:"C:\Desktop\dataset_org"
--images:绝对路径:"C:\Desktop\dataset_org\images"
----.jpg
--labels:绝对路径:"C:\Desktop\dataset_org\labels"
.....txt
目标目录:
datasetdst:绝对路径:"C:\Desktop\dataset_dst"
--images:绝对路径:"C:\Desktop\dataset_dst\images"
----.jpg
--labels:绝对路径:"C:\Desktop\dataset_dst\labels"
.......txt
.txt文件的标签格式yolo格式:
3 0.385417 0.498148 0.0625 0.033333
数据集是这样子的,同时给出了标签的格式。labels目录中的.txt中的数据是目标检测的标签数据,每一行数据分别对应一个GT框,每一行数据从左到右分别是:id, x,y,w,h。列1 - 目标类别id , 列2 - 目标中心位置x, 列3 - 目标中心位置y, 列4 - 目标宽度w,列5 - 目标高度h。x,y,w,h是小于1的浮点数,因为是经过对图像进行了归一化处理得到的值,也就是目标的真实的x,w值除以图像的宽度,y,h除以图像的高度。
以下代码基本实现了读取数据集使用albumentation数据增强,且可指定类别,具体使用哪些数据增强的方法还需要根据数据集具体分析且修改,代码中只是给出几个例子,没有展示每一个数据增强的方法。
import os
import cv2
import albumentations as A
from tqdm import tqdm
import uuid
class YOLOAug:
def __init__(self, pre_image_path, pre_label_path, aug_save_image_path, aug_save_label_path, labels, classes_to_augment, target_counts):
"""
初始化YOLOAug类的实例。
:param pre_image_path: 原始图像文件夹路径。
:param pre_label_path: 原始标签文件夹路径。
:param aug_save_image_path: 增强后图像保存路径。
:param aug_save_label_path: 增强后标签保存路径。
:param labels: 所有类别的标签列表。
:param classes_to_augment: 需要进行数据增强的类别列表。
:param target_counts: 每个类别增强后的目标数量。
"""
self.pre_image_path = pre_image_path
self.pre_label_path = pre_label_path
self.aug_save_image_path = aug_save_image_path
self.aug_save_label_path = aug_save_label_path
= labels
self.classes_to_augment = classes_to_augment
self.target_counts = target_counts
# 定义数据增强的组合,包括像素级变换和空间级变换
= ([
# #1、Pixel-level transforms(包含了颜色变换、噪声和模糊 )
# AdvancedBlur # (使用随机选择参数的广义正态滤波器)
# Blur # (模糊)
(clip_limit=2.0, tile_grid_size=(8, 8), p=0.05), # CLAHE # (限制对比度自适应直方图均衡化) ####
# ChannelDropout # (随机drop一个或多个通道)
# ChannelShuffle # (通道打乱)
(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.02), # ColorJitter # (色彩抖动【亮度、对比度、饱和度】) ####
# Defocus # (虚焦)
# Downscale # (降质)
(alpha=(0.2, 0.5), strength=(0.2, 0.7), p=0.05), # Emboss # (浮雕效果) ####
(mode='cv', by_channels=True, p=0.05), # Equalize # (直方图均衡) ####
# FDA # (Fourier-Domain-Adaptation,实现简单的风格迁移)
(alpha=0.1, p=0.02), # FancyPCA # (RGB图像色彩增强) ####
# FromFloat # (乘最大值变整型,与ToFloat相反)
# GaussNoise # (高斯噪声)
# GaussianBlur # (高斯模糊)
# GlassBlur # (玻璃模糊)
# HistogramMatching # (直方图匹配,会引起色调变化)
(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.02), # HueSaturationValue # (色调、饱和度、亮度) ####
(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.02), # ISONoise # (传感器噪声) ####
# ImageCompression # (图像压缩)
# InvertImg # (255-img)
(blur_limit=5, p=0.05), # MedianBlur # (中值滤波) ####
# MotionBlur # (运动模糊)
# MultiplicativeNoise # (乘性噪声)
# Normalize # (归一化)
# (p=0.02),# PixelDistributionAdaptation #可能是调整图像中像素分布的技术 ####
# Posterize # (色调分层)
# RGBShift # (RGB每个通道上值偏移)
(brightness_limit=0.2, contrast_limit=0.2, p=0.02),# RandomBrightnessContrast # (亮度、对比度) ####
(p=0.01),# RandomFog # (雾效果) ####
(gamma_limit=(80, 120),p=0.02),# RandomGamma # (gamma变换) ####
(slant_lower=-10, slant_upper=10, drop_length=20,p=0.02),# RandomRain # (下雨效果) ####
(num_shadows_lower=1, num_shadows_upper=3,p=0.02),# RandomShadow # 随机阴影,模拟自然环境中的光影效果 ####
(snow_point_lower=0.1, snow_point_upper=0.3,p=0.02), # RandomSnow # 雪花效果,模拟冬季的气候条件 ####
(flare_roi=(0, 0, 1, 1), angle_lower=0.3, angle_upper=0.7,p=0.02),# RandomSunFlare # (太阳耀斑效果) ####
(scale=0.1,p=0.02),# RandomToneCurve # 调整图像的色调曲线,改变图像的色彩和对比度。 ####
# RingingOvershoot # 在图像中引入振铃效果,常见于图像处理中的过度锐化。
(alpha=(0.2, 0.5), lightness=(0.5, 1.0),p=0.05),# Sharpen # (锐化) ####
# Solarize # (大于阈值反转)
# Spatter # (镜头雨点泥点飞溅效果)
# Superpixels # (超像素)
# TemplateTransform # 模板转换,通过一定的模板改变图像的一部分。
# ToFloat # (除最大值归一化,与FromFloat相反)
# ToGray # (转灰度(三通道))
# ToRGB # (灰度转三通道RGB)
# ToSepia # (加棕褐色滤镜)
# (blur_limit=(3,7),p=0.05),# UnsharpMask # (锐化) ####
# ZoomBlur # (变焦模糊)
# # 2、Spatial-level transforms(包含了几何变换和空间变换。但是没有合成,比如cutout、mosaic)
# Affine # 仿射变换,包括缩放、旋转、平移
# BBoxSafeRandomCrop # (包含所有bboxes的裁剪)
# CenterCrop # (裁剪中心区域)
# CoarseDropout # (矩形区域cutout)
# Crop # (裁切)
# CropAndPad # (裁剪或填充图像边缘)
# CropNonEmptyMaskIfExists # (裁剪+缩放,可以忽略mask部分区域)
# ElasticTransform # (弹性变换)
# Flip # (翻转)
# GridDistortion # (网格畸变)
# GridDropout # (网格状cutout)
# HorizontalFlip # (水平翻转)
# Lambda # 应用自定义函数进行图像变换,为高度自定义的处理提供接口。后续添加合成(cutout、mixup、mosaic)变换使用
# LongestMaxSize # (长边等比例缩放至指定size)
# MaskDropout # (随机抹除目标实例)
# NoOp # (无操作)
# OpticalDistortion # (光学畸变(桶形、枕形))
# PadIfNeeded # (边界填充)
# Perspective # (透视变换)
# PiecewiseAffine # (局部仿射变换,效果类似ElasticTransform,但速度很慢)
# PixelDropout # (随机丢弃像素值)
# RandomCrop # (随机裁剪)
# RandomCropFromBorders # (图像边缘裁剪,会改变尺寸)
# RandomCropNearBBox # (指定rect附近裁剪)
# RandomGridShuffle # (分块打乱)
# RandomResizedCrop # (裁剪+缩放,裁剪区域宽高比随机)
# RandomRotate90 # (随机旋转90度n次,即0°,90°,180°,270°随机旋转)
([
(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),
(p=0.05), # 随机排列通道
# (p=0.3), # 随机改变图像的亮度、对比度、饱和度、色调
(p=0.05), # 随机丢弃通道
], p=0.05),
# (p=0.1), # 随机缩小和放大来降低图像质量
# (p=0.2), # 压印输入图像并将结果与原始图像叠加
],
# yolo: [x_center, y_center, width, height] # 经过归一化
# min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.
# min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox
bbox_params=(format='yolo', min_area=0., min_visibility=0., label_fields=['category_id'],clip=True)
)
def aug_image(self):
"""
对原始图像进行数据增强,并将增强后的图像和标签保存到指定路径。
"""
# 确保保存增强图像和标签的目录存在
(self.aug_save_image_path, exist_ok=True)
(self.aug_save_label_path, exist_ok=True)
# 遍历原始图像文件夹中的所有图像文件
for image_filename in tqdm((self.pre_image_path), desc="Augmenting Images", unit="image"):
# 构造图像和标签的完整路径
image_path = (self.pre_image_path, image_filename)
label_path = (self.pre_label_path, image_filename.replace('.jpg', '.txt'))
# 如果标签文件不存在,则跳过
if not (label_path):
continue
# 读取图片和标签
image = (image_path)
with open(label_path, 'r') as file:
lines = ()
# 解析标签文件中的边界框信息
bboxes = [().split() for line in lines]
category_ids = [int(bbox[0]) for bbox in bboxes]
bboxes = [[float(x) for x in bbox[1:]] for bbox in bboxes]
for i in range(5): # 假设我们为每张原图生成5张增强图
# 应用数据增强
augmented = (image=image, bboxes=bboxes, category_id=category_ids)
new_image = augmented['image']
new_bboxes = augmented['bboxes']
new_category_ids = augmented['category_id']
# 使用uuid生成增强图像和标签文件的唯一名称
unique_id = uuid.uuid4().hex
new_image_filename = f"{image_filename.split('.')[0]}_{unique_id}.jpg"
new_label_filename = f"{image_filename.split('.')[0]}_{unique_id}.txt"
# 保存增强后的图片和标签
((self.aug_save_image_path, new_image_filename), new_image)
with open((self.aug_save_label_path, new_label_filename), 'w') as new_label_file:
for category_id, bbox in zip(new_category_ids, new_bboxes):
new_label_file.write(f"{category_id} {' '.join([f'{b:.6f}' for b in bbox])}\n")
def main():
"""
主函数,用于创建YOLOAug类的实例并调用数据增强方法。
"""
# 创建YOLOAug对象,设置数据增强的参数
yolo_aug = YOLOAug(
pre_image_path= , #数据增强之前的数据集image目录路径
pre_label_path= , #数据增强之前的数据集label目录路径
aug_save_image_path= , #数据增强之后的数据集image目录路径
aug_save_label_path= , #数据增强之后的数据集label目录路径
labels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# classes_to_augment=[0,1,2,3,4,5,6,7,8,9],
# target_counts={0:30000,1:30000,2:30000,3:30000,4:30000,5:30000,6:30000,7:30000,8:30000,9:30000}
classes_to_augment = [6,7],
target_counts = {6:1000,7:1000}
)
# 调用数据增强方法
yolo_aug.aug_image()
# 程序入口,调用main函数
if __name__ == "__main__":
main()