ST-GCN:自建行为识别数据集训练指南

作者:KAKAKA2024.08.29 17:56浏览量:28

简介:本文详细介绍了如何使用ST-GCN(时空图卷积网络)训练自建的行为识别数据集,包括数据准备、转换、模型配置、训练及测试等步骤,为初学者提供了清晰易懂的实践指导。

ST-GCN:自建行为识别数据集训练指南

引言

行为识别是计算机视觉领域的一个重要研究方向,广泛应用于视频监控、人机交互、体育分析等场景。ST-GCN(Spatial Temporal Graph Convolutional Networks)作为一种基于图卷积的行为识别模型,因其能够有效地捕捉人体关节间的空间和时间依赖关系而备受关注。本文将详细介绍如何使用ST-GCN训练自建的行为识别数据集。

一、环境配置与代码下载

首先,你需要下载ST-GCN的官方实现代码,并配置好相应的环境。可以通过以下命令克隆ST-GCN的GitHub仓库,并安装必要的依赖项:

  1. git clone https://github.com/yysijie/st-gcn.git
  2. cd st-gcn
  3. pip install -r requirements.txt
  4. cd torchlight
  5. python setup.py install
  6. cd ..

二、准备行为数据

在训练之前,你需要准备自己的行为数据集。数据集应包含视频文件以及对应的标注信息(如关键帧的人体姿态、行为标签等)。你可以参考常用的行为识别数据集(如Kinetics-skeleton、NTU RGB+D等)的格式来整理你的数据。

三、数据转换

ST-GCN的训练代码提供了数据转换的脚本(如kinetics_gendata.py),用于将原始数据集转换为模型训练所需的格式(如npy和pkl文件)。然而,由于你的数据集是自建的,因此你需要修改这些脚本以适配你的数据集格式。

主要修改点包括:

  • 数据读取路径
  • 关键点个数
  • 观测人数(num_person_in)和输出人数(num_person_out)
  • 最大帧数(max_frame)

例如,在kinetics_gendata.py中,你可能需要修改以下参数:

  1. num_person_in = 5 # 观察前5个人
  2. num_person_out = 2 # 选择分数最高的2个人
  3. max_frame = 300 # 最大帧数
  4. num_joints = 18 # 关键点个数(根据你的数据集可能不同)

同时,你还需要在feeder_kinetics.py中修改相应的参数,以确保数据加载的正确性。

四、添加Layout

如果你的数据集使用的人体姿态关键点个数与ST-GCN默认的不同(默认为18个关键点),你需要在graph.py中的get_edge函数中添加一个新的Layout。这个Layout定义了关键点之间的连接关系,是图卷积网络的关键部分。

  1. elif layout == 'my_pose':
  2. self.num_node = 20 # 假设你的数据集有20个关键点
  3. self_link = [(i, i) for i in range(self.num_node)]
  4. # 定义关键点之间的连接关系...
  5. self.edge = self_link + neighbor_link

五、修改训练参数

train.yaml配置文件中,你需要修改与你的数据集相关的参数,如数据路径(data_pathlabel_path)、行为类别数(num_class)、Layout类别(layout)等。此外,你还可以根据需要调整学习率、batch size、迭代次数等训练参数。

六、开始训练

一切准备就绪后,你可以使用以下命令开始训练模型:

  1. python main.py recognition -c config/st_gcn/kinetics-skeleton/train.yaml

注意将配置文件路径替换为你自己的配置文件路径。

训练过程中,模型会定期保存(默认每10个epoch保存一次),你可以在work_dir目录下找到保存的模型文件。

七、模型测试

训练完成后,你可以使用测试集对模型进行测试。测试过程与训练过程类似,但你需要使用测试配置文件(如test.yaml)并指定模型权重文件。

  1. python main.py recognition -c config/st_gcn/kinetics-skeleton/test.yaml --weights path_to_your_model.pth