线性层

ydky / 2023-08-19 / 原文

线性层

线性层结构

线性连接是全连接的一种形式,但全连接不一定是线性连接。全连接层可以使用非线性激活函数,而线性连接只进行简单的线性映射。线性连接如下图:

神经网络中的使用

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
  • in_features (in) – size of each input sample
  • out_features (int) – size of each output sample
  • bias (bool) – If set to False, the layer will not learn an additive bias. Default: True

代码实现

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader

# 使用CIFAR10数据集
dataset = torchvision.datasets.CIFAR10("./dataset2", train=False, transform=torchvision.transforms.ToTensor())

dataloader = DataLoader(dataset, batch_size=64, drop_last=True)

class Baselinear(nn.Module):
    def __init__(self):
        super(Baselinear, self).__init__()
        self.linear1 = Linear(196608, 10)

    def forward(self, input):
        output = self.linear1(input)
        return output

baselinear = Baselinear()

for data in dataloader:
    imgs, targets = data
    print(imgs.shape)
    # output = torch.reshape(imgs, (1, 1, 1, -1))
    # flatten是将imgs图片进行展平,与上述的reshape作用相同
    output = torch.flatten(imgs)
    print(output.shape)
    output = baselinear(output)
    print(output.shape)

使用reshape的输出结果:

![image-20230808122914489](C:\User

使用flatten的输出结果: