星星博客 »  > 

mmaction2 性能指标相关源码解析

文章目录

    • 0. 前言
    • 1. 训练时性能指标
    • 2. 验证/测试时性能指标

0. 前言

  • 想实现一个TubeDataset,要实现性能指标相关功能。
    • 之前一直都没有仔细研究过相关源码,趁这个机会都看一下。
  • 从结构上看,性能指标相关源码可以分为:
    • 训练时性能指标:无法通过配置文件配置,每类任务(分类、定位、检测)都有固定的性能指标展示
    • 验证时性能指标:通过 EvalHook 实现,EvalHook中的核心流程与测试时完全一致。
      • 配置文件中 evaluation 选项的参数会传递到 dataset.evaluate 中。
    • 测试时性能指标:先通过 single_gpu_test/multi_gpu_test 获取预测结果,然后再调用各类数据集提供的 evaluate 函数
      • dataset.evaluate 参数通过 tools/test.py 的命令行参数指定

1. 训练时性能指标

  • 展示形式:通过 log_config 相关配置可以显示性能指标

    • log_config 本质就是使用 mmcv 中的 logger hook,在每个 iter/epoch 的结尾都会调用 log 函数展示数据。
  • 那么问题来了,要展示的数据从哪里来呢?

    • LoggerHook 的源码 中表明要从展示数据保存在 runner 中。
    • TextLoggerHook为例,要展示的内容(如epoch/iter/mode等)很多是写死的。有一些可以自定义,主要就是 runner.log_buffer.output 中保存的内容。具体参考源码
    • EpochBaseRunner为例,logger_buffer 中的数据主要更新自模型类 batch_processor/train_step/val_step 函数的输出结果,即 outputs['log_vars']。具体参考源码。
    • 要寻找 outputs['log_vars'] 终于要回到MMAction2了,最终定位到各种 head 的 loss 函数中。
  • 在MMAction2中,具体要展示的数据有哪些?换句话说,就是看看各种headloss函数都输出啥内容。

    • 行为识别任务中,显示 loss 以及 top1/top5 accuracy
    • 在时空行为检测任务中,显示 prec@top3、prec@top5、recall@top3、recall@top5
  • 结论:训练时展示的性能指标无法通过配置文件配置,都是写死的

2. 验证/测试时性能指标

  • 验证展示形式:在每一类(epoch)结束的时候,通过 logger hook 显示验证集结果。
  • 数据集验证都是通过 EvalHook 实现的,参考源码
    • 配置 Eaval Hook 就是通过配置文件中的 evaluation 设置的,如下距离
    • 该hook的主要作用就是在训练完一轮后在验证集上获取性能指标。
    • 验证的基本流程就是:先通过 single_gpu_test/multi_gpu_test 获取预测结果,然后再调用各类数据集提供的 evaluate 函数
    • 注意,验证结果都也都会保存到 runner.logger_buffer 中,参考源码
# 行为识别举例
evaluation = dict(
    interval=1, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))

# 时空行为检测举例
evaluation = dict(interval=1, save_best='mAP@0.5IOU')
  • 测试时就更加直接了,就是先 single_gpu_test/multi_gpu_testdataset.evaluate
  • 行为识别数据集的验证与测试,参考源码
    • 函数格式如下,metrics 可以是字符串或列表,支持的操作包括 'top_k_accuracy', 'mean_class_accuracy', 'mean_average_precision','mmit_mean_average_precision'
    • metrics_options是一个字典,key就是 metrics ,value是字典、就是对应函数的参数
    def evaluate(self,
                 results,
                 metrics='top_k_accuracy',
                 metric_options=dict(top_k_accuracy=dict(topk=(1, 5))),
                 logger=None,
                 **deprecated_kwargs):
  • 时空行为检测数据集的验证,参考源码
    • 目前只支持一个数值,所以 metrics 不用修改
    def evaluate(self,
                 results,
                 metrics=('mAP', ),
                 metric_options=None,
                 logger=None):

相关文章