基础功能

KamalEngine 自下而上分为三层,最底层的训练引擎负责整个训练流程的管理调度,中间层定义了任务描述以及评价指标,最顶层则是各类算法的具体实现。本节内容主要介绍各个层的核心组件,并通过举例的方式说明其原理和用途。

KamalEngine 架构
KamalEngine 架构

训练引擎

训练引擎主要包括循环(Loop)、事件(Event)、回调(Callback)三类组件,并由此抽象出训练器(Trainer)。

1. 训练器(Trainer)

在KAE中,网络的训练是由 Trainer 执行的,Trainer 是整个训练的中心调度器,负责组织数据处理、模型更新以及各种模型验证和可视化功能。

每个 Trainer 包含两个主要接口,分别对应了初始化以及训练执行的功能:

  • setup:用于初始化 Trainer,将用户传入的模型、任务、数据集、超参等信息注册到训练器中
  • run:执行训练,并控制训练总迭代次数
class BasicTrainer(Engine):
def __init__( self, logger=None, tb_writer=None):
super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer)
def setup(self, model: torch.nn.Module, task: tasks.Task,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device=None):
def run( self, max_iter, start_iter=0, epoch_length=None):
... # enter train loop

在训练模型的过程中,我们首先初始化一个 Trainer,并调用 setup 接口来设置参数,再通过调用 run 来执行训练。在 KamalEngine 中,所有的算法接口都遵循这一约定。

trainer.setup( model=model, task=task,
dataloader=train_loader,
optimizer=optim,
device=device )
trainer.run(start_iter=0, max_iter=TOTAL_ITERS)

每个 Trainer 中都包含一个 state 属性,用于保存训练过程中的状态信息,用户可以对 state 对象进行访问或修改。该对象包含以下默认提供的状态信息:

  • iter:当前迭代轮数
  • max_iter:最大迭代轮数
  • metrics:训练过程中的评价指标字典,包括训练损失、验证结果等
  • epoch_length:每个周期的迭代次数
  • dataloader:可迭代的数据读取对象
  • seed:随机种子
  • batch:当前读取的数据
  • current_epoch:当前处于第几个周期
  • max_epoch:最大的周期数
  • current_batch_index:当前周期中数据的数据批次序号
  • max_batch_index:每个周期包含的数据批次总数
print(trainer.state.iter) # 打印当前迭代轮数
print(trainer.state.metrics) # 打印当前评价指标

2. 事件(Event)与回调(Callback)

在训练的过程中,如果用户需要执行自定义操作,如验证、可视化、记录日志等,就需要通过回调实现。回调是可调用的(Callable)函数或者 Python 对象,它需要能够接收一个 Trainer 对象作为输入,如下所示:

def my_callback(trainer):
print(trainer.state.metrics) # 打印当前评价指标

回调通常会和某一个事件(Event)进行绑定,从而在事件发生时被触发。引擎中包含8 种事件:

  • 训练前(BEFORE_RUN)

  • 训练后(AFTER_RUN)

  • Epoch之前(BEFORE_EPOCH)

  • Epoch之后(AFTER_EPOCH)

  • 每一步迭代前(BEFORE_STEP)

  • 每一步迭代后(AFTER_STEP)

  • 获取数据前(BEFORE_GET_BATCH)

  • 获取数据后(AFTER_GET_BATCH)

我们可以将回调函数注册在某个事件上,并控制回调触发的条件:

trainer.add_callback(
engine.DefaultEvents.AFTER_STEP(every=10), #每10次迭代执行一次回调
callbacks=my_callback )

由此,我们在运行 run 接口进行训练时,注册的回调函数会自动被调用。引擎中主要提供一下几类预置回调:

  • EvalAndCkpt:验证模型并保存参数
  • MetricsLogging:对模型评价指标进行记录,存储至文本文件或显示到 Tensorboard 进行可视化
  • VisualizeOutput:对模型输出可视化

任务定义与模型评价

1. 任务(Task)

在 KamalEngine 中,任务是一个 Python 对象,提供了损失函数计算和结果预测的功能。

class Task(object):
def __init__(self, name):
self.name = name
@abc.abstractmethod
def get_loss( self, outputs, targets ) -> Dict:
pass
@abc.abstractmethod
def predict(self, outputs) -> Any:
pass

在训练过程中,Task 对象根据模型输出以及数据标签计算损失,损失函数的输出为一个字典,保存了损失项的名称和数值。用户可以通过继承 kamal.tasks.Task 实现自定义任务,也可以直接使用 kamal.tasks.StandardTasks 中预设的任务描述。

2. 验证器(Evaluator)与指标(Metrics)

指标的主要作用是评估模型的性能,在不同的任务上我们所需要的指标是不同的,如在语义分割上我们通常使用 mIoU 来评价分割的效果。KamalEngine 中的指标是一个包含以下三个接口的对象:

  • update:更新指标的累积状态
  • get_results:计算累积的指标结果
  • reset:重置指标对象的状态,清除先前累积的结果
class Metric(ABC):
def __init__(self, attach_to=None):
self._attach = AttachTo(attach_to)
@abstractmethod
def update(self, pred, target):
...
@abstractmethod
def get_results(self):
...
@abstractmethod
def reset(self):
...

下面以一个具体例子来说明其用途。我们可以使用 kamal.metrics.Accuracy 这一指标计算模型的分类精度:

acc_metric = kamal.metrics.Accuracy() # 1. 初始化指标
acc_metric.reset() # 2. 重置
for data, target in dataloader:
output = model(data)
acc_metric.update(output, target) # 3. 累积预测结果
print("Accuracy=%.4f"%(acc_metric.get_results())) # 4. 计算总体精度

指标的使用包含4 个步骤,即初始化一个指标,重置内部状态,在迭代中更新状态信息,最后获得结果。然而当我们需要同时使用多个指标时,上述实现就变得比较繁琐且难以管理。KamalEngine 提供了 Evaluator 和 MetricCompose 来管理多个指标。我们首先使用 MetricCompose 组装多个指标对象,并设置合适的名称,然后将指标和测试数据一同传递给 Evaluator 用于模型评估。以下样例展示了用于语义分割任务的验证器构建:

# 语义分割指标,包含精度、mIou和混淆矩阵
confusion_matrix = metrics.ConfusionMatrix(num_classes=13, ignore_idx=255)
metric = metrics.MetricCompose({
'acc': metrics.Accuracy(), #指标名称:指标的实际对象
'cm': confusion_matrix, # 混淆矩阵
'mIoU': metrics.mIoU(confusion_matrix) # mIoU
})
# 构建验证器
evaluator = engine.evaluator.BasicEvaluator(
dataloader=val_loader, metric=metric
)
# 验证模型
results = evaluator.eval( model )

算法实现

引擎的最顶层主要包含各类算法的实现,算法实质上是对训练器(Trainer)的封装,能够与上述提及的所有组件进行交互。具体的算法内容我们在下一节中进行详细讨论。

多任务训练

KamalEngine 通过 attach_to 参数来处理多任务训练。多任务可以看作多个独立任务的组合,一个多任务网络包含了多个输出,因此我们需要指定各个损失函数、评估标准所对应的模型输出。比如在深度估计、语义分割多任务下,我们需要对深度估计的结果计算均方根误差,而对语义分割结果计算精度,这一功能可以通过简单地设置 attch_to 参数实现:

confusion_matrix = metrics.ConfusionMatrix(num_classes=13, ignore_idx=255, attach_to=0)
metric = metrics.MetricCompose({ 'acc': metrics.Accuracy(attach_to=0),
'cm': confusion_matrix,
'mIoU': metrics.mIoU(confusion_matrix),
'rmse': metrics.RootMeanSquaredError(attach_to=1)})
evaluator = engine.evaluator.BasicEvaluator( dataloader=val_loader, metric=metric)
task = [ kamal.tasks.StandardTask.segmentation(attach_to=0),
kamal.tasks.StandardTask.monocular_depth(attach_to=1) ]

上述样例定义了一个评估器(Evaluator) 和一组任务(Task),支持深度估计和语义分割多任务训练。我们通过 attach_to 参数指定对应的输出序号,在这里序号0对应了语义分割的输出,而序号1则对应深度估计的输出。除了指定序号外, attach_to 还可以接收一个可调用函数,下面给出与 attach_to=0 等价的函数实现:

def attach_to_the_first_output(outputs, targets):
return outputs[0], targets[0]

视觉任务工具包

知识重组引擎提供了一组视觉任务工具包,包括模型库、数据集以及数据增强功能。模型库和数据集提供了常用模型和数据读取的实现。数据增强模块 SyncTransforms 提供了支持多任务的同步数据增强功能,使得多个输入样本和标签能够以一致的方式进行增强。下面的样例给出了一种包括「随机裁剪」和「随机反转」的同步数据增强,其输入依次是原始图片、语义标签、深度标签。

transforms=sT.Compose([
sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST), sT.Resize(240)),
sT.Sync( sT.RandomCrop(240), sT.RandomCrop(240), sT.RandomCrop(240) ),
sT.Sync( sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip() ),
sT.Multi( sT.ToTensor(), sT.ToTensor( normalize=False, dtype=torch.long ), sT.ToTensor( normalize=False, dtype=torch.float )),
sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()), sT.Lambda( lambd=lambda x: x/1e3 ) )
])
读取并处理NYUv2数据集
读取并处理NYUv2数据集
Last updated on