NetTrain

NetTrain[net,{input1output1,input2output2,}]

通过给出 inputi 作为输入,使用自动选择的损失函数最小化 outputi 和网络的实际输出之间的差异训练指定的神经网络.

NetTrain[net,port1{data11,data12,},port2{},]

通过在指定端口提供训练数据训练指定神经网络.

NetTrain[net,"dataset"]

训练来自于 Wolfram 数据存储库的已命名数据集.

NetTrain[net,f]

在训练过程中调用函数 f 以产生成批的训练数据.

NetTrain[net,data,prop]

给出与训练会话的具体属性 prop 关联的数据.

NetTrain[net,data,All]

给出一个总结训练会话信息的 NetTrainResultsObject[].

更多信息和选项

  • NetTrain 用于教导神经网络识别模式,并通过根据输入数据和正确输出调整其参数来进行预测.
  • 在训练期间,使用梯度下降等优化算法调整网络的参数(例如权重和偏差),以最小化预测输出与实际输出之间的差异,从而随着时间的推移提高网络的准确性.
  • 网络的任意一个形状不固定的输入端口将从训练数据的形式推断而得,同时,如果训练数据含有 Image 对象等,将会添加 NetEncoder 对象.
  • data 可采用的形式包括:
  • "dataset"已命名数据集
    {input1output1,}输入和输出之间的 Rule 列表
    {input1,}->{output1,}输入和对应输出之间的 Rule
    {port1,,}由指定端口的输入构成的关联的列表
    port1{data11,data12,},一个关联,给出指定端口的输入列表
    Dataset[]数据集对象
    f创建训练用批数据的函数
  • 个别训练数据输入可以是标量、向量、数值数组. 如果网络附加有合适的 NetEncoder 对象,输入可以包括 Image 对象、字符串等.
  • 命名数据集一般用作神经网络应用的范例,如下所示:
  • "MNIST"60,000 分类的手写数字
    "FashionMNIST"60,000 衣服分类图像
    "CIFAR-10","CIFAR-100"50,000 真实世界对象的分类图像
    "MovieReview"10,662 条含有情感极性的影评片段
  • 如果 ResourceObject["dataset"] 不存在,在已命名数据集上训练相当于在 ResourceData["dataset","TrainingData"]ExampleData[{"MachineLearning","dataset"},"TrainingData"] 上训练. 如果已命名数据集被用于 ValidationSet 选项,则相当于 ResourceData["dataset","TestData"]ExampleData[{"MachineLearning","dataset"},"TestData"].
  • 当用规范 {input1output1,} 给出训练数据时,网络不应含有任何损失层,应该正好只有一个输入和输出端口.
  • 指定训练数据的其他格式包括 {input1,input2,}->{output1,}{port1,port2,,port1,,}.
  • 当损失层是由 NetTrain 自动附加到输出端口上时,它们的 "Target" 端口将取自训练数据,使用与原来输出端口同样的名称.
  • 支持下列选项:
  • BatchSize Automatic一次处理多少个实例
    LearningRate Automatic调整权重最小化损失的速率
    LearningRateMultipliers Automatic设定网络内的相对学习速率
    LossFunction Automatic访问输出的损失函数
    MaxTrainingRounds Automatic遍历训练数据多少次
    Method Automatic所用的训练方法
    PerformanceGoalAutomatic具有特定优势的偏好设置
    TargetDevice "CPU"执行训练的目标设备
    TimeGoal Automatic训练的秒数
    TrainingProgressMeasurements Automatic训练期间监控、跟踪和绘图的测量
    TrainingProgressCheckpointing None怎样定期保存已部分训练过的网络
    RandomSeeding1234如何内部播种伪随机生成器
    TrainingProgressFunction None训练过程中周期性调用的函数
    TrainingProgressReporting Automatic训练过程中怎样汇报进度
    TrainingStoppingCriterion None如何自动停止训练
    TrainingUpdateSchedule Automatic何时更新特定部分的网络
    ValidationSet None训练中用于计算模型的数据集
    WorkingPrecision Automatic浮点计算的精度
  • 如果没有使用 LossFunction 明显给出损失,损失函数会基于网络中的最终层或各层自动选择.
  • 如果采用默认设置 BatchSize->Automatic,批次大小将根据网络的内存要求和目标设备上可用的内存自动选择. 自动选择的最大批次大小为 64.
  • 默认设置为 MaxTrainingRounds->Automatic, 大约每 20 秒进行一次训练,但是不会超过 10,000 次.
  • 当设置为 MaxTrainingRounds->n,训练会发生 n 次,其中,一次定义为遍历整个训练数据集.
  • 可以给出下列 ValidationSet 的设置:
  • None只使用现有训练集来估计损失(缺省)
    data验证集的形式和训练数据一样
    Scaled[frac]保留部分训练集用来进行验证
    {spec,"Interval"int}指定计算验证损失的间隔
  • 对于 ValidationSet->{spec,"Interval"->int},间隔可以是整数 n,表示每 n 轮训练,或以秒、分钟、小时为单位的 Quantity 时间后计算一次验证损失.
  • 对于命名数据集,例如 "MNIST",指定 ValidationSet->Automatic 会使用对应的 "TestData" 内容元素.
  • 如果验证集已被指定,NetTrain 将返回训练期间相对于该集合给出最低验证损失的网络.
  • NetTrain[net,f] 中,函数 f 被应用于 <|"BatchSize"n,"Round"r|> 来产生形式为 {input1->output1,}<|"port1"->data,|> 的训练数据.
  • NetTrain[net,{f,"RoundLength"->n}] 可用于指定在训练时应用 f 多次产生大约 n 个范例. 默认情况下每次训练应用 f 一次.
  • 为了计算验证损失和准确性,NetTrain[net,,ValidationSet->{g,"RoundLength"->n}] 可用于指定函数 g 应该以与 NetTrain[net,{f,"RoundLength"->n}] 等价的方式应用产生大约 n 个范例.
  • TargetDevice 可取的设置包括:
  • "CPU"在 CPU 上训练
    "GPU"在 CUDA(兼容 GPU)上训练
  • "GPU" 设置被解析为 "CUDA". 目前不支持其他设置.
  • WorkingPrecision 的可能设置包括:
  • "Real32"使用单精度实数 (32-bit)
    "Real64"使用双精度实数 (64-bit)
    "Mixed"某些运算使用半精度实数
  • WorkingPrecision->"Mixed" 只支持 TargetDevice->"GPU",在某些设备上可以导致显著的性能增加.
  • NetTrain[net,data,loss,prop] 中,属性 prop 可以是下列形式之一:
  • "TrainedNet"找到的最佳的训练好的网络(默认)
    "BatchesPerRound"每轮包含多少批次
    "BatchLossList"每个批量更新的平均损失的列表
    "BatchMeasurementsLists"每个批量更新的训练度量关联的列表
    "BatchPermutation"用于填充每个批次的训练数据的索引数组
    "BatchSize"BatchSize 的有效值
    "BestValidationRound"与最终训练好的网络对应的训练回合
    "CheckpointingFiles"训练期间产生的检查点文件列表
    "ExampleLosses"训练期间每个样例接受的损失
    "ExamplesProcessed"训练期间处理的样例总数
    "FinalLearningRate"训练结束时的学习率
    "FinalNet"在训练过程中生成的最新网络,无论其在验证集或其他指标上的表现如何
    "FinalPlots"所有损失和度量图的关联
    "InitialLearningRate"训练开始时的学习率
    "LossPlot"平均训练损失的演变图
    "MeanBatchesPerSecond"每秒处理的平均批数
    "MeanExamplesPerSecond"每秒处理的输入样例的平均数
    "NetTrainInputForm"表示对 NetTrain 的始发调用的表达式
    "OptimizationMethod"使用的优化方法的名称
    "Properties"可用属性列表
    "ReasonTrainingStopped"为什么训练停止的简短描述
    "ResultsObject"NetTrainResultsObject[] 包含表格中大部分可用属性
    "RoundLoss"最新回合的平均损失
    "RoundLossList"每轮的平均损失列表
    "RoundMeasurements"最新回合训练度量的关联
    "RoundMeasurementsLists"每轮的训练度量关联的列表
    "RoundPositions"对应于每轮度量的 batch number
    "TargetDevice"用于训练的设备
    "TotalBatches"训练期间碰到的总批数
    "TotalRounds"训练的总轮数
    "TotalTrainingTime"花在训练上的总时间(以秒为单位)
    "TrainingExamples"训练集中的样例数
    "TrainingNet"准备进行训练的网络
    "TrainingUpdateSchedule"TrainingUpdateSchedule 的值
    "ValidationExamples"验证集中的样例数
    "ValidationLoss"对于最新的验证度量,在 ValidationSet 上获取的平均损失
    "ValidationLossList"每次验证度量在 ValidationSet 上的平均损失的列表
    "ValidationMeasurements"最新验证度量后 ValidationSet 上训练度量的关联
    "ValidationMeasurementsLists"每次验证度量在 ValidationSet 上的训练度量关联的列表
    "ValidationPositions"与每次验证度量对应的 batch number
    "WeightsLearningRateMultipliers"用于每个权重的学习率乘子的关联
  • 格式 <|"Property"->prop,"Form"->form,"Interval"->int|> 的关联可用于指定自定义属性,其值在训练时会被重复收集.
  • 对于自定义属性,prop 的有效设置可以是 TrainingProgressFunction 上的任何可用属性,或给定所有属性关联的用户定义函数. 表单的有效设置包括 "List""TransposedList""Plot". "Interval" 的有效设置可以是 "Batch""Round"Quantity[]. 支持的单位包括 "Batches""Rounds""Percent" 和时间单位,比如,"Seconds""Minutes""Hours".
  • NetTrain[net,data,loss,{prop1,prop2,}] 返回 propi 的结果列表.
  • NetTrain[net,data,All] 返回一个 NetTrainResultsObject[] 包含不需要显著额外计算或内存的所有属性值.
  • 使用 ValidationSet->None 的默认设置,"TrainedNet" 属性会在训练结束时生成网络. 当提供验证集时,选择最佳网络的默认标准取决于网络的类型:
  • classification net选择带有最低错误率的网络;使用最低损失打破关系
    non-classification net选择带有最低损失的网络
  • 用于选择 "TrainedNet" 属性的标准可以使用 TrainingStoppingCriterion 选项自定义.
  • 属性 "BestValidationRound" 给出了选择最终网络的精确回合.
  • Method 的可能设置包括:
  • "ADAM"使用对梯度的对角重定标不变的自适应学习速率的随机梯度下降
    "RMSProp"使用从梯度幅度的指数平滑平均值导出的自适应学习速率的随机梯度下降
    "SGD"普通的带有动量的随机梯度下降
    "SignSGD"随机梯度下降,其中,梯度幅度被丢弃
  • PerformanceGoal 的有效设置包括 Automatic"TrainingMemory""TrainingSpeed" 或目标列表组合.
  • WorkingPrecision 的有效设置包括 "Real32" 的默认值,表示单精度浮点;"Real64" 表示双精度浮点;"Mixed" 表示 "Real32" 和半精度的混合. 混合精度的训练只支持 GPU .
  • 可以用 Method{"method",opt1val1,} 来指定特定方法的子选项. 所有方法都可使用的子选项有:
  • "LearningRateSchedule"Automatic如何按照训练进程调整学习速率
    "L2Regularization"None与所有习得数组的 L2 范数关联的全局损失
    "GradientClipping"None梯度将被剪切的幅值下限
    "WeightClipping"None大于该值的幅值的权重应被截掉
  • 当设置为 "LearningRateSchedule"->f 时,将用 initial*f[batch,total] 来计算给定批次的学习速率,其中 batch 是当前的批号,total 是训练期间将要处理的总批次数,initial 是用 LearningRate 选项指定的初始训练速率. 由 f 返回的值应是 01 之间的一个数字.
  • 可以用下列形式给出子选项 "L2Regularization""GradientClipping""WeightClipping"
  • r网络中的所有权重都使用数值 r
    {lspec1r1,lspec2r2,}对网络中的特定部分 lspeci 使用数值 ri
  • LearningRateMultipliers 的同样形式给出规则 lspeciri.
  • 对于方法 "SGD",额外支持下列子选项
  • "Momentum"0.93在更新导数时保留多少前一步的结果
  • 对于方法 "ADAM",额外支持下列子选项:
  • "Beta1"0.9第一动量估计的指数衰减率
    "Beta2"0.999第二动量估计的指数衰减率
    "Epsilon"0.00001`稳定性参数
  • 对于方法 "RMSProp",额外支持下列子选项:
  • "Beta"0.95梯度幅度移动平均值的指数衰减率
    "Epsilon"0.000001稳定性参数
    "Momentum"0.9动量项
  • 对于方法 "SignSGD",支持以下其他子选项:
  • "Momentum"0.93当更新导数时,保留前一步的程度
  • 如果网络已经含有初始化过的或之前已训练过的权重,训练开始之前 NetTrain 不会重新进行初始化.

范例

打开所有单元关闭所有单元

基本范例  (6)

在 input output 上训练单层线性网络:

预测新输入的输出值:

一次做多个预测:

预测是输入的线性函数:

训练一个把输入分类为 TrueFalse 的感知器:

预测一个新输入是 True 还是 False

通过禁用 NetDecoder 获取输入为 True 的概率:

一次做多个预测:

绘制作为输入的函数的概率:

训练一个三层网络来学习二维函数:

在输入上评估网络:

绘制作为 xy 的函数的网络的预测:

训练一个预测输入序列中最大值的循环网络:

在输入上评估网络:

在序列中一个元素变化的情况下绘制网络的输出:

训练一个网络并产生一个总结训练进程的结果对象:

从结果中获取训练网络:

获取用于训练神经网络的网络:

使用命名数据集和模型训练分类手写数字:

分类不同的数字:

范围  (14)

网络  (3)

训练单层网络:

训练线性网络链:

训练网络层的有向图:

数据格式  (7)

用规则列表制定训练数据,每一个规则代表一个输入和相应的目标输出:

用规则指定训练数据,其左侧是输入列表,右侧是相应目标输出的列表:

用关联指定训练数据,其中的键为端口名称:

用关联列表指定训练数据,每一个关联表示一个训练实例:

使用命名数据集训练网络:

Dataset 指定训练数据,其中的一行为一个实例:

在训练过程中生成训练数据. 首先创建要训练的网络:

定义生成器,用规则列表的形式产生一批实例:

用特定 BatchSize 的生成器进行训练:

指定生成器每轮应该被调用 4 次,每轮产生 64 个例子:

属性  (4)

获取一个用于训练会话的 NetTrainResultsObject

查询特定属性的结果对象:

获取调用 NetTrain 的原始格式:

获取所有可用属性:

获取在训练卷积网络期间每个数组梯度幅度的演变图:

在 MNIST 数据集上训练并记录单个例子随时间的损失:

绘制一段时间内与单个示例相关的损失:

通过计算每个范例的平均损失计算最困难的范例,并接受 20 个最大的索引,例如平均损失:

显示数字 "0"、"3"、"8" 和 "9" 的错误率和损失的平均演变:

创建最后图形的自定义版本:

检查原始损失图:

只使用回合于验证度量创建新的损失图:

创建没有对数缩放的新的损失图:

选项  (27)

BatchSize  (1)

大的批次通常会增加每秒能计算的实例的数量:

小的批次带来的结果是每秒计算的实例会少一些:

取决于任务和网络,更大的批尺寸会允许更高的学习率,也可以允许更有效地使用可用的硬件.

LearningRate  (1)

用 0.01 学习率训练网络:

LearningRateMultipliers  (1)

创建并初始化一个三层网络,但只训练最后一层:

在输入上评估训练过的网络:

起始于零偏差的初始网络的第一层:

第一层偏差在训练网络中仍然为零:

第三层偏差被训练了:

LossFunction  (4)

MeanSquaredLossLayer 训练一个简单网络,在损失不是由 SoftmaxLayer 产生的时候应用缺省的损失:

在一组输入上运行训练过的网络:

指定不同的损失层,将其附加到网络的输出. 先创建损失层:

损失层接受一个输入和目标,然后给出损失:

在训练中使用该损失层:

在一组输入上运行训练过的网络:

创建一个网络,接受长度为 2 的向量,产生类别 LessGreater 中的一个:

NetTrain 会自动使用带有合适类别编码器的 CrossEntropyLossLayer 对象:

在一组输入上运行训练过的网络:

创建显式损失层,其中,目标的形式为每个类别的概率:

训练数据应由概率向量组成,而不是符号是类别:

在一组输入上运行训练过的网络:

从一个将要被训练的 "evaluation net" 开始:

创建一个 "loss net",显式计算评估网络(此处,自定义损失等价于 MeanSquaredLossLayer)的损失:

在合成数据上训练该网络,指定名为 "Loss" 的输出端口应被诠释为损失:

NetExtract 获取训练过的 "evaluation" 网络:

也可以使用 Part 语句:

在平面上绘制网络的输出:

创建一个计算输出和多个显式损失的训练网络:

该网络要求一个输入和一个目标端口来产生输出和损失:

该网络需要一个输入和目标以便产生输出和损失:

训练后在一个数据对上测量损失:

用特定输出作为损失进行训练,忽略其他输出:

保留想要的输出和输入,用 NetTake 移除其他输出:

在一组输入上运行该网络:

MaxTrainingRounds  (2)

训练一个网络,使它只使用每个实例一次:

当同时指定 MaxTrainingRoundsTimeGoal,两者中较短的一个将被使用(注意这个例子应该运行两次以避免最初的预处理开销):

Method  (2)

用带有动量的随机梯度下降法训练一个简单的网络:

用初始学习速率指定学习速率进度安排:

用正规化 (regularization) 来防止过拟合 (overfitting). 基于高斯曲线创建合成训练数据:

用和训练数据的数量有关的大量参数训练网络:

所得网络过度拟合数据,除了基本的函数还学到了噪声:

"L2Regularization" 选项会带来与权重参数的平方成正比的损失. 这将造成更稀疏的权重矩阵,因此减轻过拟合问题:

TargetDevice  (1)

如果有支持 CUDA 的卡可用,则使用默认的系统 GPU 训练网络:

如果没有可兼容的 GPU,将返回 Failure 对象:

TimeGoal  (1)

训练网络大约 5 秒:

TrainingProgressCheckpointing  (1)

对一个在 MNIST 数据集上进行训练的卷积网络取定期检查点:

列出创建的所有检查点:

导入最后一个检查点:

TrainingProgressFunction  (1)

TrainingProgressFunction 在文件中追加训练状态的信息. 创建日志文件:

定义函数以在日志文件中追加批次号和损失:

定义训练数据并进行训练:

读取日志文件:

把保存的数据存入 Dataset

绘制损失随时间变化的情况:

TrainingProgressMeasurements  (1)

检查最后验证精度并召回 FashionMNIST 上训练的 LeNet:

动画训练周期上的混淆矩阵:

TrainingProgressReporting  (6)

训练过程中交互式显示训练进度:

训练过程中定期输出训练进度:

显示一个简单的进度表示器:

自定义进度报告:

把训练处理信息编写进文件:

不汇报进度:

TrainingStoppingCriterion  (1)

当验证损失停止改善时,通过停止训练阻止过度拟合. 设置简单的网络以及某些训练和验证数据:

当验证损失停止改善时,使用 TrainingStoppingCriterion 停止训练:

如果在超过 5 轮的训练中验证损失没有改善至少 0.001,使用 TrainingStoppingCriterion 停止训练:

使用回调函数停止训练. 设置网络和数据:

当验证损失高于 1.75,则停止训练:

TrainingUpdateSchedule  (1)

通过交替更新判别器和生成器训练一个 NetGANOperator

ValidationSet  (1)

NetTrain 提供 ValidationSet 来防止过拟合. 基于高斯曲线创建混合训练数据:

用和训练数据的数量有关的大量参数训练网络:

所得网络过度拟合数据,除了基本的函数还学到了噪声:

ValidationSet 选项使 NetTrain 选择训练中实现最低验证损失的网络. NetTrain 将从训练数据中随机选择 20% 的数据作为验证集:

NetTrain 返回的结果为能最好地归纳验证集中的点的网络,用验证损失来度量. 这将惩罚过度拟合,因为出现在训练数据中的噪声与验证集中的噪声无关:

WorkingPrecision  (2)

用 64 位精度训练网络:

用 64 位精度计算训练的网络:

以混合精度训练网络,利用 NVIDIA Tensor Core 等硬件优化:

属性和关系  (2)

NetChain 对象可被用作 NetGraph 中的层:

只有一个输入和一个输出的 NetGraph 对象可被用作 NetChain 对象中的层:

可能存在的问题  (1)

默认情况下,NetTrain 使用 RandomSeeding1234,当重复调用 NetTrain,它会使用同样的随机种子初始化网络:

使用 RandomSeedingAutomatic 确保 NetTrain 的重复调用时使用不同的初始化:

互动范例  (1)

在训练网络以解决最小二乘问题的同时对解进行观察. 首先生成训练数据:

创建拟合数据的网络:

用网络当前行为的动态更新图替换默认的进度面板:

绘制经过 10 秒训练后最终得到的网络:

巧妙范例  (2)

把一幅测试图像转换成一个训练集,其中像素的位置 (x,y) 被映射成颜色值 (r,g,b)

创建一个网络,基于像素的位置预测颜色:

训练网络:

用网络预测整个原始图像:

高维的 嵌入可能会提供更好的预测:

训练该替代网络:

用新的网络预测整幅图像:

使用属性参数的关联格式绘制拟合最小二乘问题的中间曲线. 首先生成训练数据:

创建拟合数据的网络:

训练网络,指定返回的属性是每 100 轮网络计算的:

动画列表以显示随时间收敛的解:

Wolfram Research (2016),NetTrain,Wolfram 语言函数,https://reference.wolfram.com/language/ref/NetTrain.html (更新于 2022 年).

文本

Wolfram Research (2016),NetTrain,Wolfram 语言函数,https://reference.wolfram.com/language/ref/NetTrain.html (更新于 2022 年).

CMS

Wolfram 语言. 2016. "NetTrain." Wolfram 语言与系统参考资料中心. Wolfram Research. 最新版本 2022. https://reference.wolfram.com/language/ref/NetTrain.html.

APA

Wolfram 语言. (2016). NetTrain. Wolfram 语言与系统参考资料中心. 追溯自 https://reference.wolfram.com/language/ref/NetTrain.html 年

BibTeX

@misc{reference.wolfram_2024_nettrain, author="Wolfram Research", title="{NetTrain}", year="2022", howpublished="\url{https://reference.wolfram.com/language/ref/NetTrain.html}", note=[Accessed: 22-November-2024 ]}

BibLaTeX

@online{reference.wolfram_2024_nettrain, organization={Wolfram Research}, title={NetTrain}, year={2022}, url={https://reference.wolfram.com/language/ref/NetTrain.html}, note=[Accessed: 22-November-2024 ]}