SSE圖像算法優化系列十六:經典USM銳化中的分支判斷語句SSE實現的幾種方法嘗試。


  分支判斷的語句一般來說是不太適合進行SSE優化的,因為他會破壞代碼的並行性,但是也不是所有的都是這樣的,在合適的場景中運用SSE還是能對分支預測進行一定的優化的,我們這里以某一個算法的部分代碼為例進行講解。

  在某一個版本的USM銳化算法中有這樣的一段代碼:

int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold) { int Channel = Stride / Width; if ((Src == NULL) || (Dest == NULL))                                return IM_STATUS_NULLREFRENCE; if ((Width <= 0) || (Height <= 0))                                  return IM_STATUS_INVALIDPARAMETER; if ((Channel != 1) && (Channel != 3) && (Channel != 4))             return IM_STATUS_INVALIDPARAMETER; int Status = IM_STATUS_OK; Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius);   //  這里標准過程是用IM_GaussBlur代替 if (Status != IM_STATUS_OK)    return Status; const float Inv255 = 1.0f / 255.0f; int *Table = (int *)malloc(511 * 256 * sizeof(int)); if (Table == NULL)    return IM_STATUS_OUTOFMEMORY; for (int Y = 0; Y < 256; Y++) { float TempUp = Amount * sqrtf(1.0f - Y * Inv255) / 100.0f; float TempDown = Amount * sqrtf(Y * Inv255) / 100.0f; for (int X = -255; X <= 255; X++) { int Diff = X; if (Diff >= Threshold) { Diff -= Threshold; Table[((X + 255) << 8) + Y] = IM_ClampToByte(int(Diff * TempUp + 0.5f) + Y); } else if (Diff < -Threshold) { Diff += Threshold; Table[((X + 255) << 8) + Y] = IM_ClampToByte(int(Diff * TempDown + 0.5f) + Y); } else { Table[((X + 255) << 8) + Y] = Y;        // 不做變化
 } } } for (int Y = 0; Y < Height * Stride; Y++)            // 分四路並行速度只有一點點提高
 { Dest[Y] = Table[((Src[Y] - Dest[Y] + 255) << 8) + Src[Y]]; } free(Table); return IM_STATUS_OK; }

  這個USM銳化的算法參考自:https://github.com/pluginguy/plugins/tree/master/USM2,源代碼中的算法還提供了對高光、暗調和中間調進行不同調節的參數,我這里對他那個代碼進行了適度的修改和簡化,並且用查找表進行了優化。這個github的作者還提供了關於高斯模糊方面的資料,是個不錯的參考點。

  上述代碼起始已經很高效了,復雜的浮點和開方計算都已經用查表的形式進行了簡化,實測一副1080P的24位圖像大處理時間大約在14.5ms左右,而其中的IM_ExpBlur耗時約有6.75ms,建立查找表花了0.75ms,后面的遍歷圖像進行查找表替換使用了7ms,注意前面的IM_ExpBlur的時間是已經進行了SSE編碼后的優化時間。

  查找表其實本身也是個耗時的工作,因為這個可能有着嚴重的cache miss,特別是查找表比較大的時候。但是查找表本身呢在目前SIMD框架下是無法使用SSE優化的(除非是16個字節的查找表,可以使用_mm_shuffle_epi8來優化),因此,如果查找表本身的建立算法並不特別復雜,是可以考慮使用SSE來對表中每個元素進行直接的實現的,鑒於此,我們來考慮上述代碼的查找表的直接SSE實現。

  為了表示清楚,我們把上述算法的非查找表方式實現的代碼整理出來如下:

int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold)
{
    int Channel = Stride / Width;
    if ((Src == NULL) || (Dest == NULL))                                  return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0))                                     return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3) && (Channel != 4))                 return IM_STATUS_INVALIDPARAMETER;
    int Status = IM_STATUS_OK;

    Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius);        //    這里標准過程是用IM_GaussBlur代替
    if (Status != IM_STATUS_OK)    return Status;

    float Adjust = Amount / 100.0f / sqrtf(255.0f);
    for (int Y = 0; Y < Height * Stride; Y++)                        
    {
        int Diff = Src[Y] - Dest[Y];
        if (Diff >= Threshold)
        {
            Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]);
        }
        else if (Diff < -Threshold)
        {
            Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf((float)Src[Y]) + 0.5f) + Src[Y]);
        }
        else
        {
            Dest[Y] = Src[Y];        //    不做變化
        }
    }
    return IM_STATUS_OK;
}

  注意為減少計算我已經把一些重復的計算提取到Adjust變量中,其中的/sqrtf(255.0f)可以讓循環內部的sqrtf的參數少一次乘法計算,並且在后面我們還可以看到他起到了另外一個特殊的作用。運行上述代碼的同參數同照片耗時變為了55ms左右,可見查找表的優化也是很給力的。

  我注意到這段代碼已經有很久了,也一直想使用SSE優化他們,但苦於能力,一直未得良方,不過最近過年重新審視這段代碼,發現只要手指按住鍵盤,總會有新大陸發現的。

  第一方案:既然SSE不太好做分支判斷,我就把所有分支的結果都計算出來,最后再根據分支條件做數據融合不就可以了嗎,可以肯定SSE計算每個分支的速度肯定比C快,但是如果要每個分支都計算,這個增加的耗時和加速的時間比例如何呢,只有實踐才知道,於是我硬着頭皮把他們用SSE做個硬編碼,代碼如下所示:

//    實在沒有好的辦法,極端情況下把所有的分支的結果都算出來,然后在最后根據判斷條件合成,比如下面的代碼,寫出來后比原始的查找表方式也還是要快一點的。
int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold)
{
    int Channel = Stride / Width;
    if ((Src == NULL) || (Dest == NULL))                                return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0))                                    return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3) && (Channel != 4))                return IM_STATUS_INVALIDPARAMETER;
    int Status = IM_STATUS_OK;

    Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius);
    if (Status != IM_STATUS_OK)    return Status;

    const float Adjust = Amount / 100.0f / sqrt(255.0f);
    const int BlockSize = 8;
    int Block = (Height * Stride) / BlockSize;

    const __m128i Zero = _mm_setzero_si128();
    const __m128i ThresholdV = _mm_set1_epi16(Threshold);
    const __m128i MinusThresholdV = _mm_set1_epi16(-Threshold);
    const __m128i One = _mm_set1_epi16(1);
    const __m128i MinusOne = _mm_set1_epi16(-1);
    const __m128 Const255 = _mm_set1_ps(255.0f);
    const __m128 AdjustV = _mm_set1_ps(Adjust);

    for (int Y = 0; Y < Block * BlockSize; Y += BlockSize)
    {
        __m128i SrcV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Src + Y)), Zero);
        __m128i DstV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Dest + Y)), Zero);
        __m128 SrcL = _mm_cvtepi32_ps(_mm_unpacklo_epi8(SrcV, Zero));
        __m128 SrcH = _mm_cvtepi32_ps(_mm_unpackhi_epi8(SrcV, Zero));
        __m128i Diff = _mm_sub_epi16(SrcV, DstV);
        __m128i DiffA = _mm_add_epi16(Diff, ThresholdV);
        __m128i DiffS = _mm_sub_epi16(Diff, ThresholdV);
        __m128 DiffL = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(Diff));
        __m128 DiffH = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(Diff, 8)));

        __m128 UpL = _mm_mul_ps(AdjustV, _mm_sqrt_ps(_mm_sub_ps(Const255, SrcL)));
        __m128 UpH = _mm_mul_ps(AdjustV, _mm_sqrt_ps(_mm_sub_ps(Const255, SrcH)));
        __m128 DownL = _mm_mul_ps(AdjustV, _mm_sqrt_ps(SrcL));
        __m128 DownH = _mm_mul_ps(AdjustV, _mm_sqrt_ps(SrcH));

        __m128 DiffUpL = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(DiffS)), UpL);
        __m128 DiffUpH = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(DiffS, 8))), UpH);
        __m128 DiffDownL = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(DiffA)), DownL);
        __m128 DiffDownH = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(DiffA, 8))), DownH);

        __m128i DiffUp = _mm_adds_epi16(_mm_packs_epi32(_mm_cvtps_epi32(DiffUpL), _mm_cvtps_epi32(DiffUpH)), SrcV);
        __m128i DiffDown = _mm_adds_epi16(_mm_packs_epi32(_mm_cvtps_epi32(DiffDownL), _mm_cvtps_epi32(DiffDownH)), SrcV);

        __m128i DestV = _mm_blendv_si128(SrcV, DiffUp, _mm_cmpgt_epi16(Diff, ThresholdV));
        DestV = _mm_blendv_si128(DestV, DiffDown, _mm_cmplt_epi16(Diff, MinusThresholdV));
_mm_storel_epi64((__m128i
*)(Dest + Y), _mm_packus_epi16(DestV, Zero)); } for (int Y = Block * BlockSize; Y < Height * Stride; Y++) { int Diff = Src[Y] - Dest[Y]; if (Diff >= Threshold) { Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]); } else if (Diff < -Threshold) { Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf(0.0f + Src[Y]) + 0.5f) + Src[Y]); } else { Dest[Y] = Src[Y]; } } return IM_STATUS_OK; }

  上述代碼基本就是普通C語言的翻譯,這里講幾個需要注意的地方。

  第一、_mm_cvtepi16_epi32這是個講signed short轉換為signed int的函數,只處理XMM寄存的低8位,如果需要將高8位也進行轉換,就必須得配合_mm_srli_si128一起使用,如果需要轉換的signed short能確認是大於等於0的,也可以使用_mm_unpacklo_epi16及_mm_unpackhi_epi16配合_mm_setzero_si128來實現,比如上面的SrcL和SrcH就是使用的這個技巧,但是如果有小於0的情況出現,一定只能用_mm_cvtepi16_epi32來實現,比如上面的DiffL和DiffH,我以前在這個上面吃過很多虧。

  第二、在計算DiffUp和DiffDown這兩個結果時,注意需要使用_mm_packs_epi32,而不是_mm_packus_epi32,因為計算結果是有負數存在的。

  第三、結果的融合這里的技巧很好,我們知道SSE4提供了兩個__m128i變量融合的函數,比如_mm_blendv_epi8,但是他要求最后的融合選項是個常數,而我們這里的融合選項是變化的,所以無法使用,我們使用了一個叫做_mm_blendv_si128的內聯函數,這個函數用一個__m128i變量作為融合參數,對128個位進行融合,其代碼如下:

static inline __m128i _mm_blendv_si128(__m128i x, __m128i y, __m128i mask)
{
    return _mm_or_si128(_mm_andnot_si128(mask, x), _mm_and_si128(mask, y));
}

  當mask的某一位為0時,選擇x中的對應位的值,否則選擇y中對應位的值。

  這個函數正是我需要的,而且恰好前幾天在瀏覽文章:A few missing SSE intrinsics發現了他,有的時候真的覺得處處留心皆學問啊。

  這時我們來看下上面的融合的代碼:__m128i DestV = _mm_blendv_si128(SrcV, DiffUp, _mm_cmpgt_epi16(Diff, ThresholdV));

  后面的_mm_cmpgt_epi16的比較函數會返回一個__m128i變量,當Diff > Threshold時,對應的16位數據為0xFFFF,否則為0,這樣我們使用_mm_blendv_si128融合時,滿足條件的部分結果就為DiffUp了,其他部分還保持SrcV不變。

  接着 DestV = _mm_blendv_si128(DestV, DiffDown, _mm_cmplt_epi16(Diff, MinusThresholdV)); 使用Diff < -Threshold作為判斷條件,因為該條件和Diff > Threshold不可能同時成立,所以_mm_cmplt_epi16的返回結果中的為true的部分和_mm_cmpgt_epi16返回的true部分的值不可能重疊,因此,再次執行_mm_blendv_si128混合的值就是我們融合的正確結果。

  那么我們最關心的速度來了,經過測試,上述算法對1080P彩色圖能達到約14ms的執行速度,和查找表的C語言版本速度差不多,唯一的優勢就是運算時少占用了一部分內存。但是同時也說明SSE的計算能力真的不是蓋的,算一算,正正的SSE執行時間實際上只有14-6.75 =7.25ms,而不用查找表的C代碼的用時為55-6.75=48.25ms,達到了進7倍的提速比,但這就是我們的終點了嗎?

  第二方案:我們在仔細觀察下Diff > Threshold和Diff < -Threshold時計算的不同,第一個不同是Diff > Threshold時使用了Diff - Threshold,而Diff < -Threshold時使用了Diff + Threshold;第二個不同為Diff > Threshold時使用了255.0f - Src[Y]作為開平方的算式,而Diff < -Threshold時使用了 Src[Y]。關於第一個不同,我們可以看到僅僅是個符號位不同,如果在Threshold前面根據不同的條件加個符號位在進行乘法不就可以了,也就是說如果我們根據Diff和Threshold的關系構建一個-1和1的中間變量,則可以把他們寫在一個式子里,那這樣的符號為要如何構建呢?

  自然而然我們又想到了上述方法的_mm_blendv_si128,簡單的方式如下所示:

__m128i Sign = _mm_blendv_si128(Zero, MinusOne, _mm_cmpgt_epi16(Diff, ThresholdV));
        Sign = _mm_blendv_si128(Sign, One, _mm_cmplt_epi16(Diff, MinusThresholdV));

  Zero,MinusOne,One這個還需不需要解釋,上面的代碼還需不需要解釋?

  第二個不同,我們這樣看,我們把它們放在一起 255.0f - Src[Y]  |  Src[Y],稍微改寫一下255 - Src[Y]  | 0 -  Src[Y],后面的+和-可以用類似前面的同樣的方法處理,我們還需處理255和0,如果我們能夠根據判斷條件構造出255 和 0這樣的序列,那是不是就解決問題了,如何構造?

  前面說過,_mm_cmpgt_epi16會返回0xFFFF和0,看成unsigned short類型則為65535和0, 如果我們把這個返回結果右移8位,是不是就變為了255和0呢,明白了嗎?

  最后我們注意一點,當-Threshold < Diff <Threshold時,我們的返回的是原圖像的值,那在這種情況下是不是有問題呢,其實也不會,我們注意到此條件下Sign對應的符號位為0,而_mm_cmpgt_epi16返回的那部分數據也為0,也就是說此時對應的sqrt參數為0,那么作為乘法的一部分,整個前面的算式就為0,結果返回的恰好是原值。

  我們還來在說下前面的符號問題,正或者負某個數,直接用符號位加乘法固然是可以實現的,但是有么有其他的方式更好的實現呢,翻一番SSE的手冊,我們會發現有_mm_sign_epi8 、_mm_sign_epi16 、_mm_sign_epi32 這樣的函數,他們是干什么的呢,我們以_mm_sign_epi16為例,看看他的文檔說明:

extern __m128i _mm_sign_epi16 (__m128i a, __m128i b); 
Negate packed words in a if corresponding sign in b is less than zero. 
Interpreting a, b, and r as arrays of signed 16-bit integers: 
for (i = 0; i < 8; i++)
{ 
    if (b[i] < 0) 
    { 
        r[i] = -a[i]; 
    } 
    else if (b[i] == 0)
    { 
        r[i] = 0; 
    } 
    else 
    { 
        r[i] = a[i]; 
    } 
}

  什么意思,就是以參數b的符號位來決定a的值,當b為負數是,對a求反,當b為0時,a也為0,否則a值保持不變。這不就可以直接實現上述的符號位的問題了嗎?

  說了那么多,我貼出代碼大家看一看:

int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold)
{
    int Channel = Stride / Width;
    if ((Src == NULL) || (Dest == NULL))                                return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0))                                    return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3) && (Channel != 4))                return IM_STATUS_INVALIDPARAMETER;
    int Status = IM_STATUS_OK;

    Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius);
    if (Status != IM_STATUS_OK)    return Status;

    const float Adjust = Amount / 100.0f / sqrt(255.0f);
    const int BlockSize = 8;
    int Block = (Height * Stride) / BlockSize;

    const __m128i Zero = _mm_setzero_si128();
    const __m128i ThresholdV = _mm_set1_epi16(Threshold);
    const __m128i MinusThresholdV = _mm_set1_epi16(-Threshold);
    const __m128i MinusOne = _mm_set1_epi16(-1);
    const __m128 AdjustV = _mm_set1_ps(Adjust);
    const __m128i One = _mm_set1_epi16(1);
    for (int Y = 0; Y < Block * BlockSize; Y += BlockSize)
    {
        __m128i SrcV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Src + Y)), Zero);
        __m128i DstV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Dest + Y)), Zero);
        __m128i Diff = _mm_sub_epi16(SrcV, DstV);                                                //    int Diff = Src[Y] - Dest[Y];
        
        //    當Diff > ThresholdV時,Sign設置為負數,當Diff < -ThresholdV時,Sign設置為正數,
        //    介於-ThresholdV和ThresholdV之間時為0,這里One和MinusOne只是取得一個代表性的值

        __m128i SignA = _mm_cmpgt_epi16(Diff, ThresholdV);                           
        __m128i SignB = _mm_cmplt_epi16(Diff, MinusThresholdV);                        
        __m128i Sign = _mm_blendv_si128(Zero, MinusOne, SignA);
        Sign = _mm_blendv_si128(Sign, One, SignB);
            
        //    Diff 為不同值時,NewDiff需要帶上不同符號,利用上面的Sign配合_mm_sign_epi16能很好的解決問題
        __m128i NewDiff = _mm_add_epi16(Diff, _mm_sign_epi16(ThresholdV, Sign));

        //    _mm_cmpgt_epi16返回0xfffff和0兩種值,我們這里需要的是0xff和0,因此需要進行下移位,注意此時在Diff < Threshold(Sign為0或者1時)
        //    _mm_add_epi16的第一個參數都是0,而第二個參數對於Sign為0的情況則也返回0,這樣0+0正好為0,Sqrt后也為0,對結果正好沒有影響(巧合還是天意?)
        __m128i NewPower = _mm_add_epi16(_mm_srli_epi16(SignA, 8), _mm_sign_epi16(SrcV, Sign));

        //    注意這里有負數存在,則必須用這種強制轉換函數
        __m128 NewDiffL = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(NewDiff));                                    
        __m128 NewDiffH = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(NewDiff, 8)));

        //    都是正數就可以這樣轉化
        __m128 NewPowerL = _mm_cvtepi32_ps(_mm_unpacklo_epi16(NewPower, Zero));                            
        __m128 NewPowerH = _mm_cvtepi32_ps(_mm_unpackhi_epi16(NewPower, Zero));

        //    按公式計算結果
        __m128 DstL = _mm_mul_ps(_mm_mul_ps(AdjustV, NewDiffL), _mm_sqrt_ps(NewPowerL));
        __m128 DstH = _mm_mul_ps(_mm_mul_ps(AdjustV, NewDiffH), _mm_sqrt_ps(NewPowerH));

        //    合成到16位的結果,注意這里不要用_mm_packus_epi32,因為后面還有一個加法要進行
        __m128i Result = _mm_packs_epi32(_mm_cvtps_epi32(DstL), _mm_cvtps_epi32(DstH));                    

        //    合成到8位的結果,注意這要用抗飽和的加法_mm_adds_epi16
        _mm_storel_epi64((__m128i *)(Dest + Y), _mm_packus_epi16(_mm_adds_epi16(Result, SrcV), Zero));
    }

    for (int Y = Block * BlockSize; Y < Height * Stride; Y++)
    {
        int Diff = Src[Y] - Dest[Y];
        if (Diff >= Threshold)
        {
            Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]);
        }
        else if (Diff < -Threshold)
        {
            Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf(0.0f + Src[Y]) + 0.5f) + Src[Y]);
        }
        else
        {
            Dest[Y] = IM_ClampToByte(int(Diff * Adjust * sqrtf(0.0f + 0.0f) + 0.5f) + Src[Y]);        //    不做變化
        }
    }

    return IM_STATUS_OK;
}

  最后回到我們關心的速度問題上去,經過上述優化后能達到的速度平均值在11.5ms左右,比查找表版本的還要快了3ms左右。

  實際上上述求Sign的過程還有更為簡單的優化過程的,想通了也很有道理,這個留個讀者自行去研究,大概能加快0.4ms左右的速度。

  關於分支預測的SSE優化,目前我掌握的技巧也就這么多,管件還是要看算法本身,有的時候要脫離原始算法,為了能用SSE而稍微改變下算法的外表。這就各位神仙各顯神通了,當然有很多分支預測由於太復雜還是不能夠用SIMD指令優化的。

  最后說一句,關於Photoshop的標准USM銳化並不是使用的上述算法,其原理應該說比上面的還要簡單,但也不是網絡上流行的那個計算公式,我已經通過測試推到得到了和其一模一樣的計算式,這里不提,不過呢,為什么非要一樣呢,這里的這個算法也是不錯的。

  算法Demo下載地址:https://files.cnblogs.com/files/Imageshop/SSE_Optimization_Demo.rar

 

      

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM