【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression.cpp

时间:2022-10-21 09:51:51

模型实现代码,关键是train函数和predict函数,都很容易。

#include <iostream>
#include <string>
#include <math.h>
#include "LogisticRegression.h"
using namespace std; LogisticRegression::LogisticRegression(
int size, // N
int in, // n_in
int out // n_out
)
{
N = size;
n_in = in;
n_out = out; // initialize W, b
// W[n_out][n_in], b[n_out]
W = new double*[n_out];
for(int i=0; i<n_out; i++)
W[i] = new double[n_in];
b = new double[n_out]; for(int i=0; i<n_out; i++)
{
for(int j=0; j<n_in; j++)
{
W[i][j] = 0;
}
b[i] = 0;
}
} LogisticRegression::~LogisticRegression()
{
for(int i=0; i<n_out; i++)
delete[] W[i];
delete[] W;
delete[] b;
} void LogisticRegression::train (
int *x, // the input from input nodes in training set
int *y, // the output from output nodes in training set
double lr // the learning rate
)
{
// the probability of P(y|x)
double *p_y_given_x = new double[n_out];
// the tmp variable which is not necessary being an array
double *dy = new double[n_out]; // step 1: calculate the output of softmax given input
for(int i=0; i<n_out; i++)
{
// initialize
p_y_given_x[i] = 0;
for(int j=0; j<n_in; j++)
{
// the weight of networks
p_y_given_x[i] += W[i][j] * x[j];
}
// the bias
p_y_given_x[i] += b[i];
}
// the softmax value
softmax(p_y_given_x); // step 2: update the weight of networks
// w_new = w_old + learningRate * differential (导数)
// = w_old + learningRate * x (1{y_i=y} - p_yi_given_x)
// = w_old + learningRate * x * (y - p_y_given_x)
for(int i=0; i<n_out; i++)
{
dy[i] = y[i] - p_y_given_x[i];
for(int j=0; j<n_in; j++)
{
W[i][j] += lr * dy[i] * x[j] / N;
}
b[i] += lr * dy[i] / N;
}
delete[] p_y_given_x;
delete[] dy;
} void LogisticRegression::softmax (double *x)
{
double max = 0.0;
double sum = 0.0; // step1: get the max in the X vector
for(int i=0; i<n_out; i++)
if(max < x[i])
max = x[i];
// step 2: normalization and softmax
// normalize -- 'x[i]-max', it's not necessary in traditional LR.
// I wonder why it appears here?
for(int i=0; i<n_out; i++)
{
x[i] = exp(x[i] - max);
sum += x[i];
}
for(int i=0; i<n_out; i++)
x[i] /= sum;
} void LogisticRegression::predict(
int *x, // the input from input nodes in testing set
double *y // the calculated softmax probability
)
{
// get the softmax output value given the current networks
for(int i=0; i<n_out; i++)
{
y[i] = 0;
for(int j=0; j<n_in; j++)
{
y[i] += W[i][j] * x[j];
}
y[i] += b[i];
} softmax(y);
}

【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression.cpp的更多相关文章

  1. 【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression&period;h

    继续看yusugomori的代码,看逻辑回归.在DBN(Deep Blief Network)中,下面几层是RBM,最上层就是LR了.关于回归.二类回归.以及逻辑回归,资料就是前面转的几篇.套路就是设 ...

  2. 【deep learning学习笔记】注释yusugomori的DA代码 --- dA&period;h

    DA就是“Denoising Autoencoders”的缩写.继续给yusugomori做注释,边注释边学习.看了一些DA的材料,基本上都在前面“转载”了.学习中间总有个疑问:DA和RBM到底啥区别 ...

  3. 【deep learning学习笔记】注释yusugomori的RBM代码 --- 头文件

    百度了半天yusugomori,也不知道他是谁.不过这位老兄写了deep learning的代码,包括RBM.逻辑回归.DBN.autoencoder等,实现语言包括c.c++.java.python ...

  4. &lbrack;置顶&rsqb;&NewLine; Deep Learning 学习笔记

    一.文章来由 好久没写原创博客了,一直处于学习新知识的阶段.来新加坡也有一个星期,搞定签证.入学等杂事之后,今天上午与导师确定了接下来的研究任务,我平时基本也是把博客当作联机版的云笔记~~如果有写的不 ...

  5. Deep Learning 学习笔记(8):自编码器&lpar; Autoencoders &rpar;

    之前的笔记,算不上是 Deep Learning, 只是为理解Deep Learning 而需要学习的基础知识, 从下面开始,我会把我学习UFDL的笔记写出来 #主要是给自己用的,所以其他人不一定看得 ...

  6. 【deep learning学习笔记】Recommending music on Spotify with deep learning

    主要内容: Spotify是个类似酷我音乐的音乐站点.做个性化音乐推荐和音乐消费.作者利用deep learning结合协同过滤来做音乐推荐. 详细内容: 1. 协同过滤 基本原理:某两个用户听的歌曲 ...

  7. Neural Networks and Deep Learning学习笔记ch1 - 神经网络

    近期開始看一些深度学习的资料.想学习一下深度学习的基础知识.找到了一个比較好的tutorial,Neural Networks and Deep Learning,认真看完了之后觉得收获还是非常多的. ...

  8. paper 149&colon;Deep Learning 学习笔记(一)

     1. 直接上手篇 *李宏毅教授写的,<1天搞懂深度学习> slideshare的链接: http://www.slideshare.net/tw_dsconf/ss-62245351? ...

  9. Deep Learning 学习笔记——第9章

    总览: 本章所讲的知识点包括>>>> 1.描述卷积操作 2.解释使用卷积的原因 3.描述pooling操作 4.卷积在实践应用中的变化形式 5.卷积如何适应输入数据 6.CNN ...

随机推荐

  1. js,jq新增元素 ,on绑定事件无效

    在jquery1.7之后,建议使用on来绑定事件. $('.upload a').on('click',function(){ $(this).remove(); }) 在DOM渲染的时候,也就是ht ...

  2. nginx设置反向代理后,页面上的js css文件无法加载

    问题现象: nginx配置反向代理后,网页可以正常访问,但是页面上的js css文件无法加载,页面样式乱了. (1)nginx配置如下: (2)域名访问:js css文件无法加载: (3)IP访问:j ...

  3. 使用Jquery&plus;EasyUI 进行框架项目开发案例讲解之五 模块(菜单)管理源码分享

    http://www.cnblogs.com/huyong/p/3454012.html 使用Jquery+EasyUI 进行框架项目开发案例讲解之五  模块(菜单)管理源码分享    在上四篇文章 ...

  4. MySQL查询测试经验

    测试表geoinfo,整个表超过1100万行,表结构: CREATE TABLE `geoinfo` ( `objectid` ) NOT NULL AUTO_INCREMENT , `latitud ...

  5. Android学习----AndroidManifest&period;xml文件解析

    一个Android应用程序的结构: 一.关于AndroidManifest.xml AndroidManifest.xml 是每个android程序中必须的文件.它位于整个项目的根目录,描述了pack ...

  6. Network in Network

     论文要点: 用更有效的非线性函数逼近器(MLP,multilayer perceptron)代替 GLM 以增强局部模型的抽象能力.抽象能力指的模型中特征是对于同一概念的变体的不变形. 使用 gl ...

  7. 如何在Anaconda中实现多版本python共存

    anaconda中Python版本是3.5,因为爬虫原因,需要Python2.7版本,因此,希望能在anaconda中Python3和Python2共存. 1. 打开Anaconda Prompt,可 ...

  8. 计算机编码--c语言中输出float的十六进制和二进制编码

    c语言中没有可以直接打印float类型数据的二进制或者十六进制编码的输出格式, 因此,需要单独给个函数,如下: unsigned int float2hexRepr(float* a){ unsign ...

  9. Linux 基础知识(一) shell的&amp&semi;&amp&semi;和&vert;&vert; 简单使用

    shell 在执行某个命令的时候,会返回一个返回值,该返回值保存在 shell 变量 $? 中.当 $? == 0 时,表示执行成功:当 $? == 1 时,表示执行失败.  有时候,下一条命令依赖前 ...

  10. mysql复杂查询&lpar;一&rpar;

    所谓复杂查询,指涉及多个表.具有嵌套等复杂结构的查询.这里简要介绍典型的几种复杂查询格式. 一.连接查询 连接是区别关系与非关系系统的最重要的标志.通过连接运算符可以实现多个表查询.连接查询主要包括内 ...