文字識別:CTC LOSS 學(xué)習(xí)筆記

CTCloss 詳解

簡介

在ocr任務(wù)與機(jī)器翻譯中,輸入與輸出GT文本很難在單詞上對齊,在預(yù)處理的時候?qū)R是非常困難的,但是如果不對齊而直接訓(xùn)練模型的話,由于字符距離的不同,導(dǎo)致模型很難收斂.

CTC(Connectionist Temporal Classification)避免了輸入與輸出手動對齊,適合OCR或語音這樣的序列應(yīng)用;

建模

給定輸入序列X=[x1,x2,...,xT],以及對應(yīng)的標(biāo)簽數(shù)據(jù)Y=[y1,y2,...yU].T,U不一定相等 我們的工作是找到一個X到Y(jié)的映射.這種對時序數(shù)據(jù)進(jìn)行分類算法叫做Temporal Classification。

對比傳統(tǒng)分類方法,時序分類有一下困難:

  • X和Y的長度都是變化且不相等的.
  • 對于一個端到端的模型,我們并不想手動設(shè)計X和Y之間的對齊.

CTC提供了解決方案,對于一個給定的輸入序列 X ,CTC給出所有可能的 Y 的輸出分布。根據(jù)這個分布,我們可以輸出最可能的結(jié)果或者給出某個輸出的概率。

loss:給定輸入序列X,我們希望最大化Y的后驗(yàn)概率 P(Y|X),P(Y|X)應(yīng)該是可導(dǎo)的,這樣我們就能執(zhí)行梯度下降算法進(jìn)行優(yōu)化.
infer:Y^{\star}=argmaxP(Y|X)

1.1 對齊

在ocr任務(wù)中,輸入X是一張含有"CAT"的圖片,輸出Y是文本[C,A,T]
最原始的對齊方式將X分割成若干個時間片,每一個時間片得到一個字符的輸出,然后合并連續(xù)重復(fù)出現(xiàn)的字符.
然而這樣做有兩個缺點(diǎn):

  • 幾乎不可能將 X 的每個時間片都和輸出Y對應(yīng)上,例如OCR中字符的間隔,語音識別中的停頓;
  • 不能處理有連續(xù)重復(fù)字符出現(xiàn)的情況,例如單詞“APPLE”,按照上面的算法,輸出的是“APLE”而非“APPLE”。

為了解決上面的問題,CTC引入了空白字符,
CTC的對齊涉及去除重復(fù)字母和去除 空白字符 兩部分.
其規(guī)則:

  • 連續(xù)相同的字符做去重
  • 去重空白字符

比如,對于長度為10的輸入序列,以下RNN輸出序列都可以映射為: apple

  • _aappp_ple
  • ap_p_|_ e
  • _ _ app_ple_

最后要計算P(Y|X),可以累加其對應(yīng)的全部輸出(全部輸出為apple)的路徑概率之和.

因此,在訓(xùn)練階段,我們要對GT進(jìn)行標(biāo)簽擴(kuò)充.其做法是:
頭尾加空白符,并在GT中的每一個字符間插入空白符;
用l表示最終標(biāo)簽,l’表示擴(kuò)展后的形式,
則由2|l| + 1 = 2|l’|,比如:l=apple => l’=_a_p_p_l_e_

如圖所示:


image

1.2 路徑搜索與動態(tài)規(guī)劃

image
image
image

上圖的路徑搜索中:
定義:

  • \alpha(s,t)為第t個時刻,gt字符串l`的第s個字符的路徑前向概率.
  • p^{t}(Z_{s}|X) 為預(yù)測矩陣中第t時刻是第s個字符的概率.
  • p(l|x)為輸入序列x,輸出為l的概率,我們要最大化其概率

(1) 如果\alpha(s,t)為空白符.則\alpha(s,t)只能由前一個空白符\alpha(s-1,t-1)或者其GT中該字符為上一時刻\alpha(s,t-1)得到(因?yàn)槲覀兪歉粢粋€字符插入空白符,當(dāng)前字符是空白的話,如果前一個也不是s的字符,就會錯過GT中s字符,導(dǎo)致最后的path沒法解析到GT,所以要么最多連續(xù)兩個空白格,要么是前一個已經(jīng)出現(xiàn)s字符,當(dāng)前可以為空白)

所以這種情況下其概率為:
\alpha(s,t)=(\alpha(s-1,t-1)+\alpha(s,t-1))*p^{t}(Z_{s}|X)

(2)如果\alpha(s,t)不為空白符,那么該點(diǎn)的前向概率之和可以通過以下路徑得到

  • (\alpha(s-1,t-1):s為當(dāng)前gt的第s個字符,然后前一個為空白字符的概率
  • (\alpha(s,t-1):當(dāng)前字符s連續(xù)出現(xiàn)的概率
  • (\alpha(s-2,t-1):前一個字符GT的上一個字符,當(dāng)前步是GT的S個字符,代表s-1,s之間沒有空白符,沒有連續(xù)的字符的概率(因?yàn)槊總€字符都隔了一個空白符)

所以這種情況下前向概率為:
\alpha(s,t)=(\alpha(s-1,t-1)+\alpha(s,t-1)+\alpha(s-2,t-1))*p^{t}(Z_{s}|X)

初始化值:
\alpha(0,0)=p^{0}(空白符|X)
\alpha(1,0)=p^{0}(s1|X)
(代表了T時刻可以從空白符出發(fā),也可以從gt的第0個字符開始)

最后我們需要計算

p(l|x)=\alpha(s,t)+\alpha(s+1,t)
兩個前向概率之和便得到前向概率之和.如圖的右下角兩個位置概率之和.

利用前向概率計算ctc 的loss
argmax p(l|x) 等于最小化對數(shù)域. -ln(p(l|x))

所以loss的值為:
-ln(p(l|x))

簡化計算

我們看到在計算過程中我們發(fā)現(xiàn)了大量的連乘。由于每一個數(shù)字都是浮點(diǎn)數(shù),那么這樣連乘下去,最終數(shù)字有可能非常小而導(dǎo)致underflow。所以我們要將這個計算過程轉(zhuǎn)到對數(shù)域上。這樣我們就將其中的乘法轉(zhuǎn)變成了加法。

\begin{array}{l} \log (a+b)=\log \left(a\left(1+\frac{a}\right)\right)=\log a+\log \left(1+\frac{a}\right) \\ =\log a+\log \left(1+\exp \left(\log \left(\frac{a}\right)\right)\right)=\log a+\log (1+\exp (\log b-\log a)) \end{array}

由于最后計算loss為-ln(P)

所以在計算前向概率的時候可以直接計算log(p)的值..

logsumexp的優(yōu)化

如果我們有N個概率,{\{x_{n}\}}^{N}_{n=1},我們想求其對數(shù)域之和:

z=\log \sum_{n=1}^{N} \exp \left\{x_{n}\right\}

如果X_{n}很大或很小,樸素的直接計算會上溢出或下溢出,從而導(dǎo)致嚴(yán)重問題。舉個例子,對于[0,1,0]直接計算是可行的,我們可以得到1.55。但對于[1000,1001,1000]卻并不可行,我們會得到inf;對于[-1000,-999,-1000],還是不行,我們會得到-inf.

解決方法:
\log \sum_{n=1}^{N} \exp \left\{x_{n}\right\}=a+\log \sum_{n=1}^{N} \exp \left\{x_{n}-a\right\}

一般情況下,a取N個值中的最大值;
這可以保證指數(shù)最大不會超過0,于是你就不會上溢出。即便剩余的部分下溢出了,你也能得到一個合理的值。

證明:

\begin{aligned} y &=\log \sum_{i=1}^{n} e^{x_{i}} \\ \Rightarrow e^{y} &=\sum_{i=1}^{n} e^{x_{i}} \\ \Rightarrow e^{-a} e^{y} &=e^{-a} \sum_{i=1}^{n} e^{x_{i}} \\ \Rightarrow e^{y-a} &=\sum_{i=1}^{n} e^{-a} e^{x_{i}} \\ \Rightarrow e^{y-a} &=\sum_{i=1}^{n} e^{-a} e^{x_{i}} \\ \Rightarrow \quad y-a &=\log \sum_{i=1}^{n} e^{x_{i}-a} \\ \Rightarrow \quad y &=a+\log \sum_{i=1}^{n} e^{x_{i}-a} \end{aligned}

CTC 概率圖前向概率:
\alpha(s,t)代表了第t時刻第s個gt字符(經(jīng)過補(bǔ)充空白符)的概率

\alpha(s, t)=(\alpha(s, t-1)+\alpha(s-1, t-1)) \cdot p_{t}\left(z_{s} \mid X\right)

\alpha(s, t)=(\alpha(s, t-1)+\alpha(s-1, t-1)+\alpha(s-2, t-1)) \cdot p_{t}\left(z_{s} \mid X\right)

在torch的ctc 前饋過程中,計算的log前向概率值的矩陣,(用以進(jìn)行l(wèi)oss back),我們看到其核心:

  • 每一個baych ,通過兩層循環(huán)(T,S)動態(tài)規(guī)劃計算前向概率Log值. 在計算的同時將同樣需要計算的\alpha(s, t-1)+\alpha(s-1, t-1)作為la1 ,la2,然后判斷當(dāng)前s的字符來決定第三個加項(xiàng)是否為\alpha(s-2, t-1)
  • 同時,因?yàn)檗D(zhuǎn)換到了對數(shù)域,也避免了數(shù)變小與溢出的問題,其項(xiàng)變成了
log_alpha_a[t][s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_a[t][current_target_prime];

源碼分析:
使用了pytorch中ctcloss 的源碼

//pytorch/aten/src/ATen/native/LossCTC.cpp

//獲取填充blank后的target指定位置的值,用來判斷是
static inline int64_t get_target_prime(target_t *target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK)
            {
                if (idx % 2 == 0)
                {
                    return BLANK;
                }
                else
                {
                    return target[offset + stride * (idx / 2)];
                }
            }

//ctc_loss_cpu_template部分核心代碼
//前向概率[B,T,N]
                Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1}, log_probs.options());

                Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());

                //[B,T,N]
                auto lpp = log_probs.permute({1, 0, 2});
                auto log_probs_a_global = lpp.accessor<scalar_t, 3>();
                auto log_alpha_a_global = log_alpha.accessor<scalar_t, 3>();
                auto targets_data = targets.data_ptr<target_t>();
                auto neg_log_likelihood_a = neg_log_likelihood.accessor<scalar_t, 1>();

                // alpha calculation for the first row, the three equations for alpha_1 above eq (6)
                // first the default
                log_alpha.narrow(1, 0, 1).fill_(neginf);
                at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
                    for (int64_t b = start; b < end; b++)
                    {
                        //每個batch
                        int64_t input_length = input_lengths[b];
                        int64_t target_length = target_lengths[b];
                        auto log_probs_a = log_probs_a_global[b];
                        auto log_alpha_a = log_alpha_a_global[b];
                        int64_t tg_batch_offset = tg_batch_offsets[b];
                        
                        // the first two items of alpha_t above eq (6)
                        //初始化前向概率[t0][s0]
                        log_alpha_a[0][0] = log_probs_a[0][BLANK];
                        if (target_length > 0)
                            //[t0][s1]等于序列中第一個字符的概率
                            log_alpha_a[0][1] = log_probs_a[0][get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];

                        // now the loop over the inputs
                        for (int64_t t = 1; t < input_length; t++)
                        {
                            for (int64_t s = 0; s < 2 * target_length + 1; s++)
                            {
                                //對于每一個s,計算其概率
                                //獲取第s個字符是什么
                                auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
                                // this loop over s could be parallel/vectorized, too, but the required items are one index apart
                                // alternatively, one might consider moving s to the outer loop to cache current_target_prime more (but then it needs to be descending)
                                // for the cuda implementation, that gave a speed boost.
                                // This is eq (6) and (7), la1,2,3 are the three summands. We keep track of the maximum for the logsumexp calculation.

                                scalar_t la1 = log_alpha_a[t - 1][s];
                                scalar_t lamax = la1;
                                scalar_t la2, la3;
                                if (s > 0)
                                {
                                    
                                    la2 = log_alpha_a[t - 1][s - 1];
                                    if (la2 > lamax)
                                        lamax = la2;
                                }
                                else
                                {
                                    la2 = neginf;
                                }
                                if ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s - 2, BLANK) !=
                                                current_target_prime))
                                {
                                    //第s個字符不是空且不等于s-2(即不連續(xù)的時候),即動態(tài)轉(zhuǎn)移方程的第二個式子
                                    la3 = log_alpha_a[t - 1][s - 2];
                                    if (la3 > lamax)
                                        lamax = la3;
                                }
                                else
                                {   
                                    //s為空或者連續(xù),第三項(xiàng)不用加
                                    la3 = neginf;
                                }
                                //添加概率最大項(xiàng)按前一個[t-1][s]
                                if (lamax == neginf) // cannot do neginf-neginf
                                    lamax = 0;

                                //計算此時的
                                // this is the assignment of eq (6)
                                log_alpha_a[t][s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_a[t][current_target_prime];
                            }
                        }
                        // the likelihood is the the sum of the last two alphas, eq (8), the loss is the negative log likelihood
                        if (target_length == 0)
                        {
                            // if the target is empty then there is no preceding BLANK state and hence there is no path to merge
                            neg_log_likelihood_a[b] = -log_alpha_a[input_length - 1][0];
                        }
                        else
                        {
                            scalar_t l1 = log_alpha_a[input_length - 1][target_length * 2];
                            scalar_t l2 = log_alpha_a[input_length - 1][target_length * 2 - 1];
                            //取兩條路的概率之和
                            scalar_t m = std::max(l1, l2);
                            m = ((m == neginf) ? 0 : m);
                            scalar_t log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
                            neg_log_likelihood_a[b] = -log_likelihood;
                        }
                    }
                });

提供一個python 版本的numpy ctc的代碼方便理解

import numpy as np


ninf = -np.float('inf')

def _logsumexp(a, b):
    '''
    np.log(np.exp(a) + np.exp(b))
    '''

    if a < b:
        a, b = b, a

    if b == ninf:
        return a
    else:
        return a + np.log(1 + np.exp(b - a)) 

def logsumexp(*args):
    '''
    from scipy.special import logsumexp
    logsumexp(args)
    '''
    res = args[0]
    for e in args[1:]:
        res = _logsumexp(res, e)
    return res
class CTC:
    def __init__(self):
        pass

    def forward(self):
        pass

    def alpha(self, log_y, labels):
        ##alpha 為前向概率
        T, V = log_y.shape
        L = len(labels)
        log_alpha = np.ones([T, L]) * ninf

        # init
        ## 初始化動態(tài)規(guī)劃
        log_alpha[0, 0] = log_y[0, labels[0]]
        log_alpha[0, 1] = log_y[0, labels[1]]

        ##計算每一步,每個GT的前向概率
        for t in range(1, T):
            for i in range(L):
                s = labels[i]

                a = log_alpha[t - 1, i]
                if i - 1 >= 0:
                    a = logsumexp(a, log_alpha[t - 1, i - 1])
                ##如果當(dāng)前不是空白符,得加前兩步的狀態(tài)
                if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
                    a = logsumexp(a, log_alpha[t - 1, i - 2])

                log_alpha[t, i] = a + log_y[t, s]

        return log_alpha


    def beta(self, log_y, labels):
        ##計算后向概率
        T, V = log_y.shape
        L = len(labels)
        log_beta = np.ones([T, L]) * ninf

        # init
        log_beta[-1, -1] = log_y[-1, labels[-1]]
        log_beta[-1, -2] = log_y[-1, labels[-2]]

        for t in range(T - 2, -1, -1):
            for i in range(L):
                s = labels[i]

                a = log_beta[t + 1, i]
                if i + 1 < L:
                    a = logsumexp(a, log_beta[t + 1, i + 1])
                if i + 2 < L and s != 0 and s != labels[i + 2]:
                    a = logsumexp(a, log_beta[t + 1, i + 2])

                log_beta[t, i] = a + log_y[t, s]

        return log_beta

    def backward(selflog_y, labels):
        T, V = log_y.shape
        L = len(labels)

        log_alpha = self.alpha(log_y, labels)
        log_beta = self.beta(log_y, labels)
        log_p = logsumexp(log_alpha[-1, -1], log_alpha[-1, -2])
        ##任意時刻的
        log_grad = np.ones([T, V]) * ninf
        for t in range(T):
            for s in range(V):
                lab = [i for i, c in enumerate(labels) if c == s]
                for i in lab:
                    log_grad[t, s] = logsumexp(log_grad[t, s],
                                            log_alpha[t, i] + log_beta[t, i])
                log_grad[t, s] -= 2 * log_y[t, s]

        log_grad -= log_p
        return log_grad

    def predict(self):
        pass

    def ctc_prefix(self):
        pass

    def ctc_beamsearch(self):
        pass

    def alpha_vanilla(self, y, labels):
        T, V = y.shape  # T,time step, V: probs
        L = len(labels) # label length
        alpha = np.zeros([T, L])

        # init
        alpha[0, 0] = y[0, labels[0]]
        alpha[0, 1] = y[0, labels[1]]

        for t in range(1, T):
            for i in range(L):
                s = labels[i]

                a = alpha[t - 1, i]
                if i - 1 >= 0:
                    a += alpha[t - 1, i - 1]
                if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
                    a += alpha[t - 1, i - 2]

                alpha[t, i] = a * y[t, s]

        return alpha

    def beta_vanilla(self, y, labels):
        ##原始版計算前向概率,沒在對數(shù)域中計算
        T, V = y.shape
        L = len(labels)
        beta = np.zeros([T, L])

        # init
        beta[-1, -1] = y[-1, labels[-1]]
        beta[-1, -2] = y[-1, labels[-2]]

        for t in range(T - 2, -1, -1):
            for i in range(L):
                s = labels[i]

                a = beta[t + 1, i]
                if i + 1 < L:
                    a += beta[t + 1, i + 1]
                if i + 2 < L and s != 0 and s != labels[i + 2]:
                    a += beta[t + 1, i + 2]

                beta[t, i] = a * y[t, s]

        return beta

    def gradient(self, y, labels):
        T, V = y.shape
        L = len(labels)

        alpha = self.alpha_vanilla(y, labels)
        beta = self.beta(y, labels)
        p = alpha[-1, -1] + alpha[-1, -2]

        grad = np.zeros([T, V])
        for t in range(T):
            for s in range(V):
                lab = [i for i, c in enumerate(labels) if c == s]
                for i in lab:
                    grad[t, s] += alpha[t, i] * beta[t, i]
                grad[t, s] /= y[t, s] ** 2

        grad /= p
        return grad



 def check_grad(y, labels, w=-1, v=-1, toleration=1e-3):
    grad_1 = gradient(y, labels)[w, v]

    delta = 1e-10
    original = y[w, v]

    y[w, v] = original + delta
    alpha = forward(y, labels)
    log_p1 = np.log(alpha[-1, -1] + alpha[-1, -2])

    y[w, v] = original - delta
    alpha = forward(y, labels)
    log_p2 = np.log(alpha[-1, -1] + alpha[-1, -2])

    y[w, v] = original

    grad_2 = (log_p1 - log_p2) / (2 * delta)
    if np.abs(grad_1 - grad_2) > toleration:
        print('[%d, %d]:%.2e' % (w, v, np.abs(grad_1 - grad_2)))


def remove_blank(labels, blank=0):
    new_labels = []

    # combine duplicate
    previous = None
    for l in labels:
        if l != previous:
            new_labels.append(l)
            previous = l

    # remove blank
    new_labels = [l for l in new_labels if l != blank]

    return new_labels

def insert_blank(labels, blank=0):
    new_labels = [blank]
    for l in labels:
        new_labels += [l, blank]
    return new_labels

def greedy_decode(y, blank=0):
    raw_rs = np.argmax(y, axis=1)
    rs = remove_blank(raw_rs, blank)
    return raw_rs, rs

def beam_decode(y, beam_size=10):
    T, V = y.shape
    log_y = np.log(y)

    beam = [([], 0)]
    for t in range(T):  # for every timestep
        new_beam = []
        for prefix, score in beam:
            for i in range(V):  # for every state
                new_prefix = prefix + [i]
                new_score = score + log_y[t, i]

                new_beam.append((new_prefix, new_score))

        # top beam_size
        new_beam.sort(key=lambda x: x[1], reverse=True)
        beam = new_beam[:beam_size]

    return beam

def prefix_beam_decode(y, beam_size=10, blank=0):
    T, V = y.shape
    log_y = np.log(y)

    beam = [(tuple(), (0, ninf))]  # blank, non-blank
    for t in range(T):  # for every timestep
        new_beam = defaultdict(lambda : (ninf, ninf))

        for prefix, (p_b, p_nb) in beam:
            for i in range(V):  # for every state
                p = log_y[t, i]

                if i == blank:  # propose a blank
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)
                    continue
                else:  # extend with non-blank
                    end_t = prefix[-1] if prefix else None

                    # exntend current prefix
                    new_prefix = prefix + (i,)
                    new_p_b, new_p_nb = new_beam[new_prefix]
                    if i != end_t:
                        new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
                    else:
                        new_p_nb = logsumexp(new_p_nb, p_b + p)
                    new_beam[new_prefix] = (new_p_b, new_p_nb)

                    # keep current prefix
                    if i == end_t:
                        new_p_b, new_p_nb = new_beam[prefix]
                        new_p_nb = logsumexp(new_p_nb, p_nb + p)
                        new_beam[prefix] = (new_p_b, new_p_nb)

        # top beam_size
        beam = sorted(new_beam.items(), key=lambda x : logsumexp(*x[1]), reverse=True)
        beam = beam[:beam_size]

    return beam
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容