神经网络 torch.nn---Non-Linear Activations (ReLU)

时间:2024-06-11 09:20:30

ReLU — PyTorch 2.3 documentation

torch.nn - PyTorch中文文档 (pytorch-cn.readthedocs.io)

非线性变换的目的

  • 非线性变换的目的是为神经网络引入一些非线性特征,使其训练出一些符合各种曲线或各种特征的模型。

  • 换句话来说,如果模型都是直线特征的话,它的泛化能力会不够好

torch.nn.ReLU

torch.nn.ReLU(inplace=False)torch.nn.modules.activation — PyTorch 2.3 documentation

inplace参数:

  • inplace=True,则会自动替换输入时的变量参数。如:input=-1,ReLU(input,implace=True),那么输出后,input=output=0

  • inplace=True,则不替换输入时的变量参数。如:input=-1,ReLU(input,implace=True),那么输出后,input=-1,output=0

作用:

  • input <= 0, output = 0
  • input  >  0,   output = input

计算公式:

程序代码:

示例1:

import torch
from torch import nn
from torch.nn import ReLU

input =torch.tensor([
    [1, -0.5],
    [-1, 3]
])
print(input.shape)

input = torch.reshape(input,(-1,1,2,2))
print(input.shape)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.relu1 = ReLU()  #inplace bool   原数据是否被替换

    def forward(self, input):
        output = self.relu1(input)
        return output

tudui = Tudui()
output = tudui(input)
print(output)

输出:

示例2:

import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=64)
# shuffle 是否打乱   False不打乱
# drop_last 最后一轮数据不够时,是否舍弃 true舍弃

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.sigmoid1 = Sigmoid()  #inplace bool   原数据是否被替换

    def forward(self, input):
        output = self.sigmoid1(input)
        return output

tudui = Tudui()
step = 1
writer = SummaryWriter('logs')
for data in dataloader:
    imgs, targets = data
    writer.add_images('inputs',imgs,step)

    outputs = tudui(imgs)
    writer.add_images("outputs",outputs,step)
    step += 1

writer.close()

在TensorBoard上看输出内容: