基於Bert從Windows API序列做惡意軟件的多分類
0x00 數據集
https://github.com/ocatak/malware_api_class
偶然間發現,該數據集共有8種惡意軟件家族,數量情況如下表。
Malware Family | Samples | Description |
---|---|---|
Spyware | 832 | enables a user to obtain covert information about another's computer activities by transmitting data covertly from their hard drive. |
Downloader | 1001 | share the primary functionality of downloading content. |
Trojan | 1001 | misleads users of its true intent. |
Worms | 1001 | spreads copies of itself from computer to computer. |
Adware | 379 | hides on your device and serves you advertisements. |
Dropper | 891 | surreptitiously carries viruses, back doors and other malicious software so they can be executed on the compromised machine. |
Virus | 1001 | designed to spread from host to host and has the ability to replicate itself. |
Backdoor | 1001 | a technique in which a system security mechanism is bypassed undetectably to access a computer or its data. |
每個樣本的內容都是由Cuckoo Sandbox基於Windows OS API生成的,數據集種共有340種API,樣本內容示例如下:
ldrloaddll ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress regopenkeyexa regopenkeyexa regopenkeyexa ntopenkey ntqueryvaluekey ntclose ntopenkey ntqueryvaluekey ntclose ntclose ntqueryattributesfile ntqueryattributesfile ntqueryattributesfile ntqueryattributesfile loadstringa ntallocatevirtualmemory ntallocatevirtualmemory loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa loadstringa ldrgetdllhandle ldrgetprocedureaddress ldrgetdllhandle ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrgetprocedureaddress ldrloaddll ldrgetprocedureaddress ldrunloaddll findfirstfileexw copyfilea regcreatekeyexa regsetvalueexa regclosekey createprocessinternalw ntclose ntclose ntclose ntfreevirtualmemory ntterminateprocess ntterminateprocess ntclose ntclose ntclose ntclose ntclose ntclose ntclose ldrunloaddll ntopenkey ntqueryvaluekey ntclose ntclose ntclose ntclose ntterminateprocess
0x01 BERT
詞嵌入模型有word2vec、glove、fasttext可用,最近在用BERT系列的模型,所以想用來嘗試一下BERT在安全領域的NLP應用效果。
BERT的模型加載
第一步,下載模型。這里個人習慣用pytorch構建深度學習模型,所以這里下載的是torch版BERT預訓練模型。BERT加載使用時需要三個文件,vocab.txt--用於對文本分詞和構建輸入,pytorch_model.bin和config.json--用於加載BERT預訓練模型
# vocab 文件下載
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
# 預訓練模型參數下載
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
第二步,文本嵌入示例:
從文本到ids
from pytorch_pretrained_bert import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('./bert/vocab.txt')
bert = BertModel.from_pretrained('./bert/')
content = "this is an apple, this is a pen"
CLS = '[CLS]'
token = tokenizer.tokenize(content)
token = [CLS] + token
token_ids = tokenizer.convert_tokens_to_ids(token)
從ids到詞嵌入、分類
bert模型輸入參數要求
input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True
-
intpu_ids,按bert的vocab分詞后,切換到ids
-
token_type_ids,可選。就是 token 對應的句子id,值為0或1(0表示對應的token屬於第一句,1表示屬於第二句)。形狀為(batch_size, sequence_length)。
[CLS] this is an apple ? [SEP] this is a pen . [SEP] 句子
0 0 0 0 0 0 0 1 1 1 1 1 1 token_type_ids
- attention_mask,對應input_ids中非padding的部分為1,padding的部分為0,用於加快計算速度
list_12_outs, pooled = bert(token_ids, output_all_encoded_layers=True)
由於output_all_encoded_layers=True,12層Transformer的結果全返回了,存在list_12_outs的列表中,列表中的每一個張量的大小都是[batch_size, sequence_length, hidden_size]
pooled是大小為[batch_size, hidden_size]的張量,也就是最后一層Transformer的輸出結果的第一個單詞[CLS]的hidden_states,蘊含了整個input句子的信息。
bert_embedding, pooled = bert(token_ids, output_all_encoded_layers=False)
當output_all_encoded_layers=False時,輸出的第一個結果bert_embedding是大小為[batch_size, sequence_len, 768],其中768相當於embedding dim。
pooled還是對整句話的表示,可以該值作為分類結果
def __init__()
self.classifier = nn.Linear(1024, 2)
out = self.classifier(pooled)
return out
如果希望將bert結合更多模型,可以使用embedding的張量,繼續添加各種層
是否需要訓練bert的參數?
凍結所有參數
for param in self.bert.parameters():
param.requires_grad_(False)
如果直接使用pooled結果進行分類,建議將Bert里除了pooler層之外參數凍結,從實驗效果來看,會更好
for name, param in self.bert.named_parameters():
if name.startswith('pooler'):
continue
else:
param.requires_grad_(False)
0x02 數據預處理
訓練集和測試集比例為8:2,並且嚴格對每一類惡意軟件都采取8:2的比例。另外,BERT支持一次最多輸入512個token,所以對樣本中連續的API替換為一個,這樣處理后,樣本的API序列還是很長,所以決定使用樣本的1020個個token,並將其切分為2*510的兩段,每段前后各加上[CLS]和[SEP],這樣恰好成為兩段512長度的tokens。
def load_data(max_sequnce, data_file, label_file):
CLS, SEP, PAD = 101, 102, 0 # tokenizer.convert_tokens_to_ids(['[CLS]','[SEP]', '[PAD]']) 分別是對應的id
api_list = open(data_file, 'r', encoding='utf-8').readlines()
lab_list = open(label_file, 'r', encoding='utf-8').readlines()
# 用這個dict存儲每一類數據和其mask,然后8:2分割 Trojan:[(ids, mask), (ids, mask)]
collected_by_label = {
"Trojan": [],
"Backdoor": [],
"Downloader":[],
"Worms": [],
"Spyware": [],
"Adware": [],
"Dropper": [],
"Virus": []
}
train_input_ids = []
train_input_mak = []
train_input_lab = []
test_input_ids = []
test_input_mak = []
test_input_lab = []
for index in tqdm(range(len(lab_list))):
last_api = ''
simple_api = []
label = lab_list[index].strip() # 去掉末尾的\n
api = api_list[index].strip().replace('\t', ' ').replace('\s', ' ').replace('\xa0', ' ')
while ' ' in api:
api = api.replace(' ', ' ')
for i in api.split(' '):
if i != last_api:
simple_api.append(i)
last_api = i
# api -> ids
ids = []
for j in simple_api:
ids += api_index[j]
if len(ids) > max_sequnce-4: # 由於是1024,所以要加兩次cls、sep
ids = ids[:(max_sequnce-4)]
ids = [CLS] + ids[:510] + [SEP] + [CLS] + ids[510:] + [SEP]
mask = [1]*len(ids)
elif len(ids)> 510:
ids = [CLS] + ids[:510] + [SEP] + [CLS] + ids[510:] + [SEP]
mask = [1]*len(ids)
else:
ids = [CLS] + ids + [SEP]
mask = [1]*len(ids)
if len(ids) <= max_sequnce:
ids = ids + [PAD]*(max_sequnce-len(ids))
mask = mask + [0]*(max_sequnce-len(mask))
collected_by_label[label].append((ids, mask))
# 8:2切分數據集以及合並train、test
for label, data in tqdm(collected_by_label.items()):
label = label_index[label] # "Trojan" -> [0,0,0,0,0,0,0,1]
train = data[:len(data)//10*8]
test = data[len(data)//10*8:]
for ids, mask in train:
train_input_ids.append(ids)
train_input_mak.append(mask)
train_input_lab.append(label)
for ids, mask in test:
test_input_ids.append(ids)
test_input_mak.append(mask)
test_input_lab.append(label)
train_input_ids = torch.tensor(train_input_ids, dtype=torch.int64)
train_input_mak = torch.tensor(train_input_mak, dtype=torch.int64)
train_input_lab = torch.tensor(train_input_lab, dtype=torch.int64)
test_input_ids = torch.tensor(test_input_ids, dtype=torch.int64)
test_input_mak = torch.tensor(test_input_mak, dtype=torch.int64)
test_input_lab = torch.tensor(test_input_lab, dtype=torch.int64)
return train_input_ids,train_input_mak,train_input_lab,test_input_ids,test_input_mak,test_input_lab
0x03 模型框架和代碼
(
模型框架如圖所示,兩個長度為512的輸入序列分別使用BERT做embedding,BERT的輸出將被拼接在一起(torch.cat函數),拼接的函數將會輸入BiLSTM層,最后輸入全連接的softmax層。softmax層有八個神經元,對應8種分類。實際模型中在BiLSTM層后添加了Highway層,Highway層可以更好的向BiLSTM層反饋梯度。
模型定義代碼
class Highway(nn.Module):
def __init__(self, input_dim, num_layers=1):
super(Highway, self).__init__()
self._layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)])
for layer in self._layers:
layer.bias[input_dim:].data.fill_(1)
def forward(self, inputs):
current_inputs = inputs
for layer in self._layers:
linear_part = current_inputs
projected_inputs = layer(current_inputs)
nonlinear_part, gate = projected_inputs.chunk(2, dim=-1)
nonlinear_part = torch.relu(nonlinear_part)
gate = torch.sigmoid(gate)
current_inputs = gate * linear_part + (1 - gate) * nonlinear_part
return current_inputs
class Bert_HBiLSTM(nn.Module):
"""
Bert_HBiLSTM
"""
def __init__(self, config):
super(Bert_HBiLSTM, self).__init__()
self.bert = config.bert
self.config = config
for name, param in self.bert.named_parameters():
param.requires_grad_(False)
self.lstm = nn.LSTM(config.embedding_dim, config.hidden_dim, num_layers=config.num_layers, batch_first=True,
bidirectional=True)
self.drop = nn.Dropout(config.drop_rate)
self.highway = Highway(config.hidden_dim * 2, 1)
self.hidden2one = nn.Linear(config.hidden_dim*2, 1)
self.relu = nn.ReLU()
self.sequence2numclass = nn.Linear(config.max_sequnce, config.num_class)
def forward(self, word_input, input_mask):
word_input_last = word_input[:, 512:]
word_input = word_input[:, :512]
input_mask_last = input_mask[:, 512:]
input_mask = input_mask[:, :512]
word_input, _ = self.bert(word_input, attention_mask=input_mask, output_all_encoded_layers=False)
word_input_last, _ = self.bert(word_input_last, attention_mask=input_mask_last, output_all_encoded_layers=False)
input_mask.requires_grad = False
input_mask_last.requires_grad = False
word_input = word_input * (input_mask.unsqueeze(-1).float())
word_input_last = word_input_last * (input_mask_last.unsqueeze(-1).float())
cat_input = torch.cat([word_input, word_input_last], dim=1)
# bert->bilstm->highway
lstm_out, _ = self.lstm(cat_input)
output = self.highway(lstm_out)
output = self.drop(output)
# hidden_dim*2 -> 1 -> sequense
output = self.hidden2one(output)
output = output.squeeze(-1)
output = self.sequence2numclass(output)
output = F.log_softmax(output, dim=1)
return output
完整代碼和數據可在github獲取https://github.com/bitterzzZZ/Bert-malware-classification