首页 > 其他分享 >学习率调节器:深度学习训练中的关键技术

学习率调节器:深度学习训练中的关键技术

时间:2024-12-07 16:01:39浏览次数:7  
标签:loss 关键技术 调节器 torch 学习 train scheduler test

摘要: 在深度学习的训练过程中,学习率是影响模型性能的关键超参数之一。学习率调节器(Learning Rate Schedulers)是一系列用于动态调整学习率的策略,它们可以帮助模型更快地收敛,提高训练效率,并最终达到更好的性能。本文将探讨学习率调节器的重要性、常见类型以及它们在实际应用中的效果。

引言: 深度学习模型的训练是一个优化问题,目标是最小化损失函数。学习率决定了在每次迭代中参数更新的步长。一个合适的学习率可以加速训练过程,而一个过大或过小的学习率都可能导致训练效率低下或模型性能不佳。因此,如何有效地调节学习率成为了深度学习中的一个研究热点。

学习率调节器的重要性:

  1. 加速收敛: 动态调整学习率可以帮助模型在训练初期快速收敛,而在训练后期则减小步长,以更细致地逼近最优解。
  2. 避免局部最小值: 适当的学习率调整策略有助于模型跳出局部最小值,从而找到全局最小值。
  3. 提高模型性能: 通过优化学习率,可以提高模型在验证集上的性能,减少过拟合的风险。

常见的学习率调节器类型:

  1. 时间衰减(Time-based Decay): 随着训练时间的增加,学习率按照一定的衰减率逐渐减小。
  2. 步长衰减(Step Decay): 在训练过程中的特定步骤,学习率会按照预设的比率下降。
  3. 指数衰减(Exponential Decay): 学习率按照指数函数随时间衰减。
  4. 余弦退火(Cosine Annealing): 学习率按照余弦函数的变化周期性地降低。
  5. 学习率预热(Learning Rate Warmup): 在训练初期,学习率从一个较小的值逐渐增加到预定的学习率,有助于模型稳定。

学习率调节器的实际应用: 在实际应用中,学习率调节器的选择和配置需要根据具体的任务和模型结构来决定。例如,在图像识别任务中,可能需要一个较大的初始学习率来快速收敛,而在自然语言处理任务中,则可能需要一个较小的学习率来避免梯度爆炸。

指数衰减法样例

# 学习率调节器
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import random_split

# 数据生成
def f(x,y):
    return x**2+2*y**2

# 定义初始值
num_samples = 1000  # 样本数量
X = torch.rand(num_samples)  # 均匀分布
Y = torch.rand(num_samples)  # 均匀分布
Z = f(X,Y)+3*torch.randn(num_samples)  # 加上噪声

dataset = torch.stack([X,Y,Z],dim=1)

# 数据划分 按照7:3
train_size = int(0.7*len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_size, test_size])

# 将数据放入 DataLoader
train_loader = DataLoader(TensorDataset(train_dataset.dataset.narrow(1,0,2),train_dataset.dataset.narrow(1,2,1)),batch_size=32,shuffle=True)
test_loader = DataLoader(TensorDataset(test_dataset.dataset.narrow(1,0,2),test_dataset.dataset.narrow(1,2,1)),batch_size=32,shuffle=True)

# 定义模型
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2,8)  # 输入2维,输出8维
        self.fc2 = nn.Linear(8,1)  # 输入8维,输出1维

    def forward(self,x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# 设置超参数
num_epochs = 100  # 训练轮数
learning_rate = 0.01  # 学习率

# 定义损失函数
loss_fn = nn.MSELoss()

# 通过一个训练对比有无学习率调节器的效果
for with_scheduler in [False, True]:
    # 定义训练和测试误差数组
    train_losses = []
    test_losses = []

    # 定义模型
    model = Model()

    # 定义优化器
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    # 定义学习率调节器
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    # 训练模型
    for epoch in range(num_epochs):
        # 训练
        model.train()
        train_loss = 0

        # 遍历训练集
        for inputs,targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs,targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # 计算loss并记录到训练误差数组
        train_loss /= len(train_loader)
        train_losses.append(train_loss)

        # 验证
        model.eval()
        test_loss = 0

        # 遍历测试集
        with torch.no_grad():
            for inputs,targets in test_loader:
                outputs = model(inputs)
                loss = loss_fn(outputs,targets)
                test_loss += loss.item()

        # 计算loss并记录到测试误差数组
        test_loss /= len(test_loader)
        test_losses.append(test_loss)


        # 学习率调节器
        if with_scheduler:
            scheduler.step()

    # 绘制训练和测试误差
    plt.figure(figsize=(8,4))
    plt.plot(range(num_epochs),train_losses,label='train')
    plt.plot(range(num_epochs),test_losses,label='test')
    plt.title('Learning rate scheduler' if with_scheduler else 'No learning rate scheduler')
    plt.legend()
    plt.show()

# 常见的学习率调节器:
# 学习率衰减:scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 每5轮学习率衰减0.1
# 余弦退火:scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0.001)  # 学习率以余弦函数形式衰减,T_max为总训练轮数
# 指数退火:scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)  # 学习率以指数形式衰减,gamma为衰减率
# 多项式衰减:scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)  # 在30轮和80轮时学习率衰减0.1

运行结果可视化:

结论: 学习率调节器是深度学习训练中不可或缺的工具,它们通过动态调整学习率来优化训练过程。选择合适的学习率调节策略可以显著提高模型的性能和训练效率。随着深度学习技术的不断发展,新的学习率调节器策略也在不断涌现,为深度学习模型的训练提供了更多的选择。

标签:loss,关键技术,调节器,torch,学习,train,scheduler,test
From: https://blog.csdn.net/weixin_47012180/article/details/144311648

相关文章

  • RocketMq学习-Producer(三)
    一、Producer启动流程DefaultMQProducer设置了NamesrvAddr地址,需要从nameserver获取broker信息publicstaticvoidmain(String[]args)throwsMQClientException,InterruptedException{ System.setProperty("mqself.home","F:\\rocketmq"); DefaultMQProducerp......
  • 超全致远OA整套视频学习教程及二次开发技转攻略(火)
    引言    致远OA作为国内领先的办公自动化系统,凭借其强大的功能和灵活的二次开发能力,成为众多企业数字化转型的首选平台。为了帮助广大开发者和企业用户更好地掌握致远OA的实施与二次开发技术,我精心准备了超全致远OA视频学习教程及二次开发技转攻略,全面覆盖应用实施、功......
  • 斜率优化dp学习笔记
    斜率优化dp,主要用于转移式子长这样,或者可以经过一定的变形变成这样的式子:由$$f_i=f_j+……$$变化到\[b=y-kx\]其中\(b\)只与\(i\)有关,\(y\)只与\(j\)有关,\(kx\)是一个二次项,其中\(k\)只与\(i\)有关,\(x\)只与\(j\)有关。这里以求最小值为例,即\(f_i=\min\{f_j+……......
  • 2024-2025-1 20241314 《计算机基础与程序设计》第十一周学习总结
    2024-2025-120241314《计算机基础与程序设计》第十一周学习总结作业信息这个作业属于哪个课程2024-2025-1-计算机基础与程序设计这个作业要求在哪里2024-2025-1计算机基础与程序设计第十一周作业这个作业的目标计算机网络......
  • 多项式学习笔记
    多项式学习笔记目录多项式学习笔记多项式乘法逆多项式乘法逆给出\(F(x)\),求\(G(x)\)使得\(F(x)G(x)\equiv1(\bmodx^n)\)。首先\(G_0(x)=\frac{1}{F_0(x)}\),然后考虑倍增,用\(\bmodx^{\left\lceil\frac{n}{2}\right\rceil}\)的答案推\(\bmodx^n\)的答案:\[......
  • MySQL语句学习第三篇_数据库
    MySQL语句学习第三篇_数据库专栏记录MySQL的学习,感谢大家观看。本章的专栏......
  • tidyverse学习笔记——Data Transformation篇
    DataTransformationAssumethatflightsisatibblewith336,776rowsand19columns.RowsOperatorsfliterfliter()keepsrowsbasedonthevaluesofthecolumns.flights|>fliter(a==1&b>1|c==1|d%in%c(1,2))arrangearrange(......
  • 吉林大学2024机器学习A期末知识点归纳(第二章,线性回归)
            首先,要理解,线性模型是机器学习中的一种模型。公式就如图所示。而当我们输入样本,最终得到的是一个数,也就是我们所谓的预测结果y_hat。(它是监督学习,所以使用的数据集都是有数据标签y的)。        但如果到此为止,我们就无法对模型进行修改,迭代。静态的......
  • HarmonyOS学习:项目进度列表实战
    示例图一、布局单位在我们布局中,经常会采用px来作为布局的一个尺寸参考单位,这个单位在浏览器中已经是布局的标准。在鸿蒙开发中,提出了一些新的单位用于布局。物理像素:一般用px来表示。逻辑像素:在布局的时候,底层针对物理像素和屏幕的尺寸关系进行了转化的中间层。分辨率:......
  • HarmonyOS学习Day03
    #学习视频:bilibili蜗牛学苑#一、ArkTS实战ArkTS主要负责页面上数据维护、交互、以及基础属性的使用。ArkUI负责页面的布局。ArkTS负责对组件的数据、事件等等进行维护组件的参数在使用ArkUI进行布局的时候,组件采用括号的方式来引入使用Column(参数){//存放子元素}......