简介:本文详细解析如何使用🤗 Transformers库微调Vision Transformer(ViT)模型进行图像分类,涵盖数据准备、模型选择、训练配置及优化技巧,助力开发者快速实现高性能图像分类器。
在计算机视觉领域,Vision Transformer(ViT)凭借其自注意力机制和全局信息捕捉能力,已成为图像分类任务的重要模型。然而,直接使用预训练ViT模型在新数据集上表现可能受限,而从头训练则需大量计算资源。此时,微调(Fine-tuning)成为高效提升模型性能的关键技术。本文将详细介绍如何使用🤗 Transformers库(Hugging Face Transformers)微调ViT模型,从数据准备、模型选择到训练优化,提供可落地的技术指南。
ViT将图像分割为固定大小的patch(如16×16),通过线性投影转换为序列化的token,输入Transformer编码器进行自注意力计算。与CNN相比,ViT无需依赖局部感受野,而是通过全局注意力捕捉长距离依赖,尤其适合处理复杂场景或细粒度分类任务。
安装依赖库:
pip install transformers torchvision datasets evaluate
支持torchvision.datasets.ImageFolder或自定义Dataset类,需确保数据目录结构如下:
data/train/class1/img1.jpg...class2/...val/class1/...
使用torchvision.transforms增强数据多样性:
from torchvision import transformstrain_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
选择预训练模型(如google/vit-base-patch16-224):
from transformers import ViTForImageClassification, ViTFeatureExtractormodel = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224",num_labels=10, # 分类类别数ignore_mismatched_sizes=True # 允许调整分类头)feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
1e-5至3e-5),避免破坏预训练权重。32或64)。AdamW(带权重衰减的Adam):optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
#### 训练循环实现使用`Trainer`类简化训练流程:```pythonfrom transformers import Trainer, TrainingArgumentstraining_args = TrainingArguments(output_dir="./results",num_train_epochs=10,per_device_train_batch_size=32,per_device_eval_batch_size=64,logging_dir="./logs",logging_steps=10,evaluation_strategy="epoch",save_strategy="epoch",load_best_model_at_end=True)trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=val_dataset,compute_metrics=compute_metrics # 自定义评估函数)trainer.train()
实现准确率计算:
import numpy as npfrom evaluate import loadaccuracy_metric = load("accuracy")def compute_metrics(p):logits, labels = ppredictions = np.argmax(logits, axis=-1)return accuracy_metric.compute(references=labels, predictions=predictions)
训练完成后保存模型:
model.save_pretrained("./saved_model")feature_extractor.save_pretrained("./saved_model")
对ViT的不同层设置差异化学习率:
no_decay = ["bias", "LayerNorm.weight"]optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],"weight_decay": 0.01,"lr": 2e-5},{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],"weight_decay": 0.0,"lr": 2e-5}]optimizer = AdamW(optimizer_grouped_parameters)
启用fp16加速训练并减少显存占用:
training_args = TrainingArguments(fp16=True,# 其他参数...)
google/vit-large-patch16-224。在皮肤癌分类任务中,微调ViT-Base模型:
在钢板表面缺陷检测中,结合ViT与轻量化设计:
AdamW与SGD+动量的对比。通过🤗 Transformers库,开发者可以高效完成ViT模型的微调,快速适应各类图像分类场景。未来,随着ViT与多模态模型的融合,微调技术将进一步推动计算机视觉的边界。