使用ML.NET實現德州撲克牌型分類器


導讀:ML.NET系列文章

本文將基於ML.NET v0.2預覽版,重點介紹提取特征的思路和方法,實現德州撲克牌型分類器。

先介紹一下德州撲克的基本牌型,一手完整的牌共有五張撲克,10種牌型分別是:

1. 高牌,花色和點數同時沒有相同的牌。

2. 一對,點數有且僅有兩張相同的牌。

3. 兩對,兩張相同點數的牌,加另外兩張相同點數的牌。

4. 三條,有三張同一點數的牌。

5. 順子,五張順連的牌。

6. 同花,五張同一花色的牌。

7. 葫蘆,三張同一點數的牌,加一對其他點數的牌。

8. 四條,有四張同一點數的牌。

9. 同花順,同一花色五張順連的牌。

10. 皇家同花順,最高點數是A的同花順的牌。

這一次我們將使用邏輯回歸模型,來訓練數據完成我們想要的分類模型。

准備數據集


數據來源在Poker Hand Data Set,下載鏈接為:poker-hand-testing.datapoker-hand-training-true.data。內容類似如下:

3,92,3,3,2,2,9,3,5,1
4,4,1,11,2,9,4,13,2,7,0
1,5,1,9,2,8,2,4,4,3,0
4,12,4,7,4,5,2,10,2,2,0
4,3,2,4,4,13,3,6,4,12,0
2,5,4,5,4,1,4,9,2,7,1
2,12,3,12,3,7,3,11,2,7,2
4,13,2,6,4,6,4,10,4,9,1
...

說明一下每一行的格式:

第1張花色,第1張點數,第2張花色,第2張點數,第3張花色,第3張點數,第4張花色,第4張點數,第5張花色,第5張點數,牌型

花色是1-4代表紅心,黑桃,方塊,梅花。點數1表示A,2-10保持不變,11表示J,12表示Q,13表示K。

特征分析


前幾篇數據集的內容,基本上分割好就是特征了,這一次不同,每一行的數值僅僅是元數據,也就是說,通過五張牌的花色和點數值是不能直接和牌型形成一一對應的聯系,需要先按本文開頭介紹的10種牌型的描述,找到關鍵可數值化的字段。因此,我選擇了這樣一些字段:對子數,是否是三條,是否是四條,是否是順子,是否同花。通過這5個字段值的組合,一定能判斷出牌型。

於是,我定義出我想要的數據類型:

public class PokerHandData
{
    [Column(ordinal: "0")]
    public float S1;
    [Column(ordinal: "1")]
    public float C1;
    [Column(ordinal: "2")]
    public float S2;
    [Column(ordinal: "3")]
    public float C2;
    [Column(ordinal: "4")]
    public float S3;
    [Column(ordinal: "5")]
    public float C3;
    [Column(ordinal: "6")]
    public float S4;
    [Column(ordinal: "7")]
    public float C4;
    [Column(ordinal: "8")]
    public float S5;
    [Column(ordinal: "9")]
    public float C5;
    [Column(ordinal: "10", name: "Label")]
    public float Power;
[Column(ordinal:
"11")] public float IsSameSuit; [Column(ordinal: "12")] public float IsStraight; [Column(ordinal: "13")] public float FourOfKind; [Column(ordinal: "14")] public float ThreeOfKind; [Column(ordinal: "15")] public float PairsCount; }

S表示花色,C表示點數,Power表示牌型,PairsCount表示對子數,ThreeOfKind表示是否是三條,FourOfKind表示是否是四條,IsStraight表示是否順子,IsSameSuit表示是否同花。

判斷是否同花,只需要比較S1-S5的值即可。

public float GetIsSameSuit()
{
    if (S1 == S2 && S2 == S3 && S3 == S4 && S4 == S5)
        return 1;
    else
        return 0;
}

判斷其它幾個特征,我需要一個通用方法,先統計出每一行的C1-C5,每種點數出現的次數。

static Dictionary<int, int> GetValueCountsOfCondition(IEnumerable<int> cards)
{
    var dic = new Dictionary<int, int>();

    foreach (var item in cards)
    {
        if (dic.ContainsKey(item))
        {
            dic[item] += 1;
        }
        else
        {
            dic.Add(item, 1);
        }
    }
    return dic;
}

然后再按特征涵義計算值。

public float GetFourOfKind()
{
    return GetCountOfCondition(4);
}

public float GetThreeOfKind()
{
    return GetCountOfCondition(3);
}

public float GetPairsCount()
{
    return GetCountOfCondition(2);
}

private IEnumerable<int> GetCards()
{
    if (cards == null)
    {
        cards = new[] { Convert.ToInt32(C1), Convert.ToInt32(C2), Convert.ToInt32(C3), Convert.ToInt32(C4), Convert.ToInt32(C5) };
    }

    return cards;
}

private float GetCountOfCondition(int target)
{
    if (valueCounts == null)
    {
        valueCounts = GetValueCountsOfCondition(GetCards());
    }

    return valueCounts.Count(i => i.Value == target);
}

判斷是否為順子的方法,簡單而直接,就是看間隔差是不是為1,或者最高點有A剩下的必須是10、J、Q、K,都算順子。

public float GetIsStraight()
{
    var keys = GetCards().ToArray();
    Array.Sort(keys);
    if (keys[1] - keys[0] == keys[2] - keys[1] && keys[2] - keys[1] == keys[3] - keys[2] && keys[3] - keys[2] == keys[4] - keys[3] && keys[4] - keys[3] == 1)
    {
        return 1;
    }
    else if (keys[0] == 1 && keys[1] == 10 && keys[2] == 11 && keys[3] == 12 && keys[4] == 13)
    {
        return 1;
    }
    else
    {
        return 0;
    }
}

加載數據


這次由於使用了ML.NET v0.2,該版本的LearningPipeline新增了一種支持集合類型的數據源。因此,我將示范一種全新的載入數據集的方法,先以文件載入元數據,然后直接初始化特征的值。

static IEnumerable<PokerHandData> LoadData(string path)
{
    using (var environment = new TlcEnvironment())
    {
        var pokerHandData = new List<PokerHandData>();
        var textLoader = new Microsoft.ML.Data.TextLoader(path).CreateFrom<PokerHandData>(useHeader: false, separator: ',', trimWhitespace: false);
        var experiment = environment.CreateExperiment();
        var output = textLoader.ApplyStep(null, experiment) as ILearningPipelineDataStep;

        experiment.Compile();
        textLoader.SetInput(environment, experiment);
        experiment.Run();

        var data = experiment.GetOutput(output.Data);

        using (var cursor = data.GetRowCursor((a => true)))
        {
            var getters = new ValueGetter<float>[]{
                cursor.GetGetter<float>(0),
                cursor.GetGetter<float>(1),
                cursor.GetGetter<float>(2),
                cursor.GetGetter<float>(3),
                cursor.GetGetter<float>(4),
                cursor.GetGetter<float>(5),
                cursor.GetGetter<float>(6),
                cursor.GetGetter<float>(7),
                cursor.GetGetter<float>(8),
                cursor.GetGetter<float>(9),
                cursor.GetGetter<float>(10)
            };

            while (cursor.MoveNext())
            {
                float value0 = 0;
                float value1 = 0;
                float value2 = 0;
                float value3 = 0;
                float value4 = 0;
                float value5 = 0;
                float value6 = 0;
                float value7 = 0;
                float value8 = 0;
                float value9 = 0;
                float value10 = 0;
                getters[0](ref value0);
                getters[1](ref value1);
                getters[2](ref value2);
                getters[3](ref value3);
                getters[4](ref value4);
                getters[5](ref value5);
                getters[6](ref value6);
                getters[7](ref value7);
                getters[8](ref value8);
                getters[9](ref value9);
                getters[10](ref value10);

                var hands = new PokerHandData()
                {
                    S1 = value0,
                    C1 = value1,
                    S2 = value2,
                    C2 = value3,
                    S3 = value4,
                    C3 = value5,
                    S4 = value6,
                    C4 = value7,
                    S5 = value8,
                    C5 = value9,
                    Power = value10
                };
                hands.Init();
                pokerHandData.Add(hands);
            }
        }

        return pokerHandData;
    }
}

其中PokerHandData類增加一個初始化的方法。

public void Init()
{
    IsSameSuit = GetIsSameSuit();
    IsStraight = GetIsStraight();
    FourOfKind = GetFourOfKind();
    ThreeOfKind = GetThreeOfKind();
    PairsCount = GetPairsCount();
}

訓練模型


預測的結構定義,以計分為目標,float[]類型表示是對每一種牌型有一個得分,分值越高屬於那一種牌型的概率越大。

public class PokerHandPrediction
{
    [ColumnName("Score")]
    public float[] PredictedPower;
}

模型的選擇是LogisticRegressionClassifier,CollectionDataSource就是用來創建集合類型數據載入的對象。而特征的指定不再是全部字段,而是之前增加的那幾個。

public static PredictionModel<PokerHandData, PokerHandPrediction> Train(IEnumerable<PokerHandData> data)
{
    var pipeline = new LearningPipeline();
    var collection = CollectionDataSource.Create(data);
    pipeline.Add(collection);
    pipeline.Add(new ColumnConcatenator("Features", "IsSameSuit", "IsStraight", "FourOfKind", "ThreeOfKind", "PairsCount"));
    pipeline.Add(new LogisticRegressionClassifier());
    var model = pipeline.Train<PokerHandData, PokerHandPrediction>();
    return model;
}

預測結果


首先,對預測的得分,我們需要判斷一個概率傾向。

static string GetPower(float[] nums)
{
    var index = -1;
    var last = 0F;
    for (int i = 0; i < nums.Length; i++)
    {
        if (nums[i] > last)
        {
            index = i;
            last = nums[i];
        }
    }
var suit = string.Empty; switch (index) { case 0: suit = "高牌"; break; case 1: suit = "一對"; break; case 2: suit = "兩對"; break; case 3: suit = "三條"; break; case 4: suit = "順子"; break; case 5: suit = "同花"; break; case 6: suit = "葫蘆"; break; case 7: suit = "四條"; break; case 8: suit = "同花順"; break; case 9: suit = "皇家同花順"; break; } return suit; }

最后就是進行預測的部分了。

public static void Predict(PredictionModel<PokerHandData, PokerHandPrediction> model)
{
    var test1 = new PokerHandData
    {
        S1 = 1,
        C1 = 2,
        S2 = 1,
        C2 = 3,
        S3 = 3,
        C3 = 4,
        S4 = 4,
        C4 = 5,
        S5 = 2,
        C5 = 6
    };

    var test2 = new PokerHandData
    {
        S1 = 4,
        C1 = 5,
        S2 = 1,
        C2 = 5,
        S3 = 3,
        C3 = 5,
        S4 = 2,
        C4 = 12,
        S5 = 4,
        C5 = 7
    };
    test1.Init();
    test2.Init();
    IEnumerable<PokerHandData> pokerHands = new[]
    {
        test1,
        test2
    };
    IEnumerable<PokerHandPrediction> predictions = model.Predict(pokerHands);
    Console.WriteLine();
    Console.WriteLine("PokerHand Predictions");
    Console.WriteLine("---------------------");

    var pokerHandsAndPredictions = pokerHands.Zip(predictions, (pokerHand, prediction) => (pokerHand, prediction));
    foreach (var (pokerHand, prediction) in pokerHandsAndPredictions)
    {
        Console.WriteLine($"PokerHand: {ShowHand(pokerHand)} | Prediction: { GetPower(prediction.PredictedPower)}");
    }
    Console.WriteLine();

}

創建項目的步驟請參看本文開頭導讀給出的文章鏈接,不再贅述,運行結果如下:

最后放出源代碼文件:下載

希望讀者們保持對ML.NET的持續關注,相信新的特性一定能實現更復雜有趣的場景。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM