主頁 >  其他 > OCR -- 文本識別 -- 實踐篇

OCR -- 文本識別 -- 實踐篇

2023-06-06 09:28:26 其他

OCR -- 文本識別 -- 理論篇
本章將詳細介紹如何基于PaddleOCR完成CRNN文本識別模型的搭建、訓練、評估和預測,資料集采用 icdar 2015,其中訓練集有4468張,測驗集有2077張,

CRNN是基于CTC的演算法,CRNN是較早被提出也是目前工業界應用較多的方法,主要用于識別規則文本,有效快的預測速度,并且因為序列不對齊,不受長度的影響,所以在長文本上有很好的預測效果,中文演算法里的首選

預測原理詳解

2. 預測原理詳解

第一節中 paddleocr 加載訓練好的 CRNN 識別模型進行預測,本節將詳細介紹 CRNN 的原理及流程,

2.1 所屬類別

CRNN 是基于CTC的演算法,在理論部分介紹的分類圖中,處在如下位置,可以看出CRNN主要用于解決規則文本,基于CTC的演算法有較快的預測速度并且很好的適用長文本,因此CRNN是PPOCR選擇的中文識別演算法,

2.2 演算法詳解

CRNN 的網路結構體系如下所示,從下往上分別為卷積層、遞回層和轉錄層三部分:

1)backbone:

卷積網路作為底層的骨干網路,用于從輸入影像中提取特征序列,由于 convmax-poolingelementwise 和激活函式都作用在區域區域上,所以它們是平移不變的,因此,特征映射的每一列對應于原始影像的一個矩形區域(稱為感受野),并且這些矩形區域與它們在特征映射上對應的列從左到右的順序相同,由于CNN需要將輸入的影像縮放到固定的尺寸以滿足其固定的輸入維數,因此它不適合長度變化很大的序列物件,為了更好的支持變長序列,CRNN將backbone最后一層輸出的特征向量送到了RNN層,轉換為序列特征,

2)neck:

遞回層,在卷積網路的基礎上,構建遞回網路,將影像特征轉換為序列特征,預測每個幀的標簽分布,
RNN具有很強的捕獲序列背景關系資訊的能力,使用背景關系線索進行基于影像的序列識別比單獨處理每個像素更有效,以場景文本識別為例,寬字符可能需要幾個連續的幀來充分描述,此外,有些歧義字符在觀察其背景關系時更容易區分,其次,RNN可以將誤差差分反向傳播回卷積層,使網路可以統一訓練,第三,RNN能夠對任意長度的序列進行操作,解決了文本圖片變長的問題,CRNN使用雙層LSTM作為遞回層,解決了長序列訓練程序中的梯度消失和梯度爆炸問題,

3)head:

轉錄層,通過全連接網路和softmax激活函式,將每幀的預測轉換為最終的標簽序列,最后使用 CTC Loss 在無需序列對齊的情況下,完成CNN和RNN的聯合訓練,CTC 有一套特別的合并序列機制,LSTM輸出序列后,需要在時序上分類得到預測結果,可能存在多個時間步對應同一個類別,因此需要對相同結果進行合并,為避免合并本身存在的重復字符,CTC 引入了一個 blank 字符插入在重復字符之間,

2.2 代碼實作

整個網路結構非常簡潔,代碼實作也相對簡單,可以跟隨預測流程依次搭建模塊,本節需要完成:資料輸入、backbone搭建、neck搭建、head搭建,

【資料輸入】

資料送入網路前需要縮放到統一尺寸(3,32,320)【英文一般是(3,32,100),高度 32 效果比較好,】,并完成歸一化處理,這里省略掉訓練時需要的資料增強部分,以單張圖為例展示預處理的必須步驟(原始碼位置):

中文字符,長、寬比,1:1, 10倍 320,

import cv2
import math
import numpy as np

def resize_norm_img(img):
    """
    資料縮放和歸一化
    :param img: 輸入圖片
    """

    # 默認輸入尺寸 英文為 (3,32,100)
    imgC = 3
    imgH = 32   # 高度 32 效果比較好
    imgW = 320  # 中文每個字長寬比為1:1,防止長寬比在resize后被壓縮過小或拉伸過大,識別大概10個字左右,所以寬度 10倍 = 320

    # 圖片的真實高寬
    h, w = img.shape[:2]
    # 圖片真實長寬比
    ratio = w / float(h)

    # 按比例縮放
    if math.ceil(imgH * ratio) > imgW:
        # 如大于默認寬度,則寬度為imgW
        resized_w = imgW
    else:
        # 如小于默認寬度則以圖片真實寬為準
        resized_w = int(math.ceil(imgH * ratio))
    # 縮放
    resized_image = cv2.resize(img, (resized_w, imgH))
    resized_image = resized_image.astype('float32')
    # 歸一化
    resized_image = resized_image.transpose((2, 0, 1)) / 255
    resized_image -= 0.5
    resized_image /= 0.5
    # 對寬度不足的位置,補0
    padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
    padding_im[:, :, 0:resized_w] = resized_image
    # 轉置 padding 后的圖片用于可視化
    draw_img = padding_im.transpose((1,2,0))
    return padding_im, draw_img

【網路結構】

  • backbone

PaddleOCR 使用 MobileNetV3 作為骨干網路,組網順序與網路結構一致,首先定義網路中的公共模塊(原始碼位置):ConvBNLayer、ResidualUnit、make_divisible

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

class ConvBNLayer(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 groups=1,
                 if_act=True,
                 act=None):
        """
        卷積BN層
        :param in_channels: 輸入通道數
        :param out_channels: 輸出通道數
        :param kernel_size: 卷積核尺寸
        :parma stride: 步長大小
        :param padding: 填充大小
        :param groups: 二維卷積層的組數
        :param if_act: 是否添加激活函式
        :param act: 激活函式
        """
        super(ConvBNLayer, self).__init__()
        self.if_act = if_act
        self.act = act
        self.conv = nn.Conv2D(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias_attr=False)

        self.bn = nn.BatchNorm(num_channels=out_channels, act=None)

    def forward(self, x):
        # conv層
        x = self.conv(x)
        # batchnorm層
        x = self.bn(x)
        # 是否使用激活函式
        if self.if_act:
            if self.act == "relu":
                x = F.relu(x)
            elif self.act == "hardswish":
                x = F.hardswish(x)
            else:
                print("The activation function({}) is selected incorrectly.".
                      format(self.act))
                exit()
        return x

class SEModule(nn.Layer):
    def __init__(self, in_channels, reduction=4):
        """
        SE模塊
        :param in_channels: 輸入通道數
        :param reduction: 通道縮放率
        """        
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2D(1)
        self.conv1 = nn.Conv2D(
            in_channels=in_channels,
            out_channels=in_channels // reduction,
            kernel_size=1,
            stride=1,
            padding=0)
        self.conv2 = nn.Conv2D(
            in_channels=in_channels // reduction,
            out_channels=in_channels,
            kernel_size=1,
            stride=1,
            padding=0)

    def forward(self, inputs):
        # 平均池化
        outputs = self.avg_pool(inputs)
        # 第一個卷積層
        outputs = self.conv1(outputs)
        # relu激活函式
        outputs = F.relu(outputs)
        # 第二個卷積層
        outputs = self.conv2(outputs)
        # hardsigmoid 激活函式
        outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
        return inputs * outputs


