首页 > 其他分享 >深度学习基本单元结构与输入输出维度解析

深度学习基本单元结构与输入输出维度解析

时间:2024-11-28 18:02:47浏览次数:3  
标签:14 32 self 输入输出 28 batch 维度 解析 size

深度学习基本单元结构与输入输出维度解析

在深度学习领域,模型的设计和结构是理解其性能和应用的关键。本文将介绍深度学习中的基本单元结构,包括卷积神经网络(CNN)、反卷积(转置卷积)、循环神经网络(RNN)、门控循环单元(GRU)和长短期记忆网络(LSTM),并详细讨论每个单元的输入和输出维度。我们将以 MNIST 数据集为例,展示这些基本单元如何组合在一起构建复杂的模型。
之前的博客:
深入理解 RNN、LSTM 和 GRU:结构、参数与应用
理解 Conv2d 和 ConvTranspose2d 的输入输出特征形状计算

1. 模型结构概述

我们构建的模型包含以下主要部分:

  • 卷积神经网络(CNN)
  • 反卷积(转置卷积)
  • 循环神经网络(RNN)
  • 门控循环单元(GRU)
  • 长短期记忆网络(LSTM)
  • 全连接层

2. 模型代码

以下是实现综合模型的代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# 定义模型
class CombinedModel(nn.Module):
    def __init__(self):
        super(CombinedModel, self).__init__()

        # CNN 部分
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 输入: (1, 28, 28) -> 输出: (32, 28, 28)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 输入: (32, 28, 28) -> 输出: (64, 28, 28)

        # 反卷积部分
        self.deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)  # 输入: (64, 14, 14) -> 输出: (32, 28, 28)

        # RNN 部分
        self.rnn_input_size = 32 * 14 * 14  # 输入到 RNN 的特征数
        self.rnn = nn.RNN(input_size=self.rnn_input_size, hidden_size=128, num_layers=1,
                          batch_first=True)  # 输入: (batch_size, seq_len, input_size)

        # GRU 部分
        self.gru = nn.GRU(input_size=128, hidden_size=64, num_layers=1,
                          batch_first=True)  # 输入: (batch_size, seq_len, input_size)

        # LSTM 部分
        self.lstm = nn.LSTM(input_size=64, hidden_size=32, num_layers=1,
                            batch_first=True)  # 输入: (batch_size, seq_len, input_size)

        # 全连接层
        self.fc = nn.Linear(32, 10)  # 输出: (batch_size, 10)

    def forward(self, x):
        # CNN 部分
        print(f'Input shape: {x.shape}')  # 输入形状: (batch_size, 1, 28, 28)
        x = self.pool(torch.relu(self.conv1(x)))  # 输出: (batch_size, 32, 28, 28)
        print(f'After conv1 and pool: {x.shape}')
        x = self.pool(torch.relu(self.conv2(x)))  # 输出: (batch_size, 64, 14, 14)
        print(f'After conv2 and pool: {x.shape}')

        # 反卷积部分
        x = self.deconv(x)  # 输出: (batch_size, 32, 14, 14)
        print(f'After deconv: {x.shape}')

        # 将数据展平并调整形状以输入到 RNN
        x = x.view(x.size(0), -1)  # 展平为 (batch_size, 32 * 14 * 14)
        print(f'After flattening: {x.shape}')
        x = x.unsqueeze(1)  # 添加序列长度维度,变为 (batch_size, 1, 32 * 14 * 14)
        print(f'After unsqueeze for RNN: {x.shape}')

        # RNN 部分
        x, _ = self.rnn(x)  # 输出: (batch_size, 1, 128)
        print(f'After RNN: {x.shape}')

        # GRU 部分
        x, _ = self.gru(x)  # 输出: (batch_size, 1, 64)
        print(f'After GRU: {x.shape}')

        # LSTM 部分
        x, _ = self.lstm(x)  # 输出: (batch_size, 1, 32)
        print(f'After LSTM: {x.shape}')

        # 取最后一个时间步的输出
        x = x[:, -1, :]  # 输出: (batch_size, 32)
        print(f'After selecting last time step: {x.shape}')

        # 全连接层
        x = self.fc(x)  # 输出: (batch_size, 10)
        print(f'Output shape: {x.shape}')

        return x

