MATLAB环境下基于机器学习的合成数据生成方法

时间:2024-03-21 12:01:18

合成数据是通过计算机程序人工生成的数据,而不是由真实事件生成的数据。采用合成数据来增加训练数据,可以节省数据采集费用,或满足隐私要求。随着计算能力的提高和云数据存储选项的崛起,合成数据比以往更容易获取。这无疑是一个积极的发展:合成数据推动了AI解决方案的开发,从而更好地为所有终端用户服务。

以浅层神经网络(Shallow neural network)为例,运行环境为MATLAB R2021B,代码如下:

clear;
% Load the original dataset
load fisheriris.mat;
original_data = meas';
Classes=3; % Number of Classes
Target(1:50)=1;Target(51:100)=2;Target(101:150)=3;Target=Target'; % Original labels
% Normalize the data to [0, 1] range
min_val = min(original_data, [], 2);
max_val = max(original_data, [], 2);
normalized_data = (original_data - min_val) ./ (max_val - min_val);

% Define autoencoder architecture
input_size = size(normalized_data, 1);
hidden_size = 5; % Size of the compressed representation
output_size = input_size;

autoencoder = feedforwardnet(hidden_size);
% 'trainlm'	    Levenberg-Marquardt
% 'trainbr' 	Bayesian Regularization (good)
% 'trainrp'  	Resilient Backpropagation
% 'traincgf'	Fletcher-Powell Conjugate Gradient
% 'trainoss'	One Step Secant (good)
% 'traingd' 	Gradient Descent
autoencoder.trainFcn = 'trainoss'; % You can choose different training algorithms
autoencoder = train(autoencoder, normalized_data, normalized_data);

% Generate synthetic data using the trained autoencoder

num_samples = 500; % Number of generating samples

synthetic_data_normalized = rand(input_size, num_samples);
synthetic_data_normalized = autoencoder(synthetic_data_normalized);

% Denormalize synthetic data
synthetic_data = synthetic_data_normalized .* (max_val - min_val) + min_val;
synthetic_data_normalized=synthetic_data_normalized';

original_data=original_data';
synthetic_data=synthetic_data';
Syn=synthetic_data;

%% Getting labels of synthetic generated data by K-means clustering
[Lbl,C,sumd,D] = kmeans(Syn,Classes,'MaxIter',10000,...
    'Display','final','Replicates',10);
SynAll= [Syn Lbl];
SynSort = sortrows(SynAll,5);
Syn=SynSort(:,1:end-1);
Lbl=SynSort(:,end);

%% Plot data and classes
Feature1=1;
Feature2=4;
f1=meas(:,Feature1); % feature 1
f2=meas(:,Feature2); % feature 2
ff1=Syn(:,Feature1); % feature 1
ff2=Syn(:,Feature2); % feature 2
figure('units','normalized','outerposition',[0 0 1 1])
subplot(3,2,1)
area(meas, 'linewidth',1); title('Original Data');
ax = gca; ax.FontSize = 12; ax.FontWeight='bold'; grid on;
subplot(3,2,2)
area(Syn, 'linewidth',1); title('Synthetic Data');
ax = gca; ax.FontSize = 12; ax.FontWeight='bold'; grid on;
subplot(3,2,3)
gscatter(f1,f2,Target,'rgm','.',20); title('Original');
ax = gca; ax.FontSize = 12; ax.FontWeight='bold'; grid on;
subplot(3,2,4)
gscatter(ff1,ff2,Lbl,'rgm','.',20); title('Synthetic');
ax = gca; ax.FontSize = 12; ax.FontWeight='bold'; grid on;
subplot(3,2,[5 6])
histogram(meas, 'Normalization', 'probability', 'DisplayName', 'Original Data');
hold on;
histogram(Syn, 'Normalization', 'probability', 'DisplayName', 'Synthetic Data');
legend('Original','Synthetic')

%% Train and Test
% Training Synthetic dataset by SVM
Mdlsvm  = fitcecoc(Syn,Lbl); CVMdlsvm = crossval(Mdlsvm); 
SVMError = kfoldLoss(CVMdlsvm);
SVMAccAugTrain = (1 - SVMError)*100;
% Predict new samples (the whole original dataset)
[label5,score5,cost5] = predict(Mdlsvm,meas);

%% Test error and accuracy calculations
DataSize=size(meas);DataSize=DataSize(1,1);
a=0;b=0;c=0;
for i=1:DataSize
if label5(i)== 1
a=a+1;
elseif label5(i)==2
b=b+1;
else
label5(i)==3
c=c+1;
end;end;
erra=abs(a-50);errb=abs(b-50);errc=abs(c-50);
err=erra+errb+errc;TestErr=err*100/DataSize;
SVMAccAugTest=100-TestErr; % Test Accuracy

%% Train and Test Accuracy Results
AugRessvm = [' SDG Train SVM "',num2str(SVMAccAugTrain),'" SDG Test SVM"', num2str(SVMAccAugTest),'"'];
disp(AugRessvm);

出图如下:

此外,其他的机器学习的合成数据生成方法如下。

MATLAB环境下基于一维生成对抗网络1D-GAN的合成数据生成

MATLAB环境下基于差分进化算法的合成数据生成

MATLAB环境下基于序列蒙特卡罗方法的合成数据生成

MATLAB环境下基于马尔可夫链蒙特卡罗方法的合成数据生成

MATLAB环境下基于SMOTE方法的合成数据生成

MATLAB环境下基于遗传算法的合成数据生成

MATLAB环境下基于非线性自回归外生输入(ARX)模型的合成数据生成(SDG)方法

MATLAB环境下基于高斯混合模型GMM分布的合成数据生成方法

MATLAB环境下基于Vanilla GAN的合成数据生成方法

程序运行环境为MATLAB R2023A,执行基于Vanilla GAN的合成数据生成。

注意:MATLAB 版本不能低于R2023A,否则将运行失败。https://mbd.pub/o/bread/mbd-ZZyclplv

完整代码可见:MATLAB环境下基于机器学习的合成数据生成方法