简介:本文深入解析CRNN(Convolutional Recurrent Neural Network)文字识别技术的核心原理,结合代码实现与优化策略,为开发者提供从理论到实践的全流程指导。
在OCR(Optical Character Recognition)领域,传统方法依赖人工设计的特征提取(如SIFT、HOG)和分类器(如SVM),存在对复杂场景适应性差、需要大量预处理步骤等痛点。CRNN(Convolutional Recurrent Neural Network)作为深度学习时代的代表性方案,通过融合卷积神经网络(CNN)的局部特征提取能力和循环神经网络(RNN)的时序建模能力,实现了端到端的文字识别,尤其擅长处理不定长、非规则排版的文本。
CRNN由三部分组成:
import torchimport torch.nn as nnclass CRNN(nn.Module):def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):super(CRNN, self).__init__()assert imgH % 32 == 0, 'imgH must be a multiple of 32'# CNN部分(示例为简化版)self.cnn = nn.Sequential(nn.Conv2d(nc, 64, 3, 1, 1),nn.ReLU(),nn.MaxPool2d(2, 2),# 更多卷积层...)# RNN部分(双向LSTM)self.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, input):# CNN特征提取conv = self.cnn(input)b, c, h, w = conv.size()assert h == 1, "the height of conv must be 1"conv = conv.squeeze(2) # [b, c, w]conv = conv.permute(2, 0, 1) # [w, b, c]# RNN序列建模output = self.rnn(conv)return outputclass BidirectionalLSTM(nn.Module):def __init__(self, nIn, nHidden, nOut):super(BidirectionalLSTM, self).__init__()self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)self.embedding = nn.Linear(nHidden * 2, nOut)def forward(self, input):recurrent, _ = self.rnn(input)T, b, h = recurrent.size()t_rec = recurrent.view(T * b, h)output = self.embedding(t_rec)output = output.view(T, b, -1)return output
CTC解决了“输入序列(特征图宽度)与输出序列(字符数)长度不一致”的核心问题。其核心思想是通过引入空白标签(-)和重复字符合并规则,将所有可能的路径对齐方式映射到最终标签。
y(形状为 [T, nclass],T为时间步,nclass为字符类别数)。
criterion = nn.CTCLoss() # PyTorch内置CTC损失# 假设:# - predictions: RNN输出 [T, batch_size, nclass]# - targets: 真实标签 [sum(target_lengths)]# - input_lengths: 每个样本的时间步长度 [batch_size]# - target_lengths: 每个标签的长度 [batch_size]loss = criterion(predictions, targets, input_lengths, target_lengths)
[-1, 1] 或 [0, 1]。a→1, b→2, ..., 空白→0)。
from torchvision import transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到[-1,1]])# 自定义数据集类class OCRDataset(Dataset):def __init__(self, img_paths, labels):self.img_paths = img_pathsself.labels = labelsdef __getitem__(self, idx):img = Image.open(self.img_paths[idx]).convert('L') # 转为灰度img = transform(img)label = self.labels[idx]return img, labeldef __len__(self):return len(self.img_paths)
ReduceLROnPlateau 动态调整学习率。BatchNorm2d 加速收敛。torch.nn.utils.clip_grad_norm_)。
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)for epoch in range(epochs):model.train()for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)# 假设已计算input_lengths和target_lengthsloss = criterion(output, target, input_lengths, target_lengths)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)optimizer.step()# 验证阶段计算准确率,并更新学习率val_loss = validate(model, val_loader)scheduler.step(val_loss)
torch.quantization 将FP32模型转为INT8,减少内存占用。
dummy_input = torch.randn(1, 1, 32, 100) # 假设输入为32x100的灰度图torch.onnx.export(model, dummy_input, "crnn.onnx",input_names=["input"], output_names=["output"],dynamic_axes={"input": {0: "batch_size", 3: "width"}, "output": {0: "width"}})
CRNN通过其端到端的架构设计和对不定长文本的适应性,已成为OCR领域的核心方案。开发者可通过调整网络深度、引入注意力机制或优化部署流程,进一步满足不同场景的需求。