# 3. 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # MNIST 数据集的均值和标准差
])

# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# 4. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CombinedModel().to(device)
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器

# 训练过程
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移动到设备
        optimizer.zero_grad()  # 清空梯度
        outputs = model(images)  # 前向传播
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = running_loss / len(train_loader)
    accuracy = 100 * correct / total
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

# 5. 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移动到设备
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

3. 每个基本单元的输入输出维度

3.1 CNN 部分

  1. 输入(batch_size, 1, 28, 28)

    • 这是 MNIST 数据集的输入形状,其中 1 表示单通道(灰度图像),28x28 是图像的高度和宽度。
  2. 卷积层 1

    • self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
    • 输入形状(batch_size, 1, 28, 28)
    • 输出形状(batch_size, 32, 28, 28)
    • 32 个特征图,空间维度保持不变。
  3. 最大池化层 1

    • 输入形状(batch_size, 32, 28, 28)
    • 输出形状(batch_size, 32, 14, 14)
    • 高度和宽度减半。
  4. 卷积层 2

    • self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
    • 输入形状(batch_size, 32, 14, 14)
    • 输出形状(batch_size, 64, 14, 14)
    • 64 个特征图,空间维度保持不变。
  5. 最大池化层 2

    • 输入形状(batch_size, 64, 14, 14)
    • 输出形状(batch_size, 64, 7, 7)
    • 高度和宽度再次减半。

3.2 反卷积部分

  1. 反卷积层
    • self.deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
    • 输入形状(batch_size, 64, 7, 7)
    • 输出形状(batch_size, 32, 14, 14)
    • 高度和宽度翻倍。

3.3 RNN 部分

  1. 展平

    • x = x.view(x.size(0), -1)
    • 输入形状(batch_size, 32, 14, 14)
    • 输出形状(batch_size, 6272) # 这里的 6272 是 32 * 14 * 14
    • 将特征图展平为一个向量。
  2. 添加序列长度维度

    • x = x.unsqueeze(1)
    • 输入形状(batch_size, 6272)
    • 输出形状(batch_size, 1, 6272)
    • 添加序列长度维度,表示只有一个时间步。
  3. RNN

    • 输入形状(batch_size, 1, 6272)
    • 输出形状(batch_size, 1, 128)
    • RNN 输出的隐藏状态,隐藏层大小为 128。

3.4 GRU 和 LSTM 部分

  1. GRU

    • 输入形状(batch_size, 1, 128)
    • 输出形状(batch_size, 1, 64)
    • GRU 输出的隐藏状态,隐藏层大小为 64。
  2. LSTM

    • 输入形状(batch_size, 1, 64)
    • 输出形状(batch_size, 1, 32)
    • LSTM 输出的隐藏状态,隐藏层大小为 32。

3.5 全连接层

  1. 全连接层
    • self.fc = nn.Linear(32, 10)
    • 输入形状(batch_size, 32)
    • 输出形状(batch_size, 10)
    • 最终输出的类别数(10 类,表示 MNIST 的数字 0-9)。

4. 可视化模型结构

from torchinfo import summary
model = CombinedModel()
summary(model, input_size=(64,1, 28, 28))

或者

import torch
import torch.nn as nn
from torchviz import make_dot
model = CombinedModel()
dummy_input = torch.randn(1, 1, 28, 28)  # (batch_size, channels, height, width)
output = model(dummy_input)
dot = make_dot(output, params=dict(model.named_parameters()))
dot.render("model_structure", format="png")  # 生成 model_structure.png

在这里插入图片描述

标签:14,32,self,输入输出,28,batch,维度,解析,size
From: https://blog.csdn.net/qq_34941290/article/details/144116741

相关文章

  • 基于时间维度优化“开源 AI 智能名片 S2B2C 商城小程序”运营策略:提升触达与转化效能
    摘要:随着数字化商业生态的蓬勃发展,“开源AI智能名片S2B2C商城小程序”融合前沿技术与创新商业模式,为企业营销与业务拓展带来新机遇。本文聚焦于用户时间场景维度,深入剖析如何依据不同时段用户行为特征,精准适配运营策略,优化推送机制、功能服务呈现等内容,类比音乐产品及外卖......
  • YASKAWA安川机器人DX100轴控制基板维修解析知识
    ASKAWA安川机器人DX100轴控制基板的维修是一项复杂而精细的工作,要求具备丰富的知识和实践经验。通过与子锐机器人维修联系,希望能企业提供一些有益的参考和帮助,在面对轴板故障时能够迅速准确地找到问题所在并妥善处理。一、YASKAWA安川机器人维修步骤与方法1、故障诊断:通过YASKA......
  • 7. Spring Cloud Sleuth+ZipKin 链路监控的配置详细解析
    7.SpringCloudSleuth+ZipKin链路监控的配置详细解析@目录7.SpringCloudSleuth+ZipKin链路监控的配置详细解析前言:1.SpringCloudSleuth+ZipKin的概述1.1Sleuth/ZipKin是什么?1.2Sleuth和Zipkin的简单关系图:1.3Sleuth工作原理解析2.Sleuth+ZipKin的......
  • TextIn文档解析表格处理模型优化,显著提升表格解析性能
    近期,TextIn通用文档解析最新推出表格处理优化版本。此前版本中,表格解析处理针对有线表格与无线表格预先分类,并基于框线进行模型预测。在运行过程中,我们发现,分类错误问题对表格解析准确率有负面影响。本次优化主要改善了表格识别效果,以统一方案替代有线表格与无线表格分类处理方法,......
  • 6款办公软件全解析:于团队项目可视化管理价值几何?
    在工程管理中,任务可视化工具至关重要。它能提升团队协作效率,增强项目透明度。市面上有诸多工具,如Trello、板栗看板等,各有特色。本深度评测将从功能、易用性、适用场景等多方面,对比分析这些工具,为工程管理团队选出最优解,办公轻松无压力!1、专业项目管理类•PrimaveraP6:作为一款......
  • 骑行抗风噪最好的蓝牙耳机是哪款?精选5大骑行耳机实测解析!
    在快节奏的现代生活中,骑行不仅是一种便捷的出行方式,更是一种健康的生活态度。无论是城市通勤还是户外探险,一副好的蓝牙耳机都能为骑行者带来更好的听觉体验。然而,骑行时面临的最大挑战之一就是风噪问题。强风不仅会影响通话质量,还会降低音乐的清晰度,破坏整体的听觉享受。为了......
  • 这些不同类型的 DNS 记录承担着不同的职责,确保域名能够正确地解析到对应的服务、设备
    DNS(域名系统,DomainNameSystem)是用于将域名(如www.example.com)解析为IP地址的系统,它通过一系列的DNS记录来实现这一过程。不同类型的DNS记录对应不同的功能,下面是常见的几种DNS记录类型:1. A记录(AddressRecord)功能:将域名解析为IPv4地址。示例:CopyCodeexample......
  • git merge底层原理解析
    日常工作中常常会有这样的合并需求:现在我在A分支上,我想把B分支的内容合并上来。合并步骤如下所示1.确保在A分支上运行以下命令,确认当前处于A分支:gitbranch当前分支前会有一个*标记。如果不在A分支上,可以通过以下命令切换:gitcheckoutA2.合并B分支到A......
  • 宠物领养新趋势:SpringBoot技术解析
    第5章系统实现编程人员在搭建的开发环境中,运用编程技术实现本系统设计的各个操作权限的功能。在本节中,就展示部分操作权限的功能与界面。5.1管理员功能实现5.1.1宠物领养管理图5.1即为编码实现的宠物领养管理界面,管理员在该界面中发布需要领养的宠物的资料,可以对宠......
  • iOS系统资源调度机制解析
    在开发高性能iOS应用时,深入了解并合理利用iOS系统的资源调度机制至关重要。资源调度涉及到线程的创建与管理、任务的分配与执行、以及进程优先级的调整等多个方面。本文将重点介绍iOS系统中的核心资源调度机制——GrandCentralDispatch(GCD),并深入探讨其在多线程管理和性能优化中......