随机梯度下降算法求解SVM

时间:2022-09-08 19:34:35

随机梯度下降算法求解SVM

测试代码(matlab)如下:

clear;
load E:\dataset\USPS\USPS.mat;
% data format:
% Xtr n1*dim
% Xte n2*dim
% Ytr n1*1
% Yte n2*1
% warning: labels must range from 1 to n, n is the number of labels
% other label values will make mistakes
u=unique(Ytr);
Nclass=length(u);

allw=[];allb=[];
step=0.01;C=0.1;
param.iterations=1;
param.lambda=1e-3;
param.biaScale=1;
param.t0=100;

tic;
for classname=1:1:Nclass
temp_Ytr=change_label(Ytr,classname);
[w,b] = sgd_svm(Xtr,temp_Ytr, param);
allw=[allw;w];
allb=[allb;b];
fprintf('class %d is done \n', classname);
end

[accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb);
fprintf(' accuracy is %.2f percent.\n' , accuracy*100 );
toc;

function [temp_Ytr] = change_label(Ytr,classname)
temp_Ytr=Ytr;
tep2=find(Ytr~=classname);
tep1=find(Ytr==classname);
temp_Ytr(tep2)=-1;
temp_Ytr(tep1)= 1;

function [true_W,b]=sgd_svm(X,Y,param)
% input:
% X is n*dim
% Y is n*1 (label is 1 or 0)
% output:
% true_W is dim*1 ,so the score is X*W'+b
% b is 1*1 number
iterations=param.iterations;%10
lambda=param.lambda;%1e-3
biaScale=param.biaScale;%0
t0=param.t0;%100
t=t0;

w=zeros(1,size(X,2));
bias=0;

for k=1:1:iterations
for i=1:1:size(X,1)
t=t+1;
alpha = (1.0/(lambda*t));
if(Y(i)*(X(i,:)*w'+bias)<1)
bias=bias+alpha*Y(i)*biaScale;
w=w+alpha*Y(i,1).*X(i,:);
end
end
end
b=bias;
true_W=w;

function [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb)
% allw is nclass * dim
% allb is nclass * 1
% Yte must range from 1 to nclass, other label values will make mistakes
score = Xte * allw'+repmat(allb',[size(Bte,1),1]);
[bb c]=sort(score,2,'descend');
predict_label=c(:,1);
temp = predict_label((predict_label-Yte)==0);
right=size( temp,1 );
accuracy=right/size(Yte,1);

随机梯度下降算法求解SVM的更多相关文章

  1. 监督学习:随机梯度下降算法(sgd)和批梯度下降算法(bgd)

    线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...

  2. 监督学习——随机梯度下降算法(sgd)和批梯度下降算法(bgd)

    线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...

  3. tensorflow随机梯度下降算法使用滑动平均模型

    在采用随机梯度下降算法训练神经网络时,使用滑动平均模型可以提高最终模型在测试集数据上的表现.在Tensflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模 ...

  4. Logistic回归Cost函数和J&lpar;θ&rpar;的推导(二)----梯度下降算法求解最小值

    前言 在上一篇随笔里,我们讲了Logistic回归cost函数的推导过程.接下来的算法求解使用如下的cost函数形式: 简单回顾一下几个变量的含义: 表1 cost函数解释 x(i) 每个样本数据点在 ...

  5. 机器学习---用python实现最小二乘线性回归算法并用随机梯度下降法求解 (Machine Learning Least Squares Linear Regression Application SGD)

    在<机器学习---线性回归(Machine Learning Linear Regression)>一文中,我们主要介绍了最小二乘线性回归算法以及简单地介绍了梯度下降法.现在,让我们来实践 ...

  6. 机器学习算法(优化)之一:梯度下降算法、随机梯度下降(应用于线性回归、Logistic回归等等)

    本文介绍了机器学习中基本的优化算法—梯度下降算法和随机梯度下降算法,以及实际应用到线性回归.Logistic回归.矩阵分解推荐算法等ML中. 梯度下降算法基本公式 常见的符号说明和损失函数 X :所有 ...

  7. 梯度下降算法对比(批量下降&sol;随机下降&sol;mini-batch)

    大规模机器学习: 线性回归的梯度下降算法:Batch gradient descent(每次更新使用全部的训练样本) 批量梯度下降算法(Batch gradient descent): 每计算一次梯度 ...

  8. 对数几率回归法(梯度下降法,随机梯度下降与牛顿法)与线性判别法&lpar;LDA&rpar;

    本文主要使用了对数几率回归法与线性判别法(LDA)对数据集(西瓜3.0)进行分类.其中在对数几率回归法中,求解最优权重W时,分别使用梯度下降法,随机梯度下降与牛顿法. 代码如下: #!/usr/bin ...

  9. 随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比、实现对比&lbrack;转&rsqb;

    梯度下降(GD)是最小化风险函数.损失函数的一种常用方法,随机梯度下降和批量梯度下降是两种迭代求解思路,下面从公式和实现的角度对两者进行分析,如有哪个方面写的不对,希望网友纠正. 下面的h(x)是要拟 ...

随机推荐

  1. IT干货

    最近翻看Redis相关的中文书籍时,发现了很多错误,包括翻译错误及理论错误,因此想搜集一些相关的外文书籍看看.以下几个链接,内容大同小异,均可免费下载相关的英文书籍PDF版,内容涵盖了IT的方方面面. ...

  2. java类型转换

    //java类型转换public class Demo2 { public static void main(String[] args){ int num1 = 55; int num2 =77; ...

  3. n全排列输出和 n个数的组合&lpar;数字范围a~b&rpar;

    n全排列输出: int WPermutation(int num, bool bRepeat) num表示num全排列 bRepeat标志是否产生重复元素的序列. int Permutation(in ...

  4. jQuery的常用事件

    1.$(document).ready() $(document).ready()是jQuery中响应JavaScript内置的onload事件并执行任务的一种典型方式.它和onload具有类似的效果 ...

  5. 【转载】jxl操作excel 字体 背景色 合并单元格 列宽等 &period;

    package com.email.jav; import java.io.File;import java.io.IOException;import java.net.URL; import jx ...

  6. 【Python实战】使用Python连接Teradata数据库???未完成

    1.安装Python 方法详见:[Python 05]Python开发环境搭建 2.安装Teradata客户端ODBC驱动 安装包地址:TTU下载地址 (1)安装TeraGSS和tdicu(ODBC依 ...

  7. 好用的UI设计工具

    51yuansu  好用的在线画UI图工具    51yuansu.com processon.com processon在线画图工具,程序流程图及UI设计原型图,脑图等 draw.io的PC版画图工 ...

  8. php 基于redis计数器类

    本文引自网络 Redis是一个开源的使用ANSI C语言编写.支持网络.可基于内存亦可持久化的日志型.Key-Value数据库,并提供多种语言的API. 本文将使用其incr(自增),get(获取), ...

  9. ubuntu16&period;04更新内核--使用4&period;6以上的内核会让用A卡的Dell电脑更快--及卸载多余内核

    tips:我自己就是Dell的A卡电脑,用16.04桌面感觉不如fedora流畅,后来手动升级到4.6.2内核,发现可以和fedora与windows一般桌面操作流畅度. 我试过了4.7的开发版内核, ...

  10. ASP&period;NET MVC使用IoC

    也许你会问ASP.NET MVC为什么会爱上IoC? 相爱的理由常常很简单,就像一首歌中所唱——“只为相遇那一个眼神”. 而ASP.NET MVC爱上IoC只为IoC能实现MVC控制器的依赖注入. 下 ...