class ResidualUnit(nn.Layer):
    def __init__(self,
                 in_channels,
                 mid_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 use_se,
                 act=None):
        """
        殘差層
        :param in_channels: 輸入通道數
        :param mid_channels: 中間通道數
        :param out_channels: 輸出通道數
        :param kernel_size: 卷積核尺寸
        :parma stride: 步長大小
        :param use_se: 是否使用se模塊
        :param act: 激活函式
        """ 
        super(ResidualUnit, self).__init__()
        self.if_shortcut = stride == 1 and in_channels == out_channels
        self.if_se = use_se

        self.expand_conv = ConvBNLayer(
            in_channels=in_channels,
            out_channels=mid_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            if_act=True,
            act=act)
        self.bottleneck_conv = ConvBNLayer(
            in_channels=mid_channels,
            out_channels=mid_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=int((kernel_size - 1) // 2),
            groups=mid_channels,
            if_act=True,
            act=act)
        if self.if_se:
            self.mid_se = SEModule(mid_channels)
        self.linear_conv = ConvBNLayer(
            in_channels=mid_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            if_act=False,
            act=None)

    def forward(self, inputs):
        x = self.expand_conv(inputs)
        x = self.bottleneck_conv(x)
        if self.if_se:
            x = self.mid_se(x)
        x = self.linear_conv(x)
        if self.if_shortcut:
            x = paddle.add(inputs, x)
        return x


def make_divisible(v, divisor=8, min_value=https://www.cnblogs.com/vipsoft/p/None):"""
    確保被8整除
    """
    if min_value is None:
        min_value = https://www.cnblogs.com/vipsoft/p/divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

利用公共模塊搭建骨干網路

class MobileNetV3(nn.Layer):
    def __init__(self,
                 in_channels=3,
                 model_name='small',
                 scale=0.5,
                 small_stride=None,
                 disable_se=False,
                 **kwargs):
        super(MobileNetV3, self).__init__()
        self.disable_se = disable_se

        small_stride = [1, 2, 2, 2]

        if model_name == "small":
            cfg = [
                # k, exp, c,  se,     nl,  s,
                [3, 16, 16, True, 'relu', (small_stride[0], 1)],
                [3, 72, 24, False, 'relu', (small_stride[1], 1)],
                [3, 88, 24, False, 'relu', 1],
                [5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
                [5, 240, 40, True, 'hardswish', 1],
                [5, 240, 40, True, 'hardswish', 1],
                [5, 120, 48, True, 'hardswish', 1],
                [5, 144, 48, True, 'hardswish', 1],
                [5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
                [5, 576, 96, True, 'hardswish', 1],
                [5, 576, 96, True, 'hardswish', 1],
            ]
            cls_ch_squeeze = 576
        else:
            raise NotImplementedError("mode[" + model_name +
                                      "_model] is not implemented!")

        supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
        assert scale in supported_scale, \
            "supported scales are {} but input scale is {}".format(supported_scale, scale)

        inplanes = 16
        # conv1
        self.conv1 = ConvBNLayer(
            in_channels=in_channels,
            out_channels=make_divisible(inplanes * scale),
            kernel_size=3,
            stride=2,
            padding=1,
            groups=1,
            if_act=True,
            act='hardswish')
        i = 0
        block_list = []
        inplanes = make_divisible(inplanes * scale)
        for (k, exp, c, se, nl, s) in cfg:
            se = se and not self.disable_se
            block_list.append(
                ResidualUnit(
                    in_channels=inplanes,
                    mid_channels=make_divisible(scale * exp),
                    out_channels=make_divisible(scale * c),
                    kernel_size=k,
                    stride=s,
                    use_se=se,
                    act=nl))
            inplanes = make_divisible(scale * c)
            i += 1
        self.blocks = nn.Sequential(*block_list)

        self.conv2 = ConvBNLayer(
            in_channels=inplanes,
            out_channels=make_divisible(scale * cls_ch_squeeze),
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1,
            if_act=True,
            act='hardswish')

        self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
        self.out_channels = make_divisible(scale * cls_ch_squeeze)

    def forward(self, x):
        x = self.conv1(x)
        x = self.blocks(x)
        x = self.conv2(x)
        x = self.pool(x)
        return x
# 圖片輸入骨干網路
backbone = MobileNetV3()
# 將numpy資料轉換為Tensor
input_data = https://www.cnblogs.com/vipsoft/p/paddle.to_tensor([padding_im])
# 骨干網路輸出
feature = backbone(input_data)
# 查看feature map的緯度
print("backbone output:", feature.shape)
  • neck

neck 部分將backbone輸出的視覺特征圖轉換為1維向量輸入送到 LSTM 網路中,輸出序列特征( 原始碼位置 ):

class Im2Seq(nn.Layer):
    def __init__(self, in_channels, **kwargs):
        """
        影像特征轉換為序列特征
        :param in_channels: 輸入通道數
        """ 
        super().__init__()
        self.out_channels = in_channels

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == 1
        x = x.squeeze(axis=2)
        x = x.transpose([0, 2, 1])  # (NWC)(batch, width, channels)
        return x

class EncoderWithRNN(nn.Layer):
    def __init__(self, in_channels, hidden_size):
        super(EncoderWithRNN, self).__init__()
        self.out_channels = hidden_size * 2
        self.lstm = nn.LSTM(
            in_channels, hidden_size, direction='bidirectional', num_layers=2)

    def forward(self, x):
        x, _ = self.lstm(x)
        return x


class SequenceEncoder(nn.Layer):
    def __init__(self, in_channels, hidden_size=48, **kwargs):
        """
        序列編碼
        :param in_channels: 輸入通道數
        :param hidden_size: 隱藏層size
        """ 
        super(SequenceEncoder, self).__init__()
        self.encoder_reshape = Im2Seq(in_channels)

        self.encoder = EncoderWithRNN(
            self.encoder_reshape.out_channels, hidden_size)
        self.out_channels = self.encoder.out_channels

    def forward(self, x):
        x = self.encoder_reshape(x)
        x = self.encoder(x)
        return x
neck = SequenceEncoder(in_channels=288)
sequence = neck(feature)
print("sequence shape:", sequence.shape)  # sequence shape:[1, 80, 96]  96 = hidden_size=48 * 2
  • head

預測頭部分由全連接層和softmax組成,用于計算序列特征時間步上的標簽概率分布,本示例僅支持模型識別小寫英文字母和數字(26+10)36個類別(原始碼位置):

class CTCHead(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 **kwargs):
        """
        CTC 預測層
        :param in_channels: 輸入通道數
        :param out_channels: 輸出通道數
        """
        super(CTCHead, self).__init__()
        self.fc = nn.Linear(
            in_channels,
            out_channels)

        # 思考:out_channels 應該等于多少?  fc(x) 的 out_channels = 37 因為有一個空字符 =  (26+10)36個類別 + 1 = 37
        self.out_channels = out_channels

    def forward(self, x):
        predicts = self.fc(x)
        result = predicts

        if not self.training:
            predicts = F.softmax(predicts, axis=2)
            result = predicts

        return result

在網路隨機初始化的情況下,輸出結果是無序的,經過SoftMax之后,可以得到各時間步上的概率最大的預測結果,其中:pred_id 代表預測的標簽ID,pre_scores 代表預測結果的置信度:

ctc_head = CTCHead(in_channels=96, out_channels=37)
predict = ctc_head(sequence)
print("predict shape:", predict.shape)
result = F.softmax(predict, axis=2)
pred_id = paddle.argmax(result, axis=2)
pred_socres = paddle.max(result, axis=2)
print("pred_id:", pred_id)
print("pred_scores:", pred_socres)
  • 后處理

識別網路最侄訓傳的結果是各個時間步上的最大索引值,最終期望的輸出是對應的文字結果,因此CRNN的后處理是一個解碼程序,主要邏輯如下:

def decode(text_index, text_prob=None, is_remove_duplicate=False):
    """ convert text-index into text-label. """
    character = "-0123456789abcdefghijklmnopqrstuvwxyz" # 沒有從字典取字符,因為比較簡單,就直接寫在這了
    result_list = []
    # 忽略tokens [0] 代表ctc中的blank位
    ignored_tokens = [0]
    batch_size = len(text_index)
    for batch_idx in range(batch_size):
        char_list = []
        conf_list = []
        for idx in range(len(text_index[batch_idx])):
            if text_index[batch_idx][idx] in ignored_tokens:
                continue
            # 合并blank之間相同的字符
            if is_remove_duplicate:
                # only for predict
                if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
                        batch_idx][idx]:
                    continue
            # 將解碼結果存在char_list內
            char_list.append(character[int(text_index[batch_idx][
                idx])])
            # 記錄置信度
            if text_prob is not None:
                conf_list.append(text_prob[batch_idx][idx])
            else:
                conf_list.append(1)
        text = ''.join(char_list)
        # 輸出結果
        result_list.append((text, np.mean(conf_list)))
    return result_list

以 head 部分隨機初始化預測出的結果為例,進行解碼得到:

pred_id = paddle.argmax(result, axis=2)
pred_socres = paddle.max(result, axis=2)
print(pred_id)
decode_out = decode(pred_id, pred_socres)
print("decode out:", decode_out)

上述步驟完成了網路的搭建,也實作了一個簡單的前向預測程序,

沒有經過訓練的網路無法正確預測結果,因此需要定義損失函式、優化策略,將整個網路run起來,下面將詳細介紹網路訓練原理,

3. 訓練原理詳解

3.1 準備訓練資料

PaddleOCR 支持兩種資料格式:

  • lmdb 用于訓練以lmdb格式存盤的資料集(LMDBDataSet);
  • 通用資料 用于訓練以文本檔案存盤的資料集(SimpleDataSet);

本次只介紹通用資料格式讀取

訓練資料的默認存盤路徑是 ./train_data, 執行以下命令解壓資料:

cd /home/aistudio/work/train_data/ && tar xf ic15_data.tar

解壓完成后,訓練圖片都在同一個檔案夾內,并有一個txt檔案(rec_gt_train.txt)記錄圖片路徑和標簽,txt檔案里的內容如下:

" 影像檔案名         影像標注資訊 "

train/word_1.png	Genaxis Theatre
train/word_2.png	[06]
...

注意: txt檔案中默認將圖片路徑和圖片標簽用 \t 分割,如用其他方式分割將造成訓練報錯,

資料集應有如下檔案結構:

|-train_data
  |-ic15_data
    |- rec_gt_train.txt
    |- train
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...
    |- rec_gt_test.txt
    |- test
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...

確認組態檔中的資料路徑是否正確,以 rec_icdar15_train.yml為例:

Train:
  dataset:
    name: SimpleDataSet
    # 訓練資料根目錄
    data_dir: ./train_data/ic15_data/
    # 訓練資料標簽
    label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 100]  # [3,32,320]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: True
    batch_size_per_card: 256 # 第一次除錯可以改小一點,防止一起來就崩了
    drop_last: True
    num_workers: 8
    use_shared_memory: False

Eval:
  dataset:
    name: SimpleDataSet
    # 評估資料根目錄
    data_dir: ./train_data/ic15_data
    # 評估資料標簽
    label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 100]  # 要得訓練的設定值一致
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 256
    num_workers: 4
    use_shared_memory: False

3.2 資料預處理

送入網路的訓練資料,需要保證一個batch內維度一致,同時為了不同維度之間的特征在數值上有一定的比較性,需要對資料做統一尺度縮放歸一化

為了增加模型的魯棒性,抑制過擬合提升泛化性能,需要實作一定的資料增廣

  • 縮放和歸一化

第二節中已經介紹了相關內容,這是圖片送入網路之前的最后一步操作,呼叫 resize_norm_img 完成圖片縮放、padding和歸一化,

  • 資料增廣

PaddleOCR中實作了多種資料增廣方式,如:顏色反轉、隨機切割、仿射變化、隨機噪聲等等,這里以簡單的隨機切割為例,更多增廣方式可參考:rec_img_aug.py

def get_crop(image):
    """
    random crop
    """
    import random
    h, w, _ = image.shape
    top_min = 1
    top_max = 8
    top_crop = int(random.randint(top_min, top_max))
    top_crop = min(top_crop, h - 1)
    crop_img = image.copy()
    ratio = random.randint(0, 1)
    if ratio:
        crop_img = crop_img[top_crop:h, :, :]
    else:
        crop_img = crop_img[0:h - top_crop, :, :]
    return crop_img
# 讀圖
raw_img = cv2.imread("/home/aistudio/work/word_1.png")
plt.figure()
plt.subplot(2,1,1)
# 可視化原圖
plt.imshow(raw_img)
# 隨機切割
crop_img = get_crop(raw_img)
plt.subplot(2,1,2)
# 可視化增廣圖
plt.imshow(crop_img)
plt.show()

image

3.3 訓練主程式

模型訓練的入口代碼是 train.py,它展示了訓練中所需的各個模塊: build dataloader, build post process, build model , build loss, build optim, build metric,將各部分串聯后即可開始訓練:

  • 構建 dataloader

訓練模型需要將資料組成指定數目的 batch ,并在訓練程序中依次 yield 出來,本例中呼叫了 PaddleOCR 中實作的 SimpleDataSet

基于原始代碼稍作修改,其回傳單條資料的主要邏輯如下

def __getitem__(data_line, data_dir):
    import os
    mode = "train"
    delimiter = '\t' # label 設定的時候,前面是圖片的路徑,所以要 \t 進行分隔后面真實的標簽
    try:
        substr = data_line.strip("\n").split(delimiter)
        file_name = substr[0]
        label = substr[1]
        img_path = os.path.join(data_dir, file_name)
        data = https://www.cnblogs.com/vipsoft/p/{'img_path': img_path, 'label': label}
        if not os.path.exists(img_path):
            raise Exception("{} does not exist!".format(img_path))
        with open(data['img_path'], 'rb') as f:
            img = f.read()
            data['image'] = img
        # 預處理操作,先注釋掉
        # outs = transform(data, self.ops)
        outs = data
    except Exception as e:
        print("When parsing line {}, error happened with msg: {}".format(
                data_line, e))
        outs = None
    return outs

假設當前輸入的標簽為 train/word_1.png Genaxis Theatre, 訓練資料的路徑為 /home/aistudio/work/train_data/ic15_data/, 決議出的結果是一個字典,里面包含 img_path label image 三個欄位:

data_line = "train/word_1.png	Genaxis Theatre"
data_dir = "/home/aistudio/work/train_data/ic15_data/"

item = __getitem__(data_line, data_dir)
print(item)

實作完單條資料回傳邏輯后,呼叫 padde.io.Dataloader 即可把資料組合成batch,具體可參考 build_dataloader

  • build model

    build model 即搭建主要網路結構,具體細節如《2.3 代碼實作》所述,本節不做過多介紹,各模塊代碼可參考modeling

  • build loss

    CRNN 模型的損失函式為 CTC loss, 飛槳集成了常用的 Loss 函式,只需呼叫實作即可:

import paddle.nn as nn
class CTCLoss(nn.Layer):
    def __init__(self, use_focal_loss=False, **kwargs):
        super(CTCLoss, self).__init__()
        # blank 是 ctc 的無意義連接符
        self.loss_func = nn.CTCLoss(blank=0, reduction='none')

    def forward(self, predicts, batch):
        if isinstance(predicts, (list, tuple)):
            predicts = predicts[-1]
        # 轉置模型 head 層的預測結果,沿channel層排列
        predicts = predicts.transpose((1, 0, 2)) #[80,1,37]
        N, B, _ = predicts.shape
        preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
        labels = batch[1].astype("int32")
        label_lengths = batch[2].astype('int64')
        # 計算損失函式
        loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
        loss = loss.mean()
        return {'loss': loss}
  • build post process

具體細節同樣在《2.3 代碼實作》有詳細介紹,實作邏輯與之前一致,

  • build optim

優化器使用 Adam , 同樣呼叫飛槳API: paddle.optimizer.Adam

  • build metric

metric 部分用于計算模型指標,PaddleOCR的文本識別中,將整句預測正確判斷為預測正確,因此準確率計算主要邏輯如下:

def metric(preds, labels):    
    correct_num = 0
    all_num = 0
    norm_edit_dis = 0.0
    for (pred), (target) in zip(preds, labels):
        pred = pred.replace(" ", "")     # 如果嚴格預測的話,把這代碼注釋掉,比較時,空格敏感的預測
        target = target.replace(" ", "")
        if pred == target:
            correct_num += 1
        all_num += 1
    correct_num += correct_num
    all_num += all_num
    return {
        'acc': correct_num / all_num,
    }
preds = ["aaa", "bbb", "ccc", "123", "456"]
labels = ["aaa", "bbb", "ddd", "123", "444"]
acc = metric(preds, labels)
print("acc:", acc)
# 五個預測結果中,完全正確的有3個,因此準確率應為0.6

將以上各部分組合起來,即是完整的訓練流程:


def main(config, device, logger, vdl_writer):
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()

    global_config = config['Global']

    # build dataloader
    train_dataloader = build_dataloader(config, 'Train', device, logger)
    if len(train_dataloader) == 0:
        logger.error(
            "No Images in train dataset, please ensure\n" +
            "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
            +
            "\t2. The annotation file and path in the configuration file are provided normally."
        )
        return

    if config['Eval']:
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
    else:
        valid_dataloader = None

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        if config['Architecture']["algorithm"] in ["Distillation",
                                                   ]:  # distillation model
            for key in config['Architecture']["Models"]:
                config['Architecture']["Models"][key]["Head"][
                    'out_channels'] = char_num
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

    model = build_model(config['Architecture'])
    if config['Global']['distributed']:
        model = paddle.DataParallel(model)

    # build loss
    loss_class = build_loss(config['Loss'])

    # build optim
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
        epochs=config['Global']['epoch_num'],
        step_each_epoch=len(train_dataloader),
        parameters=model.parameters())

    # build metric
    eval_class = build_metric(config['Metric'])
    # load pretrain model
    pre_best_model_dict = load_model(config, model, optimizer)
    logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
    if valid_dataloader is not None:
        logger.info('valid dataloader has {} iters'.format(
            len(valid_dataloader)))

    use_amp = config["Global"].get("use_amp", False)
    if use_amp:
        AMP_RELATED_FLAGS_SETTING = {
            'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
            'FLAGS_max_inplace_grad_add': 8,
        }
        paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
        scale_loss = config["Global"].get("scale_loss", 1.0)
        use_dynamic_loss_scaling = config["Global"].get(
            "use_dynamic_loss_scaling", False)
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)
    else:
        scaler = None

    # start train
    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
                  eval_class, pre_best_model_dict, logger, vdl_writer, scaler)

4. 完整訓練任務

4.1 啟動訓練

PaddleOCR 識別任務與檢測任務類似,是通過組態檔傳輸引數的,

要進行完整的模型訓練,首先需要下載整個專案并安裝相關依賴:

# 克隆PaddleOCR代碼
#!git clone https://gitee.com/paddlepaddle/PaddleOCR
# 修改代碼運行的默認目錄為 /home/aistudio/PaddleOCR
import os
os.chdir("/home/aistudio/PaddleOCR")
# 安裝PaddleOCR第三方依賴
!pip install -r requirements.txt

創建軟鏈,將訓練資料放在PaddleOCR專案下:

!ln -s /home/aistudio/work/train_data/ /home/aistudio/PaddleOCR/

下載預訓練模型:

為了加快收斂速度,建議下載訓練好的模型在 icdar2015 資料上進行 finetune

!cd PaddleOCR/
# 下載MobileNetV3的預訓練模型
!wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
# 解壓模型引數
!tar -xf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar

啟動訓練命令很簡單,指定好組態檔即可,另外在命令列中可以通過 -o 修改組態檔中的引數值,啟動訓練命令如下所示

其中:

  • Global.pretrained_model: 加載的預訓練模型路徑
  • Global.character_dict_path : 字典路徑(這里只支持26個小寫字母+數字)
  • Global.eval_batch_step : 評估頻率,[0,200] 從第0次開始計算,每200次評估一次
  • Global.epoch_num: 總訓練輪數
# -o 修改默認的引數值,命令列修改優先級會更高
!python3 tools/train.py -c configs/rec/rec_icdar15_train.yml \
   -o Global.pretrained_model=rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy \
   Global.character_dict_path=ppocr/utils/ic15_dict.txt \
   Global.eval_batch_step=[0,200] \
   Global.epoch_num=40 \
   Global.use_gpu=false

根據組態檔中設定的的 save_model_dir 欄位,會有以下幾種引數被保存下來:

output/rec/ic15
├── best_accuracy.pdopt              # 每次評估時,拿到的最優評估結果
├── best_accuracy.pdparams
├── best_accuracy.states
├── config.yml                       # 命令列修改后的引數(當前訓練的引數)
├── iter_epoch_3.pdopt               # 每3個epoch保存一次,會看到 epoch_3、epoch_6、epoch_9 訓練的中間狀態,可以在組態檔中修改保存的頻度
├── iter_epoch_3.pdparams
├── iter_epoch_3.states
├── latest.pdopt
├── latest.pdparams
├── latest.states
└── train.log

其中 best_accuracy.* 是評估集上的最優模型;iter_epoch_x.* 是以 save_epoch_step 為間隔保存下來的模型;latest.* 是最后一個epoch的模型,
總結:

如果需要訓練自己的資料需要修改:

  1. 訓練和評估資料路徑(必須)
  2. 字典路徑(必須)
  3. 預訓練模型 (可選)
  4. 學習率、image shape、網路結構(可選)

4.2 模型評估

評估資料集可以通過 configs/rec/rec_icdar15_train.yml 修改Eval中的 label_file_path 設定,

這里默認使用 icdar2015 的評估集,加載剛剛訓練好的模型權重:

!python tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy \
        Global.character_dict_path=ppocr/utils/ic15_dict.txt

評估后,可以看到訓練模型在驗證集上的精度,

PaddleOCR支持訓練和評估交替進行, 可在 configs/rec/rec_icdar15_train.yml 中修改 eval_batch_step 設定評估頻率,默認每2000個iter評估一次,評估程序中默認將最佳acc模型,保存為 output/rec/ic15/best_accuracy

如果驗證集很大,測驗將會比較耗時,建議減少評估次數,或訓練完再進行評估,

4.3 預測

使用 PaddleOCR 訓練好的模型,可以通過以下腳本進行快速預測,

預測圖片:

默認預測圖片存盤在 infer_img 里,通過 -o Global.checkpoints 加載訓練好的引數檔案:

!python tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy Global.character_dict_path=ppocr/utils/ic15_dict.txt

得到輸入影像的預測結果:

infer_img: doc/imgs_words_en/word_19.png
        result: slow	0.8795223

image

原文:AI Studio學習『動手學OCR·十講』https://aistudio.baidu.com/aistudio/course/introduce/25207?sharedLesson=2077537&sharedType=2&sharedUserId=2631487&ts=1685512885510

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/554382.html

標籤:其他

上一篇:ChatGpt玩法(一):prompt使用的重要性

下一篇:返回列表

標籤雲
其他(160407) Python(38206) JavaScript(25475) Java(18198) C(15237) 區塊鏈(8270) C#(7972) AI(7469) 爪哇(7425) MySQL(7234) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5873) 数组(5741) R(5409) Linux(5346) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4582) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2434) ASP.NET(2403) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) .NET技术(1981) 功能(1967) HtmlCss(1952) Web開發(1951) C++(1929) python-3.x(1918) 弹簧靴(1913) xml(1889) PostgreSQL(1879) .NETCore(1863) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 網閘典型架構簡述

    網閘架構一般分為兩種:三主機的三系統架構網閘和雙主機的2+1架構網閘。 三主機架構分別為內端機、外端機和仲裁機。三機無論從軟體和硬體上均各自獨立。首先從硬體上來看,三機都用各自獨立的主板、記憶體及存盤設備。從軟體上來看,三機有各自獨立的作業系統。這樣能達到完全的三機獨立。對于“2+1”系統,“2”分為 ......

    uj5u.com 2020-09-10 02:00:44 more
  • 如何從xshell上傳檔案到centos linux虛擬機里

    如何從xshell上傳檔案到centos linux虛擬機里及:虛擬機CentOs下執行 yum -y install lrzsz命令,出現錯誤:鏡像無法找到軟體包 前言 一、安裝lrzsz步驟 二、上傳檔案 三、遇到的問題及解決方案 總結 前言 提示:其實很簡單,往虛擬機上安裝一個上傳檔案的工具 ......

    uj5u.com 2020-09-10 02:00:47 more
  • 一、SQLMAP入門

    一、SQLMAP入門 1、判斷是否存在注入 sqlmap.py -u 網址/id=1 id=1不可缺少。當注入點后面的引數大于兩個時。需要加雙引號, sqlmap.py -u "網址/id=1&uid=1" 2、判斷文本中的請求是否存在注入 從文本中加載http請求,SQLMAP可以從一個文本檔案中 ......

    uj5u.com 2020-09-10 02:00:50 more
  • Metasploit 簡單使用教程

    metasploit 簡單使用教程 浩先生, 2020-08-28 16:18:25 分類專欄: kail 網路安全 linux 文章標簽: linux資訊安全 編輯 著作權 metasploit 使用教程 前言 一、Metasploit是什么? 二、準備作業 三、具體步驟 前言 Msfconsole ......

    uj5u.com 2020-09-10 02:00:53 more
  • 游戲逆向之驅動層與用戶層通訊

    驅動層代碼: #pragma once #include <ntifs.h> #define add_code CTL_CODE(FILE_DEVICE_UNKNOWN,0x800,METHOD_BUFFERED,FILE_ANY_ACCESS) /* 更多游戲逆向視頻www.yxfzedu.com ......

    uj5u.com 2020-09-10 02:00:56 more
  • 北斗電力時鐘(北斗授時服務器)讓網路資料更精準

    北斗電力時鐘(北斗授時服務器)讓網路資料更精準 北斗電力時鐘(北斗授時服務器)讓網路資料更精準 京準電子科技官微——ahjzsz 近幾年,資訊技術的得了快速發展,互聯網在逐漸普及,其在人們生活和生產中都得到了廣泛應用,并且取得了不錯的應用效果。計算機網路資訊在電力系統中的應用,一方面使電力系統的運行 ......

    uj5u.com 2020-09-10 02:01:03 more
  • 【CTF】CTFHub 技能樹 彩蛋 writeup

    ?碎碎念 CTFHub:https://www.ctfhub.com/ 筆者入門CTF時時剛開始刷的是bugku的舊平臺,后來才有了CTFHub。 感覺不論是網頁UI設計,還是題目質量,賽事跟蹤,工具軟體都做得很不錯。 而且因為獨到的金幣制度的確讓人有一種想去刷題賺金幣的感覺。 個人還是非常喜歡這個 ......

    uj5u.com 2020-09-10 02:04:05 more
  • 02windows基礎操作

    我學到了一下幾點 Windows系統目錄結構與滲透的作用 常見Windows的服務詳解 Windows埠詳解 常用的Windows注冊表詳解 hacker DOS命令詳解(net user / type /md /rd/ dir /cd /net use copy、批處理 等) 利用dos命令制作 ......

    uj5u.com 2020-09-10 02:04:18 more
  • 03.Linux基礎操作

    我學到了以下幾點 01Linux系統介紹02系統安裝,密碼啊破解03Linux常用命令04LAMP 01LINUX windows: win03 8 12 16 19 配置不繁瑣 Linux:redhat,centos(紅帽社區版),Ubuntu server,suse unix:金融機構,證券,銀 ......

    uj5u.com 2020-09-10 02:04:30 more
  • 05HTML

    01HTML介紹 02頭部標簽講解03基礎標簽講解04表單標簽講解 HTML前段語言 js1.了解代碼2.根據代碼 懂得挖掘漏洞 (POST注入/XSS漏洞上傳)3.黑帽seo 白帽seo 客戶網站被黑帽植入劫持代碼如何處理4.熟悉html表單 <html><head><title>TDK標題,描述 ......

    uj5u.com 2020-09-10 02:04:36 more
最新发布
  • OCR -- 文本識別 -- 實踐篇

    [OCR -- 文本識別 -- 理論篇](https://www.cnblogs.com/vipsoft/p/17445583.html) 本章將詳細介紹如何基于PaddleOCR完成CRNN文本識別模型的搭建、訓練、評估和預測。資料集采用 icdar 2015,其中訓練集有4468張,測驗集有20 ......

    uj5u.com 2023-06-06 09:28:26 more
  • ChatGpt玩法(一):prompt使用的重要性

    本文介紹了ChatGpt中prompt使用的重要性,并從prompt的生成、prompt重要性等角度進行了分析。希望讀者能夠通過本文更好地理解并使用ChatGpt中的prompt功能。否則你將和ChatGPT一起開始滾雪球,他回答的不是你想要的答案,你問的他也不了解;更多關于ChatGPT的玩法,關... ......

    uj5u.com 2023-06-06 09:27:04 more
  • GPT時代,一定要收藏的結構化提問技巧

    有一種被稱為“SMART”的結構化提問方法,可以幫助你更好的組織和明確提出的問題。**“SMART”**是一個縮寫,它代表了以下幾個關鍵元素: **S:Specific(具體)** 確保問題具體明確,避免模糊或含糊不清的表達,明確你想要得到的具體資訊或解決的問題。 **M:Measurable(可衡 ......

    uj5u.com 2023-06-06 09:26:50 more
  • 必讀:學習C語言編程的路線圖

    學習C語言編程,可以豐富編程思維的訓練和經驗,以下是一些學習C語言編程的路線圖: 1. **設定開發環境:**在計算機上安裝C編譯器。GNU編譯器集合(GCC)是一個流行的選擇,適用于Windows、macOS和Linux等各種作業系統。安裝IDE編程環境,推薦VSCode,有豐富的插件可以輔助編程 ......

    uj5u.com 2023-06-06 09:26:42 more
  • 身未動心已遠,AI帶你流浪地球

    摘要:我們提供了一鍵運行的notebook AI作畫 Dreambooth 生成自定義主體,可以在ModelArts平臺上除錯開發自己的文生圖模型。 本文分享自華為云社區《DreamBooth+LoRA微調生成主體》,作者: 杜甫蓋房子 。 DreamBooth+LoRA微調生成主體 文生圖風靡一時 ......

    uj5u.com 2023-06-06 09:26:30 more
  • 從0到1:如何建立一個大規模多語言代碼生成預訓練模型

    國產AI輔助編程工具**[CodeGeeX](https://codegeex.cn/)**是一個使用AI大模型為基座的輔助編程工具,幫助開發人員更快的撰寫代碼。可以自動完成整個函式的撰寫,只需要根據注釋或Tab按鍵即可。它已經在Java、JavaScript和Python等二十多種語言上進行了訓練 ......

    uj5u.com 2023-06-06 09:26:03 more
  • 1500個心理測驗大全題庫ACCESS資料庫

    今天收集到了一個結構很簡單的心理測驗題庫,所謂的結構很簡單是區別于《常用心理測驗精選題庫ACCESS資料庫》、《心理測驗性格測驗大全ACCESS資料庫》等資料庫一個測驗專案需做幾道題然后給出解釋,表與表需要關聯而言的,這個心理測驗很簡單,給出一個測驗題目,你回答ABCD即給出解釋。 最大的選項數是選 ......

    uj5u.com 2023-06-06 09:25:52 more
  • 大模型核心技術原理: Transformer架構詳解

    在大模型發展歷程中,有兩個比較重要點:第一,Transformer 架構。它是模型的底座,但 Transformer 不等于大模型,但大模型的架構可以基于 Transformer;第二,GPT。嚴格意義上講,GPT 可能不算是一個模型,更像是一種預訓練范式,它本身模型架構是基于 Transforme ......

    uj5u.com 2023-06-06 09:25:40 more
  • 云原生周刊:開發人員使用 GPT-4 的 30 種重要方法 | 2023-6-5

    OpenAI 最新的大型語言模型 GPT-4 有非常多的用途,那么,作為開發人員,應該如何去使用它,來幫助自己作業呢? 在最近的 [Hacker News 問答](https://news.ycombinator.com/item?id=36037559)中,很多開發人員參與了討論和分享。 也有人進 ......

    uj5u.com 2023-06-06 09:25:22 more
  • 解鎖邊緣無限可能,天翼云助力教育數字化轉型!

    5月16日-17日,由教育部教育管理資訊中心主辦、中國電信協辦的2023年教育系統網路安全作業探討會及網路安全展在貴陽召開。會議以“強化網路安全公共支撐能力,護航國家教育數字化戰略行動”為主題,來自各省級教育行政部門、部屬高校、部直屬單位的網路安全和資訊化相關嘉賓共同出席了會議。 天翼云要客行業中心 ......

    uj5u.com 2023-06-06 09:25:13 more