diff options
Diffstat (limited to 'script/utils.py')
| -rw-r--r-- | script/utils.py | 1484 |
1 files changed, 1484 insertions, 0 deletions
diff --git a/script/utils.py b/script/utils.py new file mode 100644 index 0000000..66f84e7 --- /dev/null +++ b/script/utils.py @@ -0,0 +1,1484 @@ +import matplotlib.pyplot as plt +import torch +from torch import nn, optim +import torch.nn.functional as F +import numpy as np +from sklearn.feature_extraction.text import HashingVectorizer +from torch.autograd import Variable +import torch.nn.functional as F +import pandas +import random +import time +import argparse +import collections +from torch.nn.utils.rnn import PackedSequence + +from sklearn import metrics + +from collections import Counter + +import statistics + +ERR_TOO_SHORT_SEQ = -1 +TRIMMED_COL_NAMES = [ + 'ATTACK_ID', + 'DIRECTION', + 'SEQ', + 'ACK', + 'DATAOFF', + 'FLAGS', + 'WINDOW', + 'CHKSUM', + 'URGPTR', + 'SK_STATE', + 'PAYLOAD_LEN', + 'IP_LEN', + 'IP_TTL', + 'IP_IHL', + 'IP_CHKSUM', + 'IP_VERSION', + 'IP_TOS', + 'IP_ID',#无作用 + 'IP_OPT_NON_STANDARD', + 'TCP_OPT_MSS', + 'TCP_OPT_TSVAL', + 'TCP_OPT_TSECR', + 'TCP_OPT_WSCALE', + 'TCP_OPT_UTO', + 'TCP_OPT_MD5HEADER', + 'TCP_OPT_NON_STANDARD', + 'TCP_TIMESTAMP', + 'ARRIVAL_TIMESTAMP', + 'HTTP_Method', + 'HTTP_Version', + 'HTTP_Path', + 'HTTP_Header', + 'HTTP_Params' +] + +TCP_FLAGS_MAP = { + "F": 0, + "S": 1, + "R": 2, + "P": 3, + "A": 4, + "U": 5, + "E": 6, + "C": 7, +} + +IP_VERSION_MAP = { + '4': 0, + '6': 1, + '-1': 2, +} + +TCP_OPT_MD5HEADER_MAP = { + '0': 0, + '1': 1, + '-1': 2, +} + +TRAIN_TEST_SPLIT = 10 + +# https://elixir.bootlin.com/linux/latest/source/net/netfilter/nf_conntrack_proto_tcp.c +nf_conntrack_states = [ + "SYN_SENT", + "SYN_RECV", + "ESTABLISHED", + "FIN_WAIT", + "CLOSE_WAIT", + "LAST_ACK", + "TIME_WAIT", + "CLOSE", + "SYN_SENT2", +] + + +class MyKitsunePacket(object): + def __init__(self, frame_time_epoch, frame_len, eth_src, + eth_dst, ip_src, ip_dst, + tcp_sport, tcp_dport, + debug=False): + self.frame_time_epoch = float(frame_time_epoch) + self.frame_len = int(frame_len) + self.eth_src = str(eth_src) + self.eth_dst = str(eth_dst) + self.ip_src = str(ip_src) + self.ip_dst = str(ip_dst) + self.tcp_sport = int(tcp_sport) + self.tcp_dport = int(tcp_dport) + + def get_dump_str(self, conn_idx=None, packet_idx=None): + if conn_idx is not None: + return '\t'.join([str(conn_idx), str(packet_idx), str(self.frame_time_epoch), + str(self.frame_len), str(self.eth_src), + str(self.eth_dst), str( + self.ip_src), str(self.ip_dst), + str(self.tcp_sport), str(self.tcp_dport)] + [''] * 11) + else: + return '\t'.join([str(self.frame_time_epoch), str(self.frame_len), str(self.eth_src), + str(self.eth_dst), str( + self.ip_src), str(self.ip_dst), + str(self.tcp_sport), str(self.tcp_dport)] + [''] * 11) + + +class MyPacket(object): + def __init__(self, src_ip, src_port, + dst_ip, dst_port, seq, + ack, dataoff, flags, + window, chksum, urgptr, + timestamp, payload_len, sk_state, + filename, ip_len, ip_ttl, ip_ihl, + ip_chksum, ip_version, ip_tos, ip_id, ip_opt_non_standard, + tcp_opt_mss, tcp_opt_tsval, tcp_opt_tsecr, + tcp_opt_wscale, tcp_opt_uto, tcp_opt_md5header, + tcp_opt_non_standard, tcp_timestamp, arrival_timestamp, + kitsune_frame_time_epoch=None, kitsune_frame_len=None, + kitsune_eth_src=None, kitsune_eth_dst=None, kitsune_ip_src=None, + kitsune_ip_dst=None, kitsune_tcp_sport=None, kitsune_tcp_dport=None, + #debug=False, + http_method=[],http_version=[],http_path=[], + http_header=[],query_params=[]): + #http特征 + self.http_method=http_method + self.http_version=http_version + self.http_path=http_path + self.query_params=query_params + self.http_header=http_header + # + self.src_ip = src_ip + self.src_port = src_port + self.dst_ip = dst_ip + self.dst_port = dst_port + self.seq = seq + self.ack = ack + self.dataoff = dataoff + self.flags = flags + self.window = window + self.chksum = chksum + self.urgptr = urgptr + self.timestamp = timestamp + self.payload_len = payload_len + self.sk_state = sk_state + self.filename = filename + self.ip_len = ip_len + self.ip_ttl = ip_ttl + self.ip_ihl = ip_ihl + self.ip_chksum = ip_chksum + self.ip_version = ip_version + self.ip_tos = ip_tos + self.ip_id = ip_id + self.ip_opt_non_standard = ip_opt_non_standard + self.tcp_opt_mss = tcp_opt_mss + self.tcp_opt_tsval = tcp_opt_tsval + self.tcp_opt_tsecr = tcp_opt_tsecr + self.tcp_opt_wscale = tcp_opt_wscale + self.tcp_opt_uto = tcp_opt_uto + self.tcp_opt_md5header = tcp_opt_md5header + self.tcp_opt_non_standard = tcp_opt_non_standard + self.tcp_timestamp = tcp_timestamp + self.arrival_timestamp = arrival_timestamp + self.kitsune_frame_time_epoch = kitsune_frame_time_epoch + self.kitsune_frame_len = kitsune_frame_len + self.kitsune_eth_src = kitsune_eth_src + self.kitsune_eth_dst = kitsune_eth_dst + self.kitsune_ip_src = kitsune_ip_src + self.kitsune_ip_dst = kitsune_ip_dst + self.kitsune_tcp_sport = kitsune_tcp_sport + self.kitsune_tcp_dport = kitsune_tcp_dport + #if debug: + #self.print_debug() + + def set_sk_state(self, sk_state): + self.sk_state = sk_state + + def get_attack_id(self): + attack_id = ';;'.join( + [self.src_ip, str(self.src_port), self.dst_ip, str(self.dst_port)]) + return attack_id + + def get_tuple_id(self): + src = ','.join([self.src_ip, str(self.src_port)]) + dst = ','.join([self.dst_ip, str(self.dst_port)]) + return src, dst + + def get_reverse_attack_id(self): + reverse_attack_id = ','.join( + [self.dst_ip, str(self.dst_port), self.src_ip, str(self.src_port)]) + return reverse_attack_id + + def get_attack_packet_id(self): + attack_packet_id = ';;'.join([str(self.dataoff), str(self.flags), str( + self.window), str(self.chksum), str(self.urgptr)]) + return attack_packet_id + + def get_filename(self): + return self.filename + + def get_hash(self): + return ';;'.join([str(self.src_ip), str(self.src_port), str(self.dst_ip), + str(self.dst_port), str(self.seq), str(self.ack), + str(self.dataoff), str(self.flags), str(self.window), + str(self.chksum), str( + self.urgptr), str(self.timestamp), + str(self.timestamp), str( + self.payload_len), str(self.sk_state), + str(self.filename), str( + self.ip_len), str(self.ip_ttl), + str(self.ip_ihl), str(self.ip_chksum)]) + + def get_data_str(self, idx, packet_idx, direction=None): + #get_attack_id,使用四元组作为标识 + if not direction: + return ';;'.join([str(idx), str(packet_idx), self.get_attack_id(), str(self.seq), + str(self.ack), self.get_attack_packet_id(), str( + self.sk_state), + str(self.payload_len), self.timestamp, str( + self.ip_len), + str(self.ip_ttl), str( + self.ip_ihl), str(self.ip_chksum), + str(self.ip_version), str( + self.ip_tos), str(self.ip_id), str(self.ip_opt_non_standard), + str(self.tcp_opt_mss), str(self.tcp_opt_tsval), + str(self.tcp_opt_tsecr), str( + self.tcp_opt_wscale), str(self.tcp_opt_uto), + str(self.tcp_opt_md5header), str( + self.tcp_opt_non_standard), str(self.tcp_timestamp), + str(self.arrival_timestamp),str(self.http_method),str(self.http_version), + str(self.http_path),str(self.http_header),str(self.query_params)]) + else: + return ';;'.join([str(idx), str(packet_idx), str(direction), str(self.seq), + str(self.ack), self.get_attack_packet_id(), str( + self.sk_state), + str(self.payload_len), self.timestamp, str( + self.ip_len), + str(self.ip_ttl), str( + self.ip_ihl), str(self.ip_chksum), + str(self.ip_version), str( + self.ip_tos), str(self.ip_id), str(self.ip_opt_non_standard), + str(self.tcp_opt_mss), str(self.tcp_opt_tsval), + str(self.tcp_opt_tsecr), str( + self.tcp_opt_wscale), str(self.tcp_opt_uto), + str(self.tcp_opt_md5header), str( + self.tcp_opt_non_standard), str(self.tcp_timestamp), + str(self.arrival_timestamp),str(self.http_method),str(self.http_version), + str(self.http_path),str(self.http_header),str(self.query_params)]) + + def get_kitsune_str(self, idx, pkt_idx): + return '\t'.join([str(idx), str(pkt_idx), str(self.kitsune_frame_time_epoch), str(self.kitsune_frame_len), + str(self.kitsune_eth_src), str( + self.kitsune_eth_dst), str(self.kitsune_ip_src), + str(self.kitsune_ip_dst), str(self.kitsune_tcp_sport), str(self.kitsune_tcp_dport)]) + + def print_debug(self): + print("Dumping packet fields...") + print("%s:%d -> %s:%d" % + (self.src_ip, self.src_port, self.dst_ip, self.dst_port)) + print("SEQ: %s" % self.seq) + print("ACK: %s" % self.ack) + print("Data offset: %d" % self.dataoff) + print("TCP flags: %s" % self.flags) + print("Window: %d" % self.window) + print("Checksum: %s" % self.chksum) + print("Urgent pointer: %s" % self.urgptr) + print("Timestamp: %s" % self.timestamp) + print("Payload length: %d" % self.payload_len) + print("sk_state: %s" % str(self.sk_state)) + print("Filename: %s" % self.filename) + print("IP length: %s" % str(self.ip_len)) + print("IP TTL: %s" % str(self.ip_ttl)) + print("IP IHL: %s" % str(self.ip_ihl)) + print("IP Checksum: %s" % str(self.ip_chksum)) + print("IP Version: %s" % str(self.ip_version)) + print("IP TOS: %s" % str(self.ip_tos)) + print("IP ID: %s" % str(self.ip_id)) + input("Dump ended.") + + +# Copied from torch's official implementation, with return value +# being a tuple that contains gate states (vs. only hidden states) +def gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): + gi = torch.mm(input, w_ih.t()) + b_ih + gh = torch.mm(hidden, w_hh.t()) + b_hh + i_r, i_i, i_n = gi.chunk(3, 1) + h_r, h_i, h_n = gh.chunk(3, 1) + + resetgate = torch.sigmoid(i_r + h_r) + inputgate = torch.sigmoid(i_i + h_i) + newgate = torch.tanh(i_n + resetgate * h_n) + hy = newgate + inputgate * (hidden - newgate) + + return hy, resetgate, inputgate + + +# Also copied from torch's official FastRNN benchmark, with additional gates returned +def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): + # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] + hx, cx = hidden + gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh + + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + + return hy, cy, ingate, forgetgate, cellgate, outgate + + +class GRUCell(nn.modules.rnn.RNNCellBase): + def __init__(self, input_size, hidden_size, bias=True): + super(GRUCell, self).__init__( + input_size, hidden_size, bias, num_chunks=3) + + def forward(self, x, hx): + # type: (Tensor, Optional[Tensor]) -> Tensor + self.check_forward_input(x) + self.check_forward_hidden(x, hx, '') + return gru_cell( + x, hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) + + +class LSTMCell(nn.modules.rnn.RNNCellBase): + def __init__(self, input_size, hidden_size, bias=True): + super(LSTMCell, self).__init__( + input_size, hidden_size, bias, num_chunks=4) + + def forward(self, x, hx): + # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] + self.check_forward_input(x) + self.check_forward_hidden(x, hx[0], '') + self.check_forward_hidden(x, hx[1], '') + return lstm_cell( + x, hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) + + +class GRUModel(nn.Module): + def __init__(self, input_size, hidden_size, output_size, num_layers, device, bidirectional): + super(GRUModel, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.num_layers = num_layers + self.device = device + self.bidirectional = bidirectional + + print("===== GRUModel args =====") + print("input_size: %s" % str(input_size)) + print("hidden_size: %s" % str(hidden_size)) + print("output_size: %s" % str(output_size)) + print("num_layers: %s" % str(num_layers)) + print("device: %s" % str(device)) + + self.gru_in = GRUCell(input_size, hidden_size) + self.gru_middle = GRUCell(hidden_size, hidden_size) + if bidirectional: + self.fc = nn.Linear(hidden_size * 2, output_size) + else: + self.fc = nn.Linear(hidden_size, output_size) + self.dropout = nn.Dropout(p=0.1) + + def forward(self, inputs): + is_packed = isinstance(inputs, PackedSequence) + if is_packed: + inputs, batch_sizes, sorted_indices, unsorted_indices = inputs + max_batch_size = batch_sizes[0] + max_batch_size = int(max_batch_size) + + # These states need to be returned + outputs = [] + gates = [] + hn = [] + + # Temporary states + hs = [] + + # Initialize hidden states + for layer_idx in range(self.num_layers): + hs.append(self.init_hidden()) + + for seq_idx in range(inputs.size(1)): + curr_seq = inputs[:, seq_idx, :] + + # Stacked GRU + for layer_idx in range(self.num_layers): + if layer_idx == 0: # input layer + hs[layer_idx], resetgate, inputgate = self.gru_in( + curr_seq, hs[layer_idx]) + else: # non-input layer + hs[layer_idx], resetgate, inputgate = self.gru_middle( + hs[layer_idx-1], hs[layer_idx]) + + outputs.append(hs[-1]) + + gates.append([resetgate.detach(), inputgate.detach()]) + hn.append(hs[-1].detach()) + + if self.bidirectional: + # Temporary states + hs2 = [] + + # Initialize hidden states + for layer_idx in range(self.num_layers): + hs2.append(self.init_hidden()) + + for seq_idx in reversed(range(inputs.size(1))): + forward_seq_idx = inputs.size(1) - seq_idx - 1 + curr_seq = inputs[:, seq_idx, :] + + # Stacked GRU + for layer_idx in range(self.num_layers): + if layer_idx == 0: # input layer + hs2[layer_idx], resetgate, inputgate = self.gru_in( + curr_seq, hs2[layer_idx]) + else: # non-input layer + hs2[layer_idx], resetgate, inputgate = self.gru_middle( + hs2[layer_idx-1], hs2[layer_idx]) + + outputs[forward_seq_idx] = torch.cat( + (outputs[forward_seq_idx], hs2[-1]), 1) + gates[forward_seq_idx] = [torch.cat((gates[forward_seq_idx][0], resetgate.detach( + )), 1), torch.cat((gates[forward_seq_idx][1], inputgate.detach()), 1)] + hn[forward_seq_idx] = torch.cat( + (hn[forward_seq_idx], hs2[-1].detach()), 1) + + for idx in range(len(outputs)): + outputs[idx] = self.fc(outputs[idx]) + outputs = torch.stack(outputs, dim=1) + + return [outputs, gates, hn] + + def init_hidden(self, batch_size=1): + if torch.cuda.is_available(): + h0 = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).to( + self.device, dtype=torch.float) + else: + h0 = Variable(torch.zeros(self.num_layers, + batch_size, self.hidden_size)) + + return h0[0, :, :] + + +class LSTMModel(nn.Module): + def __init__(self, input_size, hidden_size, output_size, num_layers, device, bidirectional): + super(LSTMModel, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.num_layers = num_layers + self.device = device + self.bidirectional = bidirectional + + print("===== LSTMModel args =====") + print("input_size: %s" % str(input_size)) + print("hidden_size: %s" % str(hidden_size)) + print("output_size: %s" % str(output_size)) + print("num_layers: %s" % str(num_layers)) + print("device: %s" % str(device)) + + self.lstm_in = LSTMCell(input_size, hidden_size) + self.lstm_middle = LSTMCell(hidden_size, hidden_size) + if bidirectional: + self.fc = nn.Linear(hidden_size * 2, output_size) + else: + self.fc = nn.Linear(hidden_size, output_size) + self.dropout = nn.Dropout(p=0.1) + + def forward(self, inputs): + # These states need to be returned + outputs = [] + gates = [] + hn = [] + + # Temporary states + hs = [] + cs = [] + + # Initialize hidden states + for layer_idx in range(self.num_layers): + hs.append(self.init_hidden()) + cs.append(self.init_hidden()) + + for seq_idx in range(inputs.size(1)): + curr_seq = inputs[:, seq_idx, :] + + # Stacked LSTM + for layer_idx in range(self.num_layers): + if layer_idx == 0: # input layer + hs[layer_idx], cs[layer_idx], inputgate, forgetgate, cellgate, outgate = self.lstm_in( + curr_seq, (hs[layer_idx], cs[layer_idx])) + hs[layer_idx] = self.dropout(hs[layer_idx]) + elif layer_idx != self.num_layers - 1: # non-input layer + hs[layer_idx], cs[layer_idx], inputgate, forgetgate, cellgate, outgate = self.lstm_middle( + hs[layer_idx-1], (hs[layer_idx], cs[layer_idx])) + hs[layer_idx] = self.dropout(hs[layer_idx]) + else: + hs[layer_idx], cs[layer_idx], inputgate, forgetgate, cellgate, outgate = self.lstm_middle( + hs[layer_idx-1], (hs[layer_idx], cs[layer_idx])) + + outputs.append(hs[-1]) + + gates.append([inputgate.detach(), forgetgate.detach(), + cellgate.detach(), outgate.detach()]) + hn.append(cs[-1].detach()) + + if self.bidirectional: + # Temporary states + hs2 = [] + cs2 = [] + + # Initialize hidden states + for layer_idx in range(self.num_layers): + hs2.append(self.init_hidden()) + cs2.append(self.init_hidden()) + + for seq_idx in reversed(range(inputs.size(1))): + forward_seq_idx = inputs.size(1) - seq_idx - 1 + curr_seq = inputs[:, seq_idx, :] + + # Stacked LSTM + for layer_idx in range(self.num_layers): + if layer_idx == 0: # input layer + hs2[layer_idx], cs2[layer_idx], inputgate, forgetgate, cellgate, outgate = self.lstm_in( + curr_seq, (hs2[layer_idx], cs2[layer_idx])) + hs2[layer_idx] = self.dropout(hs2[layer_idx]) + elif layer_idx != self.num_layers - 1: # non-input layer + hs2[layer_idx], cs2[layer_idx], inputgate, forgetgate, cellgate, outgate = self.lstm_middle( + hs2[layer_idx-1], (hs2[layer_idx], cs2[layer_idx])) + hs2[layer_idx] = self.dropout(hs2[layer_idx]) + else: + hs2[layer_idx], cs2[layer_idx], inputgate, forgetgate, cellgate, outgate = self.lstm_middle( + hs2[layer_idx-1], (hs2[layer_idx], cs2[layer_idx])) + + outputs[forward_seq_idx] = torch.cat( + (outputs[forward_seq_idx], hs2[-1]), 1) + gates[forward_seq_idx] = [torch.cat((gates[forward_seq_idx][0], inputgate.detach()), 1), torch.cat((gates[forward_seq_idx][1], forgetgate.detach( + )), 1), torch.cat((gates[forward_seq_idx][2], cellgate.detach()), 1), torch.cat((gates[forward_seq_idx][3], outgate.detach()), 1)] + hn[forward_seq_idx] = torch.cat( + (hn[forward_seq_idx], cs2[-1].detach()), 1) + + for idx in range(len(outputs)): + outputs[idx] = self.fc(outputs[idx]) + outputs = torch.stack(outputs, dim=1) + + return [outputs, gates, hn] + + def init_hidden(self, batch_size=1): + if torch.cuda.is_available(): + h0 = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).to( + self.device, dtype=torch.float) + else: + h0 = Variable(torch.zeros(self.num_layers, + batch_size, self.hidden_size)) + + return h0[0, :, :] + + +class AEModel(nn.Module): + def __init__(self, input_size, bottleneck_size=5, model_type='mid'): + super(AEModel, self).__init__() + self.input_size = input_size + self.model_type = model_type + if self.model_type == 'small': + l1 = int(float(input_size)/3) + l2 = bottleneck_size + print("[INFO][Model] Input: %d ---> L1: %d ---> L2: %d --> L3: %d --> Output: %d" % + (input_size, l1, l2, l1, input_size)) + self.fc1 = nn.Linear(input_size, l1) + self.fc2 = nn.Linear(l1, l2) + self.fc3 = nn.Linear(l2, l1) + self.fc4 = nn.Linear(l1, input_size) + elif self.model_type == 'mid': + l1 = int(float(input_size)/1.5) + l2 = int(float(input_size)/3) + l3 = bottleneck_size + print("[INFO][Model] Input: %d ---> L1: %d ---> L2: %d --> L3: %d --> L4: %d --> L5: %d --> Output: %d" % + (input_size, l1, l2, l3, l2, l1, input_size)) + self.fc1 = nn.Linear(input_size, l1) + self.fc2 = nn.Linear(l1, l2) + self.fc3 = nn.Linear(l2, l3) + self.fc4 = nn.Linear(l3, l2) + self.fc5 = nn.Linear(l2, l1) + self.fc6 = nn.Linear(l1, input_size) + elif self.model_type == 'large': + l1 = int(float(input_size)/1.5) + l2 = int(float(input_size)/2.5) + l3 = int(float(input_size)/5) + l4 = bottleneck_size + print("[INFO][Model] Input: %d ---> L1: %d ---> L2: %d ---> L3: %d ---> L4: %d ---> L5: %d --> L6: %d --> L7: %d --> Output: %d" % + (input_size, l1, l2, l3, l4, l3, l2, l1, input_size)) + self.fc1 = nn.Linear(input_size, l1) + self.fc2 = nn.Linear(l1, l2) + self.fc3 = nn.Linear(l2, l3) + self.fc4 = nn.Linear(l3, l4) + self.fc5 = nn.Linear(l4, l3) + self.fc6 = nn.Linear(l3, l2) + self.fc7 = nn.Linear(l2, l1) + self.fc8 = nn.Linear(l1, input_size) + + def encode(self, x): + if self.model_type == 'small': + h1 = F.relu(self.fc1(x)) + h2 = self.fc2(h1) + return h2 + elif self.model_type == 'mid': + h1 = F.relu(self.fc1(x)) + h2 = F.relu(self.fc2(h1)) + h3 = self.fc3(h2) + return h3 + elif self.model_type == 'large': + h1 = F.relu(self.fc1(x)) + h2 = F.relu(self.fc2(h1)) + h3 = F.relu(self.fc3(h2)) + h4 = self.fc4(h3) + return h4 + + def decode(self, z): + if self.model_type == 'small': + h1 = F.relu(self.fc3(z)) + h2 = self.fc4(h1) + return torch.sigmoid(h2) + elif self.model_type == 'mid': + h1 = F.relu(self.fc4(z)) + h2 = F.relu(self.fc5(h1)) + h3 = self.fc6(h2) + return torch.sigmoid(h3) + elif self.model_type == 'large': + h1 = F.relu(self.fc5(z)) + h2 = F.relu(self.fc6(h1)) + h3 = F.relu(self.fc7(h2)) + h4 = self.fc8(h3) + return torch.sigmoid(h4) + + def forward(self, x): + h = self.encode(x.view(-1, self.input_size)) + r = self.decode(h) + return r + + +def read_dataset(path, batch_size=1, preprocess=False, debug=False, cutoff=-1, seq_cutoff=-1, split_train_test=False, stats=None, shuffle=False, add_additional_features=False, use_conn_id=False): + def parse_flags(flags): + flags_lst = [0] * len(TCP_FLAGS_MAP) + if not isinstance(flags, str): + return flags_lst + flags_set = set(flags) + for flag, idx in TCP_FLAGS_MAP.items(): + if flag in flags_set: + flags_lst[idx] = 1 + return flags_lst + + def parse_ip_version(ip_version): + ip_version_lst = [0] * len(IP_VERSION_MAP) + for version, idx in IP_VERSION_MAP.items(): + if int(version) == ip_version: + ip_version_lst[idx] = 1 + return ip_version_lst + + def parse_md5header(md5header): + md5header_lst = [0] * len(TCP_OPT_MD5HEADER_MAP) + for md5_state, idx in TCP_OPT_MD5HEADER_MAP.items(): + if int(md5_state) == md5header: + md5header_lst[idx] = 1 + return md5header_lst + + def rescale(ori_val, stats): + maxn, minn, mean = stats['max'], stats['min'], stats['mean'] + if maxn == minn: + if ori_val < minn: + return -0.1 + elif ori_val > maxn: + return 1.1 + else: + return 0.0 + else: + return (float(ori_val - minn) / (maxn - minn)) + + def summarize(dataframe, col_name, numeral_system=10, debug=True): + if numeral_system != 10: + x = dataframe[col_name].tolist()[0] + col_list = [int(str(r), numeral_system) + for r in dataframe[col_name].tolist()] + else: + col_list = dataframe[col_name].tolist() + col_stats = {'max': max(col_list), 'min': min( + col_list), 'mean': sum(col_list)/float(len(col_list))} + return col_stats + + def add_oor_feature(bounds, val, records): + maxn, minn, mean = bounds['max'], bounds['min'], bounds['mean'] + if val < minn or val > maxn: + records.append(1.0) + else: + records.append(0.0) + + def preprocess(attack_records, numeric_stats, sk_labels_map, debug=False, add_additional_features=False): + preprocessed_records = [] + labels = [] + + for idx, row in attack_records.iterrows(): + curr_record = [] + + if use_conn_id: + curr_record.append(int(row['ATTACK_ID'])) + + if 'DIRECTION' in row: + curr_record.append(float(row['DIRECTION'])) + + if 'SEQ' in row: + rescaled_seq = rescale(int(row['SEQ']), numeric_stats['SEQ']) + curr_record.append(rescaled_seq) + + if 'ACK' in row: + rescaled_ack = rescale(int(row['ACK']), numeric_stats['ACK']) + curr_record.append(rescaled_ack) + + if 'DATAOFF' in row: + rescaled_dataoff = rescale( + int(row['DATAOFF']), numeric_stats['DATAOFF']) + curr_record.append(rescaled_dataoff) + if add_additional_features: + add_oor_feature( + numeric_stats['DATAOFF'], row['DATAOFF'], curr_record) + + if 'FLAGS' in row: + curr_record.extend(parse_flags(row['FLAGS'])) + + if 'WINDOW' in row: + rescaled_window = rescale( + int(row['WINDOW']), numeric_stats['WINDOW']) + curr_record.append(rescaled_window) + if add_additional_features: + add_oor_feature( + numeric_stats['WINDOW'], row['WINDOW'], curr_record) + + if 'CHKSUM' in row: + curr_record.append(float(row['CHKSUM'])) + + if 'URGPTR' in row: + rescaled_urg = rescale( + int(str(row['URGPTR'])), numeric_stats['URGPTR']) + curr_record.append(rescaled_urg) + if add_additional_features: + add_oor_feature( + numeric_stats['URGPTR'], row['URGPTR'], curr_record) + + labels.append(sk_labels_map[row['SK_STATE']]) + + if 'PAYLOAD_LEN' in row: + rescaled_payload_len = rescale( + int(row['PAYLOAD_LEN']), numeric_stats['PAYLOAD_LEN']) + curr_record.append(rescaled_payload_len) + if add_additional_features: + add_oor_feature( + numeric_stats['PAYLOAD_LEN'], row['PAYLOAD_LEN'], curr_record) + + if 'IP_LEN' in row: + rescaled_ip_len = rescale( + int(row['IP_LEN']), numeric_stats['IP_LEN']) + curr_record.append(rescaled_ip_len) + if add_additional_features: + add_oor_feature( + numeric_stats['IP_LEN'], row['IP_LEN'], curr_record) + + if 'IP_TTL' in row: + rescaled_ip_ttl = rescale( + int(row['IP_TTL']), numeric_stats['IP_TTL']) + curr_record.append(rescaled_ip_ttl) + if add_additional_features: + add_oor_feature( + numeric_stats['IP_TTL'], row['IP_TTL'], curr_record) + + if 'IP_IHL' in row: + rescaled_ip_ihl = rescale( + int(row['IP_IHL']), numeric_stats['IP_IHL']) + curr_record.append(rescaled_ip_ihl) + add_oor_feature( + numeric_stats['IP_IHL'], row['IP_IHL'], curr_record) + + if add_additional_features: + if row['IP_IHL'] + row['DATAOFF'] + row['PAYLOAD_LEN'] == row['IP_LEN']: + curr_record.append('0.0') + else: + curr_record.append('1.0') + + if 'IP_CHKSUM' in row: + curr_record.append(float(row['IP_CHKSUM'])) + + if 'IP_VERSION' in row: + curr_record.extend(parse_ip_version(row['IP_VERSION'])) + + if 'IP_TOS' in row: + rescaled_ip_tos = rescale( + int(row['IP_TOS']), numeric_stats['IP_TOS']) + curr_record.append(rescaled_ip_tos) + if add_additional_features: + add_oor_feature( + numeric_stats['IP_TOS'], row['IP_TOS'], curr_record) + + if 'IP_OPT_NON_STANDARD' in row: + curr_record.append(float(row['IP_OPT_NON_STANDARD'])) + + if 'TCP_OPT_MSS' in row: + rescaled_tcp_opt_mss = rescale( + int(row['TCP_OPT_MSS']), numeric_stats['TCP_OPT_MSS']) + curr_record.append(rescaled_tcp_opt_mss) + if add_additional_features: + add_oor_feature( + numeric_stats['TCP_OPT_MSS'], row['TCP_OPT_MSS'], curr_record) + + if 'TCP_OPT_TSVAL' in row: + rescaled_tcp_opt_tsval = rescale( + int(row['TCP_OPT_TSVAL']), numeric_stats['TCP_OPT_TSVAL']) + curr_record.append(rescaled_tcp_opt_tsval) + if add_additional_features: + add_oor_feature( + numeric_stats['TCP_OPT_TSVAL'], row['TCP_OPT_TSVAL'], curr_record) + + if 'TCP_OPT_TSECR' in row: + rescaled_tcp_opt_tsecr = rescale( + int(row['TCP_OPT_TSECR']), numeric_stats['TCP_OPT_TSECR']) + curr_record.append(rescaled_tcp_opt_tsecr) + if add_additional_features: + add_oor_feature( + numeric_stats['TCP_OPT_TSECR'], row['TCP_OPT_TSECR'], curr_record) + + if 'TCP_OPT_WSCALE' in row: + rescaled_tcp_opt_wscale = rescale( + int(row['TCP_OPT_WSCALE']), numeric_stats['TCP_OPT_WSCALE']) + curr_record.append(rescaled_tcp_opt_wscale) + if add_additional_features: + add_oor_feature( + numeric_stats['TCP_OPT_WSCALE'], row['TCP_OPT_WSCALE'], curr_record) + + if 'TCP_OPT_UTO' in row: + rescaled_tcp_opt_uto = rescale( + int(row['TCP_OPT_UTO']), numeric_stats['TCP_OPT_UTO']) + curr_record.append(rescaled_tcp_opt_uto) + if add_additional_features: + add_oor_feature( + numeric_stats['TCP_OPT_UTO'], row['TCP_OPT_UTO'], curr_record) + + if 'TCP_OPT_MD5HEADER' in row: + curr_record.extend(parse_md5header(row['TCP_OPT_MD5HEADER'])) + + if 'TCP_OPT_NON_STANDARD' in row: + curr_record.append(float(row['TCP_OPT_NON_STANDARD'])) + + if 'TCP_TIMESTAMP' in row: + rescaled_tcp_timestamp = rescale( + float(row['TCP_TIMESTAMP']), numeric_stats['TCP_TIMESTAMP']) + curr_record.append(rescaled_tcp_timestamp) + if add_additional_features: + add_oor_feature( + numeric_stats['TCP_TIMESTAMP'], row['TCP_TIMESTAMP'], curr_record) + + if 'ARRIVAL_TIMESTAMP' in row: + rescaled_arrival_timestamp = rescale( + float(row['ARRIVAL_TIMESTAMP']), numeric_stats['ARRIVAL_TIMESTAMP']) + curr_record.append(rescaled_arrival_timestamp) + if add_additional_features: + add_oor_feature( + numeric_stats['ARRIVAL_TIMESTAMP'], row['ARRIVAL_TIMESTAMP'], curr_record) + #添加额外的学习特征HTTP特征 + vectorizer = HashingVectorizer(n_features=1) + #vectorizer = TfidfVectorizer() + + if 'HTTP_Method' in row: + curr_record.extend(vectorizer.fit_transform([str(row['HTTP_Method'])]).toarray().flatten()) + #HTTP_Version,HTTP_Name,HTTP_Value_Domain,HTTP_Url,HTTP_Path,HTTP_Params + if 'HTTP_Version' in row: + curr_record.extend(vectorizer.fit_transform([str(row['HTTP_Version'])]).toarray().flatten()) + if 'HTTP_Path' in row: + curr_record.extend(vectorizer.fit_transform([str(row['HTTP_Path'])]).toarray().flatten()) + if 'HTTP_Header' in row: + curr_record.extend(vectorizer.fit_transform([str(row['HTTP_Header'])]).toarray().flatten()) + if 'HTTP_Params' in row: + curr_record.extend(vectorizer.fit_transform([str(row['HTTP_Params'])]).toarray().flatten()) + + preprocessed_records.append(curr_record) + + return np.array(preprocessed_records, dtype=np.float32), np.array(labels, dtype=np.int) + + dataset = [] + #trimmed + dataframe = pandas.read_csv(path, sep=';;', header='infer',engine='python')#,usecols=range(30) + labels_stats = [] + print("Reading dataset from path: %s" % path) + + if preprocess: + trimmed_dataframe = dataframe[TRIMMED_COL_NAMES] + print("[INFO][Preprocessing] Column names: %s" % + str(list(trimmed_dataframe.columns))) + + sk_state_labels_map = {} + sk_state_labels = sorted(list(set(dataframe['SK_STATE'].tolist()))) + for i in range(len(sk_state_labels)): + sk_state_labels_map[sk_state_labels[i]] = i + + if stats is None or debug: + seq_stats = summarize(dataframe, 'SEQ') + ack_stats = summarize(dataframe, 'ACK') + urg_stats = summarize(dataframe, 'URGPTR') + dataoff_stats = summarize(dataframe, 'DATAOFF') + window_stats = summarize(dataframe, 'WINDOW') + payload_len_stats = summarize(dataframe, 'PAYLOAD_LEN') + ip_len_stats = summarize(dataframe, 'IP_LEN') + ip_ttl_stats = summarize(dataframe, 'IP_TTL') + ip_ihl_stats = summarize(dataframe, 'IP_IHL') + ip_tos_stats = summarize(dataframe, 'IP_TOS') + ip_id_stats = summarize(dataframe, 'IP_ID') + tcp_opt_mss_stats = summarize(dataframe, 'TCP_OPT_MSS') + tcp_opt_tsval_stats = summarize(dataframe, 'TCP_OPT_TSVAL') + tcp_opt_tsecr_stats = summarize(dataframe, 'TCP_OPT_TSECR') + tcp_opt_wscale_stats = summarize(dataframe, 'TCP_OPT_WSCALE') + tcp_opt_uto_stats = summarize(dataframe, 'TCP_OPT_UTO') + tcp_timestamp = summarize(dataframe, 'TCP_TIMESTAMP') + arrival_timestamp = summarize(dataframe, 'ARRIVAL_TIMESTAMP') + '''summarize(dataframe,'HTTP_Method') + summarize(dataframe,'HTTP_Version') + summarize(dataframe,'HTTP_Name') + summarize(dataframe,'HTTP_Value_Domain') + summarize(dataframe,'HTTP_Url') + summarize(dataframe,'HTTP_Path') + summarize(dataframe,'HTTP_Params')''' + #能用数字表示的统计数据 + new_numeric_stats = {"SEQ": seq_stats, "ACK": ack_stats, "URGPTR": urg_stats, + "DATAOFF": dataoff_stats, "WINDOW": window_stats, "PAYLOAD_LEN": payload_len_stats, + "IP_LEN": ip_len_stats, "IP_TTL": ip_ttl_stats, "IP_IHL": ip_ihl_stats, + "IP_TOS": ip_tos_stats, "IP_ID": ip_id_stats, "TCP_OPT_MSS": tcp_opt_mss_stats, + "TCP_OPT_TSVAL": tcp_opt_tsval_stats, "TCP_OPT_TSECR": tcp_opt_tsecr_stats, + "TCP_OPT_WSCALE": tcp_opt_wscale_stats, "TCP_OPT_UTO": tcp_opt_uto_stats, + "TCP_TIMESTAMP": tcp_timestamp, "ARRIVAL_TIMESTAMP": arrival_timestamp} + if debug: + print("Debug stats: %s" % str(new_numeric_stats)) + if stats is None: + numeric_stats = new_numeric_stats + else: + numeric_stats = stats + + attack_id_list = sorted( + list(set(trimmed_dataframe['ATTACK_ID'].tolist()))) + + cnt = 0 + for attack_id in attack_id_list: + if cutoff != -1: + cnt += 1 + if cnt > cutoff: + break + attack_records = trimmed_dataframe.loc[trimmed_dataframe['ATTACK_ID'] == attack_id] + preprocessed_attack_records, labels = preprocess( + attack_records, numeric_stats, sk_state_labels_map, debug=debug, add_additional_features=add_additional_features) + if seq_cutoff != -1: + seq_cutoff = min(seq_cutoff, len(labels)) + preprocessed_attack_records = preprocessed_attack_records[:seq_cutoff] + labels = labels[:seq_cutoff] + labels_stats.extend(labels) + dataset.append([preprocessed_attack_records, labels]) + + labels_stats_counter = Counter(labels_stats) + print("[INFO][Preprocessing] Label map: %s" % str(sk_state_labels_map)) + print("[INFO][Preprocessing] Label stats: %s" % + str(labels_stats_counter)) + + if shuffle: + random.shuffle(dataset) + + if split_train_test: + train_set = dataset[:-len(dataset)//TRAIN_TEST_SPLIT] + test_set = dataset[-len(dataset)//TRAIN_TEST_SPLIT:] + train_loader = torch.utils.data.DataLoader( + train_set, batch_size=batch_size, shuffle=False) + test_loader = torch.utils.data.DataLoader( + test_set, batch_size=batch_size, shuffle=False) + return train_loader, test_loader, sk_state_labels_map, numeric_stats, labels_stats_counter + else: + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False) + return data_loader, sk_state_labels_map, numeric_stats, labels_stats_counter + + +def pause(): + input("Press Enter to continue...") + + +def rnn_loss_function(outputs, labels, weight=None, debug=False): + if debug: + print(outputs.shape) + print(labels.shape) + if weight is not None: + averaged_cross_entropy = F.cross_entropy( + outputs, labels, weight=weight, reduction='mean') + else: + averaged_cross_entropy = F.cross_entropy( + outputs, labels, reduction='mean') + return averaged_cross_entropy + + +def ae_loss_function(recon_x, x, debug=False): + if debug: + print(recon_x.shape) + print(x.shape) + loss = nn.L1Loss(reduction="mean") + return loss(recon_x, x) + + +def get_pred(rnn_outputs): + _, preds = torch.max(rnn_outputs.data, 2) + return preds + + +def print_per_label_accu(correct_labels, incorrect_labels, state_map): + def create_reversed_map(state_map): + reversed_map = {} + for k, v in state_map.items(): + reversed_map[v] = k + return reversed_map + + state_map = create_reversed_map(state_map) + accu_map = {} + for state_id, state in state_map.items(): + if state_id not in correct_labels: + correct = 0 + else: + correct = correct_labels[state_id] + if state_id not in incorrect_labels: + incorrect = 0 + else: + incorrect = incorrect_labels[state_id] + accu_map[state] = {'correct': correct, 'incorrect': incorrect} + if correct + incorrect == 0: + accu_map[state]['accuracy'] = 0.0 + else: + accu_map[state]['accuracy'] = float( + correct) / (correct + incorrect) + print(accu_map) + return accu_map + + +def generate_ngram_seq(seq, n_gram, only_outbound, use_conn_id=False, debug=False): + if only_outbound: + if use_conn_id: + IDX_CONN_ID, IDX_DIRECTION = 0, 1 + else: + IDX_DIRECTION = 0 + filtered_seq = [] + conn_ids = set() + for profile in seq: + if profile.view(-1)[IDX_DIRECTION] == 0.0: + if use_conn_id: + conn_ids.add(profile.view(-1)[IDX_CONN_ID].item()) + profile = profile.view(-1)[IDX_DIRECTION:].view(1, 1, -1) + filtered_seq.append(profile) + if use_conn_id: + assert len(conn_ids) == 1, "[NGRAM] More than 1 conn_id in seq!" + conn_id = int(list(conn_ids)[0]) + seq = filtered_seq + + if len(seq) < n_gram: + return ERR_TOO_SHORT_SEQ + + ngram_seq = [] + start, end = 0, n_gram + while end <= len(seq) - 1: + ngram_sample = torch.cat(seq[start:end]) + if use_conn_id: + ngram_seq.append((conn_id, torch.flatten(ngram_sample))) + else: + ngram_seq.append(torch.flatten(ngram_sample)) + start += 1 + end += 1 + + return ngram_seq + + +def generate_ngram_seq_dataset(loader, n_gram, batch_size=64, debug=False, only_outbound=True): + dataset = [] + for sample_idx, seq in enumerate(loader): + ngram_seq = generate_ngram_seq( + seq, n_gram, only_outbound=only_outbound) + if ngram_seq == ERR_TOO_SHORT_SEQ: + continue + dataset.extend(ngram_seq) + if debug: + print("[INFO][Train] Shape of seq sample: %s" % str(dataset[0].shape)) + print("[INFO][Train] Size of dataset: %d" % len(dataset)) + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) + + +def generate_contextual_profile_dataset(data_loader, device, rnn_model, context_mode, partition_mode, rnn_model_type, label_map, addi_data_loader=None): + if partition_mode == "none": + contextual_dataset = [] + else: + contextual_dataset = {} + + for batch_idx, [x, labels] in enumerate(data_loader): + x = x.to(device, dtype=torch.float) + labels = labels.to(device) + curr_seq = [] + + if context_mode != 'baseline': + outputs, gates, hn = rnn_model(x) + preds = get_pred(outputs) + + for i in range(x.size(1)): + x_features = x[:, i, :] + + if context_mode != 'baseline': + if 'lstm' in rnn_model_type: + resetgate, inputgate, cellgate, outgate = gates[i] + else: + resetgate, inputgate = gates[i] + hiddenstate = hn[i] + pred_label = preds[:, i].item() + gt_label = labels[:, i].item() + + if context_mode == "baseline": + profile = x_features.detach() + elif context_mode == "use_hn": + profile = torch.cat( + (x_features.detach(), hiddenstate.detach()), dim=1) + elif context_mode == "use_all": + if 'lstm' in rnn_model_type: + profile = torch.cat( + (x_features.detach(), hiddenstate.detach(), resetgate.detach(), inputgate.detach(), cellgate.detach(), outgate.detach()), dim=1) + else: + profile = torch.cat( + (x_features.detach(), hiddenstate.detach(), resetgate.detach(), inputgate.detach()), dim=1) + elif context_mode == "only_gates": + profile = torch.cat( + (resetgate.detach(), inputgate.detach()), dim=1) + elif context_mode == "only_hn": + profile = hiddenstate.detach() + elif context_mode == "use_all_gates": + profile = torch.cat( + (x_features.detach(), resetgate.detach(), inputgate.detach(), cellgate.detach(), outgate.detach()), dim=1) + elif context_mode == "use_gates": + profile = torch.cat( + (x_features.detach(), resetgate.detach(), inputgate.detach()), dim=1) + elif context_mode == "use_gates_label": + state_str = label_map[pred_label] + label_vec = [0] * (len(nf_conntrack_states) + 1) + for i in range(len(nf_conntrack_states)): + if nf_conntrack_states[i] in state_str: + label_vec[i] = 1.0 + if 'IW' in state_str: + label_vec[-1] = 1.0 + label_vec = torch.tensor(label_vec).to(device) + label_vec = label_vec.view(1, len(nf_conntrack_states)+1) + profile = torch.cat( + (x_features.detach(), label_vec.detach(), resetgate.detach(), inputgate.detach()), dim=1) + + if partition_mode == "none": + curr_seq.append(profile) + elif partition_mode == "pred_label": + if pred_label not in contextual_dataset: + contextual_dataset[pred_label] = [profile] + else: + contextual_dataset[pred_label].append(profile) + elif partition_mode == "gt_label": + if gt_label not in contextual_dataset: + contextual_dataset[gt_label] = [profile] + else: + contextual_dataset[gt_label].append(profile) + + if partition_mode == "none": + contextual_dataset.append(curr_seq) + + return contextual_dataset + + +def generate_contextual_profile_dataset_fused(data_loader, device, rnn_model, context_mode, partition_mode, rnn_model_type, label_map, addi_data_loader): + if partition_mode == "none": + contextual_dataset = [] + else: + contextual_dataset = {} + + for batch_idx, ([x, labels], [x2, _]) in enumerate(zip(data_loader, addi_data_loader)): + x = x.to(device, dtype=torch.float) + x2 = x2.to(device, dtype=torch.float) + labels = labels.to(device) + curr_seq = [] + + if context_mode != 'baseline': + outputs, gates, hn = rnn_model(x) + preds = get_pred(outputs) + + for i in range(x.size(1)): + x_features = x[:, i, :] + x2_features = x2[:, i, :] + + if context_mode != 'baseline': + if 'lstm' in rnn_model_type: + resetgate, inputgate, cellgate, outgate = gates[i] + else: + resetgate, inputgate = gates[i] + hiddenstate = hn[i] + pred_label = preds[:, i].item() + gt_label = labels[:, i].item() + + if context_mode == "baseline": + profile = x2_features.detach() + elif context_mode == "use_hn": + profile = torch.cat( + (x2_features.detach(), hiddenstate.detach()), dim=1) + elif context_mode == "use_all": + if 'lstm' in rnn_model_type: + profile = torch.cat( + (x2_features.detach(), hiddenstate.detach(), resetgate.detach(), inputgate.detach(), cellgate.detach(), outgate.detach()), dim=1) + else: + profile = torch.cat( + (x2_features.detach(), hiddenstate.detach(), resetgate.detach(), inputgate.detach()), dim=1) + elif context_mode == "only_gates": + profile = torch.cat( + (resetgate.detach(), inputgate.detach()), dim=1) + elif context_mode == "only_hn": + profile = hiddenstate.detach() + elif context_mode == "use_all_gates": + profile = torch.cat( + (x2_features.detach(), resetgate.detach(), inputgate.detach(), cellgate.detach(), outgate.detach()), dim=1) + elif context_mode == "use_gates": + profile = torch.cat( + (x2_features.detach(), resetgate.detach(), inputgate.detach()), dim=1) + elif context_mode == "use_gates_label": + state_str = label_map[pred_label] + label_vec = [0] * (len(nf_conntrack_states) + 1) + for i in range(len(nf_conntrack_states)): + if nf_conntrack_states[i] in state_str: + label_vec[i] = 1.0 + if 'IW' in state_str: + label_vec[-1] = 1.0 + label_vec = torch.tensor(label_vec).to(device) + label_vec = label_vec.view(1, len(nf_conntrack_states)+1) + profile = torch.cat( + (x2_features.detach(), label_vec.detach(), resetgate.detach(), inputgate.detach()), dim=1) + + if partition_mode == "none": + curr_seq.append(profile) + elif partition_mode == "pred_label": + if pred_label not in contextual_dataset: + contextual_dataset[pred_label] = [profile] + else: + contextual_dataset[pred_label].append(profile) + elif partition_mode == "gt_label": + if gt_label not in contextual_dataset: + contextual_dataset[gt_label] = [profile] + else: + contextual_dataset[gt_label].append(profile) + + if partition_mode == "none": + contextual_dataset.append(curr_seq) + + return contextual_dataset + + +def get_losslist(overall_data_loader, vae_model, vae_input_size, n_gram, debug=False, only_outbound=True, use_conn_id=False, draw_trend=True): + def get_windowed_top_loss(loss_list, max_idx, window_size=5): + if len(loss_list) < window_size: + return sum(loss_list) / len(loss_list) + start, end = max_idx, max_idx + while end - start < window_size and (start > 0 or end < len(loss_list) - 1): + if start > 0: + start -= 1 + if end < len(loss_list) - 1: + end += 1 + assert len(loss_list[start:end]) == end - start, "Size unmatch!" + return sum(loss_list[start:end]) / len(loss_list[start:end]) + + if isinstance(overall_data_loader, dict): + attack_test_loss = {} + attack_cnt = {} + attack_loss_list = {} + for label, data_loader in overall_data_loader.items(): + attack_test_loss[label] = 0.0 + attack_cnt[label] = 0 + attack_loss_list[label] = [] + for batch_idx, profile in enumerate(data_loader): + attack_cnt[label] += 1 + profile = profile.view(1, vae_input_size) + recon_profile = vae_model[label](profile) + loss = ae_loss_function(recon_profile, profile) + curr_loss = loss.item() + attack_loss_list[label].append(curr_loss) + attack_test_loss[label] += curr_loss + + return attack_cnt, attack_test_loss, attack_loss_list + else: + attack_test_loss, seq_test_loss = 0, 0 + seq_cnt, attack_cnt = 0, 0 + attack_loss_list, seq_loss_list = [], [] + if draw_trend: + x, y = {}, {} + for batch_idx, seq in enumerate(overall_data_loader): + ngram_seq = generate_ngram_seq( + seq, n_gram, only_outbound=only_outbound, use_conn_id=use_conn_id, debug=debug) + if debug: + input(ngram_seq) + if ngram_seq == ERR_TOO_SHORT_SEQ: + continue + if len(ngram_seq) == 0: + continue + seq_cnt += len(ngram_seq) + attack_cnt += 1 + + max_loss = 0.0 + total_loss = 0.0 + max_idx = 0 + curr_loss_list = [] + for idx, ngram in enumerate(ngram_seq): + if use_conn_id: + conn_id, ngram = ngram + ngram = ngram.view(1, vae_input_size) + recon_ngram = vae_model(ngram) + loss = ae_loss_function(recon_ngram, ngram) + curr_loss = loss.item() + total_loss += curr_loss + seq_test_loss += curr_loss + seq_loss_list.append(curr_loss) + curr_loss_list.append(curr_loss) + + if debug: + input("Sample #%d max recon error: %f" % (batch_idx, max_loss)) + if draw_trend: + if len(curr_loss_list) > 50: + x[str(conn_id)] = [i for i in range( + 1, len(curr_loss_list) + 1)] + y[str(conn_id)] = curr_loss_list + + max_loss = max(curr_loss_list) + top_loss_idx = sorted(range(len(curr_loss_list)), + key=lambda i: curr_loss_list[i], reverse=True)[:5] + max_loss_idx = top_loss_idx[0] + windowed_mean_loss = get_windowed_top_loss( + curr_loss_list, max_loss_idx, 5) + mean_loss = total_loss / len(ngram_seq) + median_loss = statistics.median(curr_loss_list) + r1 = 0.0 + r2 = 0.0 + r3 = 0.0 + r4 = 1.0 + weighted_loss = r1 * max_loss + r2 * mean_loss + \ + r3 * median_loss + r4 * windowed_mean_loss + attack_test_loss += weighted_loss + if debug: + input("max_loss: %f (max_id: %d); average_loss: %f" % + (max_loss, max_idx, weighted_loss)) + if use_conn_id: + attack_loss_list.append( + (weighted_loss, str(top_loss_idx), str(conn_id), len(ngram_seq))) + else: + attack_loss_list.append( + (weighted_loss, str(top_loss_idx), len(ngram_seq))) + + if draw_trend: + return attack_cnt, seq_cnt, attack_test_loss, seq_test_loss, attack_loss_list, seq_loss_list, x, y + else: + return attack_cnt, seq_cnt, attack_test_loss, seq_test_loss, attack_loss_list, seq_loss_list + + +def plot_roc_curve(fpr, tpr, score, fig_path, ds_title): + plt.title('ROC Curve for %s Attack' % ds_title) + plt.plot(fpr, tpr, 'b', label='AUC = %0.2f' % score) + plt.legend(loc='lower right') + plt.plot([0, 1], [0, 1], 'r--') + plt.xlim([0, 1]) + plt.ylim([0, 1]) + plt.ylabel('True Positive Rate') + plt.xlabel('False Positive Rate') + plt.savefig(fig_path) + plt.close() + + +def plot_roc_curve_comparison(fpr1, tpr1, fpr2, tpr2, score1, score2, fig_path, ds_title): + plt.title('ROC Curve on %s Attack' % ds_title) + plt.plot(fpr1, tpr1, 'grey', label='Baseline, AUC = %0.2f' % + score1, linestyle='dashed') + plt.plot(fpr2, tpr2, 'b', label='Our Approach, AUC = %0.2f' % score2) + plt.legend(loc='lower right') + plt.plot([0, 1], [0, 1], 'r--') + plt.xlim([0, 1]) + plt.ylim([0, 1]) + plt.ylabel('True Positive Rate') + plt.xlabel('False Positive Rate') + plt.savefig(fig_path) + plt.close() + + +def read_loss_list(loss_list, balance_by_label=False, deduplicate=False): + with open(loss_list, "r") as fin: + data = fin.readlines() + + if deduplicate: + data = list(set(data)) + y = [] + scores = [] + random.shuffle(data) + + top_loss_lst = {} + use_top_loss = False + + for row in data: + use_top_loss = False + if len(row) <= 1: + continue + #那么在处理只有4列的行之前如果处理过一个有5列的行,use_top_loss将错误地保持为True,这会导致处理逻辑错误。 + if len(row.rstrip('\n').split("\t")) == 4: + loss, idx, leng, label = row.rstrip('\n').split("\t") + elif len(row.rstrip('\n').split("\t")) == 5: + use_top_loss = True + loss, idx, conn_id, leng, label = row.rstrip('\n').split("\t") + else: + print(row) + input("WTF? %d" % len(row.rstrip('\n').split("\t"))) + + if use_top_loss: + top_loss_lst[conn_id] = eval(idx) + y.append(int(label)) + scores.append(float(loss)) + + if balance_by_label: + label_set = collections.Counter(y) + attack_cnt = label_set[1] + benign_cnt = label_set[0] + smaller = min(attack_cnt, benign_cnt) + print("[INFO] Attack count: %d" % attack_cnt) + print("[INFO] Benign count: %d" % benign_cnt) + if use_top_loss: + return y[:smaller], scores[:smaller], top_loss_lst + else: + return y[:smaller], scores[:smaller] + else: + #if use_top_loss: + return y, scores, top_loss_lst + #else: + # return y, scores + + +def calculate_acc(outputs, labels, debug=False): + _, preds = torch.max(outputs.data, 1) + if debug: + correct_list = (preds == labels) + print(correct_list) + print(labels) + print(labels[correct_list]) + print(labels[correct_list == False]) + input("Press Enter to continue...") + + correct_list = (preds == labels) + correct_cnt = correct_list.sum() + total_cnt = labels.size(0) + + correct_labels = labels[correct_list] + incorrect_labels = labels[correct_list == False] + + return correct_cnt, total_cnt, correct_labels, incorrect_labels |
