torchtext库(文本预处理库)


使用参考:https://zhuanlan.zhihu.com/p/31139113

例程:

def get_data_iter(train_csv, test_csv, fix_length, batch_size, word2vec_dir):
    TEXT = data.Field(sequential=True, lower=True, fix_length=fix_length, batch_first=True)
    LABEL = data.Field(sequential=False, use_vocab=False)
    train_fields = [("label", LABEL), ("title", None), ("text", TEXT)]
    train = TabularDataset(path=train_csv, format='csv', fields=train_fields, skip_header=True)
    train_iter = BucketIterator(train, batch_size=batch_size, device=-1, sort_key=lambda x : len(x.text), sort_within_batch=False, repeat=False)
    test_fields = [("label", LABEL), ("title", None), ("text", TEXT)]
    test = TabularDataset(path=test_csv, format="csv", fields=test_fields, skip_header=True)
    test_iter = Iterator(test, batch_size=batch_size,device=-1, sort=False, sort_within_batch=False, repeat=False)
    #vectors = Vectors(name=word2vec_dir)
    #TEXT.build_vocab(train, vectors=vectors)
    TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300))
    vocab = TEXT.vocab
    return train_iter, test_iter, vocab

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM