玩玩24點(中)


《玩玩24點》系列:

在上篇中,我用上位機程序遍歷了4個1~13的數的1820種組合,通過遞歸窮舉計算出其中1362組的24點接法,並轉換為二進制形式,放到單片機程序中,減少了單片機24點游戲程序的計算量,獲得了不錯的游戲體驗。

上篇的最后留了一個瘋狂暗示,但時至如今我也沒有實現出來,因為寫完上篇過后一直在准備各種比賽和考試,這兩天也在寫AVR單片機教程,一直都沒有空去管它。

寫這篇中篇的原因,是幾個沒有作業寫甚至不需要高考的同學在玩一種24點游戲的升級版——用計算器按出5個1~20的隨機整數,通過四則運算獲得不超過50的最大有理數。經過一整個晚自修的手算后,他們想起我之前寫的24點,來問我他們算出的是不是上界。

我寫算法注重可復用性,畢竟不是std::都不寫的OI。於是我很快就在上次程序的基礎上寫成了他們要的算法。

這個程序,以及人機計算能力的對比,雖然毫無懸念,但是先放一邊。我對上篇所寫的內容有一些更深的思考。

算式的可讀性

實際上這個24點程序還遠不完美。單片機經常在屏幕上輸出詭異的解法,比如10 * 12 = 120, 120 / 5 = 24,這些是不符合人類計算邏輯的,正常人想到的都是10 / 5 = 2, 2 * 12 = 24。一個可行的方法是把遞歸搜索的順序換一下,先減再加,先除后乘,在除法中優先用最大的數除以最小的數。但還是會出現12 / 5 = 12/5, 12/5 * 10 = 24這樣的式子,最根本的算法還是根據表達式建立樹,在樹上調整順序。也許4個數算24點的情況不需要這么復雜,但這是萬能的、具有可擴展性的做法(也有可能是我想多了)。

這是上篇中提出的問題與解決方案,現在我認為需要修改。

首先,對於5, 10, 12的例子,我已經找到簡單方法來使程序輸出符合人類邏輯的算式了:搜索順序改為減法、加法、結果為整數的除法、乘法、結果為分數的除法(代碼可以在后面的程序中找到,這里就不單獨放了)。在更新算法后我試玩了幾十組,發現程序給出的結果都是比較正常的,因此這個問題至少在4數24點的問題中算是解決了。

其次,作為看似更好的算法,即使我能克服學數據結構時對樹的恐懼,成功地用二叉樹表達了算式,“在樹上調整順序”的概念也是模糊的。用什么規則來調整呢?如果是整數優先,那么10 / 5可以保證,但是在新的游戲規則中,如果運算數是2, 3, 33,最優結果是99/2,程序會先計算33 * 3,再計算99 / 2,而我的思路會是33 * 1.5。那么這算什么規則呢?其他的情況呢?理不清。

所以,調整一下搜索順序,見好就收吧。

4數24點的優化

一位對計算機程序一無所知的數學競賽同學對求解24點的算法十分感興趣。在我絞盡腦汁跟他解釋通這個程序后,他認為這個算法不好,因為有大量的重復計算。

有道理。比方說1, 2, 3,原來的算法會先算1 + 2,替換為3,用3, 3遞歸調用,得到6,這是1 + 2 + 3,然后還有1 + 3 + 22 + 3 + 11, 2, 3, 4就更多了。

他提出“分治”的策略:24一定是由兩個中間結果加減乘除得到的,而每個中間結果也都是由兩個運算數得到的。在為他憑空想出分治而震驚之余,我指出這是錯的,這很顯然。

但這個想法還是有一定啟發性的。為了優化4數24點的求解算法,我想還不如枚舉出所有可能的運算結構算了:

  1. a * b * c * d

  2. a + b + c + d

  3. a * b + c + d

  4. a * b * (c + d)

  5. a * b * c + d

  6. a * (b + c + d)

  7. a * b + c * d

  8. (a + b) * (c + d)

  9. (a * b + c) * d

  10. (a + b) * c + d

其中+代表加或減,*代表乘或除。偶數序號的結構都是前一個奇數序號結構的對偶,指把加減與乘除互換,加括號保證原有的優先級。

inline bool read_bit(int c, int b)
{
    return c & (1 << b);
}

class fast_vector
{
public:
    void push_back(const Rational& r)
    {
        data[size++] = r;
    }
    Rational* begin()
    {
        return data;
    }
    Rational* end()
    {
        return data + size;
    }
private:
    Rational data[1 << max_count];
    int size = 0;
};

using vector_type = fast_vector;

void all_sum(const std::vector<Rational>& data, vector_type& result)
{
    auto end = (1 << data.size()) - 1;
    for (int c = 0; c != end; ++c)
    {
        Rational sum = 0;
        bool valid = true;
        for (int b = 0; b != data.size(); ++b)
            if (!read_bit(c, b))
                sum += data[b];
        for (int b = 0; b != data.size(); ++b)
            if (read_bit(c, b))
            {
                if (sum < data[b])
                {
                    valid = false;
                    break;
                }
                sum -= data[b];
            }
        if (valid)
            result.push_back(sum);
    }
}

void all_pro(const std::vector<Rational>& data, vector_type& result)
{
    auto end = (1 << data.size()) - 1;
    for (int c = 0; c != end; ++c)
    {
        Rational pro = 1;
        bool valid = true;
        for (int b = 0; b != data.size(); ++b)
        {
            if (read_bit(c, b))
            {
                if (data[b] == 0)
                {
                    valid = false;
                    break;
                }
                pro /= data[b];
            }
            else
                pro *= data[b];
        }
        if (valid)
            result.push_back(pro);
    }
}

bool test_sum(const Rational& lhs, const Rational& rhs)
{
    if (lhs + rhs == target)
        return true;
    if (lhs < rhs && rhs - lhs == target)
        return true;
    if (rhs < lhs && lhs - rhs == target)
        return true;
    return false;
}

bool test_pro(const Rational& lhs, const Rational& rhs)
{
    if (lhs * rhs == target)
        return true;
    if (rhs != 0 && rhs / lhs == target)
        return true;
    if (lhs != 0 && lhs / rhs == target)
        return true;
    return false;
}

bool solve(int a, int b, int c, int d)
{
    std::vector<Rational> data(4);
    data[0] = a;
    data[1] = b;
    data[2] = c;
    data[3] = d;

    // a * b * c * d
    {
        vector_type pro;
        all_pro(data, pro);
        for (const auto& r : pro)
            if (r == target)
                return true;
    }

    // a + b + c + d
    {
        vector_type sum;
        all_sum(data, sum);
        for (const auto& r : sum)
            if (r == target)
                return true;
    }

    // a * b + c + d
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto pm = data;
            pm.erase(pm.begin() + j);
            pm.erase(pm.begin() + i);
            std::vector<Rational> md{ data[i], data[j] };
            vector_type pro;
            all_pro(md, pro);
            for (const auto& r : pro)
            {
                pm.push_back(r);
                vector_type sum;
                all_sum(pm, sum);
                for (const auto& r : sum)
                    if (r == target)
                        return true;
                pm.pop_back();
            }
        }

    // a * b * (c + d)
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto md = data;
            md.erase(md.begin() + j);
            md.erase(md.begin() + i);
            std::vector<Rational> pm{ data[i], data[j] };
            vector_type sum;
            all_sum(pm, sum);
            for (const auto& r : sum)
            {
                md.push_back(r);
                vector_type pro;
                all_pro(md, pro);
                for (const auto& r : pro)
                    if (r == target)
                        return true;
                md.pop_back();
            }
        }

    // a * b * c + d
    for (int i = 0; i != 4; ++i)
    {
        auto md = data;
        md.erase(md.begin() + i);
        vector_type pro;
        all_pro(md, pro);
        for (const auto& r : pro)
            if (test_sum(data[i], r))
                return true;
    }

    // a * (b + c + d)
    for (int i = 0; i != 4; ++i)
    {
        auto pm = data;
        pm.erase(pm.begin() + i);
        vector_type sum;
        all_sum(pm, sum);
        for (const auto& r : sum)
            if (test_pro(data[i], r))
                return true;
    }

    // a * b + c * d
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto md2 = data;
            md2.erase(md2.begin() + j);
            md2.erase(md2.begin() + i);
            decltype(md2) md1{ data[i], data[j] };
            vector_type pro1, pro2;
            all_pro(md1, pro1);
            all_pro(md2, pro2);
            for (const auto& r1 : pro1)
                for (const auto& r2 : pro2)
                    if (test_sum(r1, r2))
                        return true;
        }

    // (a + b) * (c + d)
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto pm2 = data;
            pm2.erase(pm2.begin() + j);
            pm2.erase(pm2.begin() + i);
            decltype(pm2) pm1{ data[i], data[j] };
            vector_type sum1, sum2;
            all_sum(pm1, sum1);
            all_sum(pm2, sum2);
            for (const auto& r1 : sum1)
                for (const auto& r2 : sum2)
                    if (test_pro(r1, r2))
                        return true;
        }

    // (a * b + c) * d
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto rest = data;
            rest.erase(rest.begin() + j);
            rest.erase(rest.begin() + i);
            std::vector<Rational> md{ data[i], data[j] };
            vector_type pro;
            all_pro(md, pro);
            for (const auto& r : pro)
            {
                for (int k = 0; k != 2; ++k)
                {
                    std::vector<Rational> pm{ r, rest[k] };
                    vector_type sum;
                    all_sum(pm, sum);
                    for (const auto& r : sum)
                        if (test_pro(r, rest[1 - k]))
                            return true;
                }
            }
        }

    // (a + b) * c + d
    for (int i = 0; i != 3; ++i)
        for (int j = i + 1; j != 4; ++j)
        {
            auto rest = data;
            rest.erase(rest.begin() + j);
            rest.erase(rest.begin() + i);
            std::vector<Rational> pm{ data[i], data[j] };
            vector_type sum;
            all_sum(pm, sum);
            for (const auto& r : sum)
            {
                for (int k = 0; k != 2; ++k)
                {
                    std::vector<Rational> md{ r, rest[k] };
                    vector_type pro;
                    all_pro(md, pro);
                    for (const auto& r : pro)
                        if (test_sum(r, rest[1 - k]))
                            return true;
                }
            }
        }

    return false;
}

int main()
{
    auto start_time = std::clock();
    int count = 0;
    for (int a = 1; a <= max_num; ++a)
        for (int b = a; b <= max_num; ++b)
            for (int c = b; c <= max_num; ++c)
                for (int d = c; d <= max_num; ++d)
                    if (solve(a, b, c, d))
                        ++count;
    std::cout << count << std::endl;
    std::cout << (static_cast<double>(std::clock()) - start_time) * 1000
        / CLOCKS_PER_SEC << "ms" << std::endl;
    return 0;
}

IntegerintRationalExpression的定義見上篇。

原算法沒有使用std::vector數據結構,由於STL的糟糕性能,我寫了個不涉及動態內存分配的fast_vector來替換存儲運算結果的std::vector;運算數的懶得改了。

算法的核心在於all_sum函數,用於求出data數組中的元素通過加減法可以得到的所有結果:

void all_sum(const std::vector<Rational>& data, vector_type& result)
{
    auto end = (1 << data.size()) - 1;
    for (int c = 0; c != end; ++c)
    {
        Rational sum = 0;
        bool valid = true;
        for (int b = 0; b != data.size(); ++b)
            if (!read_bit(c, b))
                sum += data[b];
        for (int b = 0; b != data.size(); ++b)
            if (read_bit(c, b))
            {
                if (sum < data[b])
                {
                    valid = false;
                    break;
                }
                sum -= data[b];
            }
        if (valid)
            result.push_back(sum);
    }
}

函數用一個整數c表示data數組中各元素取加號還是減號,當二進制c的第b位為0時(最低位為第0位),下標為b的元素取加號,否則取減號;c取不到0b11...1data.size()1),是因為不能所有元素都取減號。對於每個c,如果算出來的值是有效的,就把它追加到結果的數組中去。我把返回值寫成了引用參數,雖然編譯器很可能RVO(返回值優化),我還是手動寫出來以明確我提升性能的意圖。

all_pro函數類似,只不過計算的是積與商。

程序在VS2019中編譯,配置為Release、x86,在沒插電的最節能配置下的i7-7700HQ上測試,從命令行調用,優化算法的平均運行時間為55ms,而原算法為82ms,是有明顯提升的。

概率問題

在一篇研究24點游戲的文章中,有這樣一句話:

其實還有一個原因,就是有解的概率太小了。4個數字的話也就大約80%的題能算,如果算上人頭牌,可解的題就只有75%了。

沒錯,在1820種可能的4數組合中,有1362種有解,比例為74.8%。

但是注意,我說的是“比例”而不是“概率”,這兩者是有區別的。要計算“有解的概率”,必須先確定出題的方式。

如果是從1820道題目的題庫中等概率地選擇一道,類似與上篇中提到的單片機程序一樣,這樣每一道題被選中都是古典概型中的基本事件,有解概率就是74.8%。

如果是從52張撲克牌中等概率地選擇4張,那么概率就不是74.8%,因為每一種題目出現的概率是不相等的。比如,6, 6, 6, 6出現的概率為$1 / C_{52}^{4} $,而1, 2, 3, 4出現的概率為$4! / C_{52}^{4} $,兩者相差24倍。每一種4數的有序排列都是古典概型中的基本事件,有解概率需要重新計算。

std::set<std::vector<Integer>> solution;
int solved = 0;
int total = 0;
int card[4];
std::vector<Integer> comb(4);
for (card[0] = 0; card[0] != 49; ++card[0])
    for (card[1] = card[0] + 1; card[1] != 50; ++card[1])
        for (card[2] = card[1] + 1; card[2] != 51; ++card[2])
            for (card[3] = card[2] + 1; card[3] != 52; ++card[3])
            {
                ++total;
                for (int i = 0; i != 4; ++i)
                    comb[i] = card[i] / 4 + 1;
                if (solution.find(comb) != solution.end())
                    ++solved;
            }
std::cout << solved << " / " << total << std::endl;

其中,solution已經保存了有解的4數組合。程序的輸出為:

217817 / 270725

這個比例為80.5%,也是這種模型下有解的概率。

新款50點游戲

50點游戲的規則是,用5個1~20的整數通過四則運算得到不超過50的最大有理數。

為什么是50呢?如果是48的話,我想你也會問為什么是48的。唯一的一點道理,他們說,是這樣比較考驗一個人對數字的感覺。

上回4數的算法並不局限於4數,參數都可以通過全局變量來調整。50點相對於24點還改變了輸出結果的規則,但只需要修改遞歸出口的條件和操作。在那個程序的基礎上,50點很快就寫好了。

// return whether the branch has found a better solution
bool solve(Integer count, const Rational* data, const Rational target, Rational* max, Expression* expr)
{
    // assume data is ordered
    if (count == 1)
    {
        if (*data <= target && *data > *max)
        {
            *max = *data;
            return true;
        }
        else
            return false;
    }
    auto end = data + count;
    auto before_end = end - 1;
    --count;
    Rational new_data[max_count - 1];
    auto new_end = new_data + count;
    bool optimize = false;

    // -
    for (auto lhs = data + 1; lhs != end; ++lhs)
        for (auto rhs = data; rhs != lhs; ++rhs)
        {
            auto dst = new_data;
            for (auto src = data; src != end; ++src)
                if (src != lhs && src != rhs)
                    *dst++ = *src;
            *dst = *lhs - *rhs;
            Expression temp(*lhs, '-', *rhs, *dst);
            if (temp.rhs == 0)
            {
                std::swap(temp.lhs, temp.rhs);
                temp.op = '+';
            }
            std::sort(new_data, new_end);
            if (solve(count, new_data, target, max, expr + 1))
            {
                optimize = true;
                *expr = temp;
            }
        }

    // +
    for (auto lhs = data; lhs != before_end; ++lhs)
        for (auto rhs = lhs + 1; rhs != end; ++rhs)
        {
            auto dst = new_data;
            for (auto src = data; src != end; ++src)
                if (src != lhs && src != rhs)
                    *dst++ = *src;
            *dst = *lhs + *rhs;
            Expression temp(*lhs, '+', *rhs, *dst);
            std::sort(new_data, new_end);
            if (solve(count, new_data, target, max, expr + 1))
            {
                optimize = true;
                *expr = temp;
            }
        }

    // / integer
    struct
    {
        const Rational* lhs;
        const Rational* rhs;
        Rational res;
    } div_frac[max_count * (max_count - 1)];
    Integer frac_size = 0;
    for (auto lhs = before_end; lhs != data - 1; --lhs)
        for (auto rhs = data; rhs != end; ++rhs)
        {
            if (lhs == rhs || *rhs == Rational(0))
                continue;
            auto res = *lhs / *rhs;
            if (res.den != 1)
            {
                div_frac[frac_size].lhs = lhs;
                div_frac[frac_size].rhs = rhs;
                div_frac[frac_size++].res = res;
                continue;
            }
            auto dst = new_data;
            for (auto src = data; src != end; ++src)
                if (src != lhs && src != rhs)
                    *dst++ = *src;
            *dst = res;
            Expression temp(*lhs, '/', *rhs, *dst);
            if (temp.rhs == 1)
            {
                if (Rational(1) < temp.lhs)
                    std::swap(temp.lhs, temp.rhs);
                temp.op = '*';
            }
            std::sort(new_data, new_end);
            if (solve(count, new_data, target, max, expr + 1))
            {
                optimize = true;
                *expr = temp;
            }
        }

    // *
    for (auto lhs = data; lhs != before_end; ++lhs)
        for (auto rhs = lhs + 1; rhs != end; ++rhs)
        {
            auto dst = new_data;
            for (auto src = data; src != end; ++src)
                if (src != lhs && src != rhs)
                    *dst++ = *src;
            *dst = *lhs * *rhs;
            Expression temp(*lhs, '*', *rhs, *dst);
            std::sort(new_data, new_end);
            if (solve(count, new_data, target, max, expr + 1))
            {
                optimize = true;
                *expr = temp;
            }
        }

    // / fraction
    for (Integer i = 0; i != frac_size; ++i)
    {
        auto dst = new_data;
        for (auto src = data; src != end; ++src)
            if (src != div_frac[i].lhs && src != div_frac[i].rhs)
                *dst++ = *src;
        *dst = div_frac[i].res;
        Expression temp(*div_frac[i].lhs, '/', *div_frac[i].rhs, *dst);
        std::sort(new_data, new_end);
        if (solve(count, new_data, target, max, expr + 1))
        {
            optimize = true;
            *expr = temp;
        }
    }

    return optimize;
}

Rational test(Rational* operand, const Rational target, std::ostream& os = std::cout)
{
    Expression expr[4];
    Rational result = 0;
    solve(5, operand, target, &result, expr);
    for (int i = 0; i != 4; ++i)
        os << operand[i] << ", ";
    os << operand[4] << ": ";
    os << result;
    if (result.den != 1)
        os << " = " << (double)result.num / result.den;
    os << std::endl;
    for (const auto& e : expr)
        os << '\t' << e << std::endl;
    return result;
}

同樣,Integerint的類型別名,使程序可以處理100以內的整數;RationalExpression類的定義見上篇,不過去掉了對除法和取模運算的計數。

DFS依然有些難理解。24點中solve函數返回該路徑下是否能計算出24,如果得到true,則調用者solve本身把當前操作的表達式寫入expr數組,並直接return true,一路返回到test,並輸出解法。但是,50點不能把50設置為唯一的目標,而是在每次獲得結果時更新最優解。solve函數返回該路徑下能否找到更優的解,如果為true,則調用者solve同樣把當前操作的表達式寫入expr數組,但不返回,而是繼續試探下一路徑。

不返回是比較好理解的,因為找到的不一定是最優解。不過如果找到50,則可以一路返回到底,避免不必要的搜索。由於可以算出50的輸入占一大部分,這種優化可以顯著加速全部輸入的窮舉,which原本需要十分鍾。不過這一點是我一分鍾前剛想出來的,還沒放進代碼。

無論遞歸深度,當路徑中有更優解時,就立即更新expr數組,是正確的算法。這是因為,每一層的遞歸都只負責一個Expression空間,不同層互不干擾,因此這個寫入只會覆蓋本次調用中上一次寫入或之前調用寫入的表達式,其對應的結果沒有當前找到的優,因此可以放心覆蓋。由於從遞歸出口到最初調用的每一層調用都能得知這個最優解,因此最后獲得的表達式是完整的。

對於單組輸入,這個算法是NP的,對於所有輸入而言更是。所以運算數個數、范圍和運算符都受到嚴格限制,而且我感覺這個問題不會有P的算法。

5個1~20的數共有\(C_{20 + 5 - 1}^{5} = 42504\)種組合(沒有同一個數最多4個的限制),全部求解一遍需要十分鍾。下篇應該會解決一個規則更復雜的問題,由於5數50點已經跑得夠慢了,我決定這個寒假里學習並發。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM