TensorFlow 制作自己數據集時,xml轉csv


TensorFlow 制作自己數據集時,xml轉csv千篇一律,把我拐入坑里了。

如果訓練自己的數據集只有一個類別,用網絡上的xml_to_csv,完全沒有問題,源碼如下:

# -*- coding: utf-8 -*-
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET
 
def xml_to_csv(path):
    xml_list = []
    # 讀取注釋文件
    for xml_file in glob.glob(path + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text + '.jpg',
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
                     int(member[4][1].text),
                     int(member[4][2].text),
                     int(member[4][3].text)
                     )
            xml_list.append(value)
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
 
    # 將所有數據分為樣本集和驗證集,一般按照3:1的比例
    train_list = xml_list[0: int(len(xml_list) * 0.67)]
    eval_list = xml_list[int(len(xml_list) * 0.67) + 1: ]
 
    # 保存為CSV格式
    train_df = pd.DataFrame(train_list, columns=column_name)
    eval_df = pd.DataFrame(eval_list, columns=column_name)
    train_df.to_csv('data/train.csv', index=None)
    eval_df.to_csv('data/eval.csv', index=None)
 
 
def main():
    path = './xml'
    xml_to_csv(path)
    print('Successfully converted xml to csv.')
 
main()

  

如果你的類別數據集,超過2類以上,再用上述源碼,覺得把所有的數據集3:1的分割,而非一個類別的3:1分割 。

對上述源碼略作調整,完美把每一類數據集按照9:1分割為訓練數據集和測試數據集,源代碼如下:

# coding: utf-8
import glob
import pandas as pd
import xml.etree.ElementTree as ET
 
classes = ["20Km_h", "no_passing_35", "no_passing", "keep_left", "keep_right", "mandatory", "straight_or_left", "passing_limits",
           "bicycles", "pedestrians", "stop", "dangerous"]
 
def xml_to_csv(path):
    train_list = []
    eval_list = []
 
    for cls in classes:
        xml_list = []
        # 讀取注釋文件
        for xml_file in glob.glob(path + '/*.xml'):
            tree = ET.parse(xml_file)
            root = tree.getroot()
            for member in root.findall('object'):
                if cls == member[0].text:
                    value = (root.find('filename').text,
                             int(root.find('size')[0].text),
                             int(root.find('size')[1].text),
                             member[0].text,
                             int(member[4][0].text),
                             int(member[4][1].text),
                             int(member[4][2].text),
                             int(member[4][3].text)
                             )
                    xml_list.append(value)
 
        for i in range(0,int(len(xml_list) * 0.9)):
            train_list.append(xml_list[i])
        for j in range(int(len(xml_list) * 0.9) + 1,int(len(xml_list))):
            eval_list.append(xml_list[j])
 
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
 
 
    # 保存為CSV格式
    train_df = pd.DataFrame(train_list, columns=column_name)
    eval_df = pd.DataFrame(eval_list, columns=column_name)
    train_df.to_csv('data/train.csv', index=None)
    eval_df.to_csv('data/eval.csv', index=None)
 
 
def main():
    # path = 'E:\\\data\\\Images'
    path = r'D:\work\PycharmPro\trafficsign\SSD_NET\data\xml_data'  # path參數更具自己xml文件所在的文件夾路徑修改
    xml_to_csv(path)
    print('Successfully converted xml to csv.')
 
 
main()

  

classes = ["20Km_h", "no_passing_35", "no_passing", "keep_left", "keep_right", "mandatory", "straight_or_left", "passing_limits", "bicycles", "pedestrians", "stop", "dangerous"]

該處需要改為自己數據集類別標簽名。


原文:https://blog.csdn.net/miao0967020148/article/details/90208139


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM