diff options
Diffstat (limited to 'script/test.py')
| -rw-r--r-- | script/test.py | 275 |
1 files changed, 275 insertions, 0 deletions
diff --git a/script/test.py b/script/test.py new file mode 100644 index 0000000..e3025fb --- /dev/null +++ b/script/test.py @@ -0,0 +1,275 @@ +from utils import nf_conntrack_states +from utils import get_losslist +from utils import read_dataset, generate_ngram_seq, generate_contextual_profile_dataset, generate_contextual_profile_dataset_fused +from utils import AEModel, GRUModel, GRUCell +import argparse +import torch +from torch import nn, optim +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np +import pandas +import random +import time +import pickle +from os import path + +import matplotlib.pyplot as plt +import matplotlib + +font = {'family': 'normal', + 'weight': 'bold', + 'size': 16} + +matplotlib.rc('font', **font) + + +ERR_TOO_SHORT_SEQ = -1 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='GRU for learning benign contextual profiles') + parser.add_argument('--attack-dataset', type=str, + help='path to positive dataset file') + parser.add_argument('--benign-dataset', type=str, + help='path to negative dataset file') + parser.add_argument('--dataset-stats', type=str, help='path to stats file') + parser.add_argument('--loss-list-fpath', type=str, + help='path to dump loss list') + parser.add_argument('--rnn-model', type=str, help='path to RNN model file') + parser.add_argument('--vae-model', type=str, help='path to VAE model file') + parser.add_argument('--rnn-hidden-size', type=int, + help='hidden state size') + parser.add_argument('--input-size', type=int, help='size of RNN input') + parser.add_argument('--device', type=str, help='device for training') + parser.add_argument('--seed', type=int, default=1, + metavar='S', help='random seed (default: 1)') + parser.add_argument('--batch-size', type=int, default=1, + help='batch size for training and testing') + parser.add_argument('--cutoff', type=int, default=-1, + help='cutoff for rnn training (default: -1)') + parser.add_argument('--error-thres', type=float, + help='threshold of reconstruction error') + parser.add_argument('--n-gram', type=int, default=3, + help='n-gram for training/testing the autoencoder (default: 3)') + parser.add_argument('--debug', action="store_true", + help='enables debugging information') + parser.add_argument('--context-mode', type=str, + default='use_gates', help='type of profile') + parser.add_argument('--partition-mode', type=str, + default='none', help='type of partitioning') + parser.add_argument('--rnn-model-type', type=str, + default='gru', help='type of partitioning') + parser.add_argument('--extra-features', type=str, + help='whether to include post-mortem features.') + parser.add_argument('--conn-dir', type=str, + help='direction of connection to play with.') + parser.add_argument('--use-conn-id', action='store_true', default=True, + help='use connection ids to track adv pkts.') + parser.add_argument('--paint-trend', action='store_true', default=False) + args = parser.parse_args() + + if args.seed: + torch.manual_seed(args.seed) + random.seed(int(args.seed)) + + device = torch.device(args.device if args.device else "cpu") + + with open(args.dataset_stats, 'rb') as fin: + stats_info = pickle.load(fin) + + print("[INFO] Stats used for dataset:") + print(stats_info) + stats = stats_info['stats'] + label_map = stats_info['label_map'] + reversed_label_map = {} + for label, label_id in label_map.items(): + reversed_label_map[label_id] = label + + attack_test_loader, _, _, cnt_map = read_dataset( + args.attack_dataset, batch_size=args.batch_size, preprocess=True, cutoff=args.cutoff, split_train_test=False, stats=stats, debug=True) + benign_test_loader, _, _, _ = read_dataset( + args.benign_dataset, batch_size=args.batch_size, preprocess=True, cutoff=args.cutoff, split_train_test=False, stats=stats, debug=True) + + start_timestamp = time.time() + print("[INFO] Stating timing: %f" % start_timestamp) + + if args.extra_features == 'all_addi': + addi_attack_test_loader, _, _, _ = read_dataset( + args.attack_dataset, batch_size=args.batch_size, preprocess=True, cutoff=args.cutoff, split_train_test=False, stats=stats, debug=True, add_additional_features=True, use_conn_id=args.use_conn_id) + addi_benign_test_loader, _, _, _ = read_dataset( + args.benign_dataset, batch_size=args.batch_size, preprocess=True, cutoff=args.cutoff, split_train_test=False, stats=stats, debug=True, add_additional_features=True, use_conn_id=args.use_conn_id) + + else:#xiugai + addi_attack_test_loader, _, _, _ = read_dataset( + args.attack_dataset, batch_size=args.batch_size, preprocess=True, cutoff=args.cutoff, split_train_test=False, stats=stats, debug=True, add_additional_features=False, use_conn_id=args.use_conn_id) + addi_benign_test_loader, _, _, _ = read_dataset( + args.benign_dataset, batch_size=args.batch_size, preprocess=True, cutoff=args.cutoff, split_train_test=False, stats=stats, debug=True, add_additional_features=False, use_conn_id=args.use_conn_id) + input_size = args.input_size + hidden_size = args.rnn_hidden_size + batch_size = 1 + if 'bi_' in args.rnn_model_type: + rnn_bidirectional = True + else: + rnn_bidirectional = False + + rnn_model = torch.load(args.rnn_model) + rnn_model.eval() # Setting to eval model since this is testing phase... + + if args.conn_dir == 'only_outbound': + only_outbound = True + else: + only_outbound = False + + if args.extra_features == 'all_addi': + start_feature_ext_ts = time.time() + attack_contextual_dataset = generate_contextual_profile_dataset_fused( + attack_test_loader, device, rnn_model, context_mode=args.context_mode, partition_mode=args.partition_mode, rnn_model_type=args.rnn_model_type, label_map=reversed_label_map, addi_data_loader=addi_attack_test_loader) + finish_feature_ext_ts = time.time() + benign_contextual_dataset = generate_contextual_profile_dataset_fused( + benign_test_loader, device, rnn_model, context_mode=args.context_mode, partition_mode=args.partition_mode, rnn_model_type=args.rnn_model_type, label_map=reversed_label_map, addi_data_loader=addi_benign_test_loader) + num_addi_features = 15 + else: + #xiugai huancheng fused + start_feature_ext_ts = time.time() + attack_contextual_dataset = generate_contextual_profile_dataset_fused( + attack_test_loader, device, rnn_model, context_mode=args.context_mode, partition_mode=args.partition_mode, rnn_model_type=args.rnn_model_type, label_map=reversed_label_map,addi_data_loader=addi_attack_test_loader) + finish_feature_ext_ts = time.time() + benign_contextual_dataset = generate_contextual_profile_dataset_fused( + benign_test_loader, device, rnn_model, context_mode=args.context_mode, partition_mode=args.partition_mode, rnn_model_type=args.rnn_model_type, label_map=reversed_label_map,addi_data_loader=addi_benign_test_loader) + num_addi_features = 0 + + if args.context_mode == "baseline": + vae_input_size = (input_size + num_addi_features) * args.n_gram + elif args.context_mode == "use_hn": + vae_input_size = (input_size + hidden_size) * args.n_gram + elif args.context_mode == "use_all": + vae_input_size = (input_size + hidden_size * 5) * args.n_gram + elif args.context_mode == "only_gates": + vae_input_size = (hidden_size * 2) * args.n_gram + elif args.context_mode == "only_hn": + vae_input_size = hidden_size * args.n_gram + elif args.context_mode == "use_all_gates": + vae_input_size = (input_size + hidden_size * 4) * args.n_gram + elif args.context_mode == "use_gates": + vae_input_size = (input_size + num_addi_features + + hidden_size * 2) * args.n_gram + elif args.context_mode == "use_gates_label": + vae_input_size = (input_size + num_addi_features + hidden_size * + 2 + len(nf_conntrack_states) + 1) * args.n_gram + + if args.partition_mode == 'none': + vae_model = torch.load(args.vae_model) + else: + vae_model = {} + new_label_map = {} + for label, label_id in label_map.items(): + model_fpath = "%s.%s" % ( + args.vae_model, str(reversed_label_map[label_id])) + if path.isfile(model_fpath): + vae_model[label] = torch.load(model_fpath) + new_label_map[label_id] = label + else: + print("[ERROR] Model file %s not found" % model_fpath) + label_map = new_label_map + + if args.partition_mode == "none": + attack_profile_loader = torch.utils.data.DataLoader( + attack_contextual_dataset, batch_size=batch_size, shuffle=False) + benign_profile_loader = torch.utils.data.DataLoader( + benign_contextual_dataset, batch_size=batch_size, shuffle=False) + else: + attack_profile_loader, benign_profile_loader = {}, {} + for label_id, label in label_map.items(): + if label_id in attack_contextual_dataset: + attack_profile_loader[label] = torch.utils.data.DataLoader( + attack_contextual_dataset[label_id], batch_size=batch_size, shuffle=False) + if label_id in benign_contextual_dataset: + benign_profile_loader[label] = torch.utils.data.DataLoader( + benign_contextual_dataset[label_id], batch_size=batch_size, shuffle=False) + + if args.partition_mode == "none": + start_loss_ts = time.time() + attack_cnt, attack_seq_cnt, attack_test_loss, attack_seq_test_loss, attack_loss_list, attack_seq_loss_list, attack_x, attack_y = get_losslist( + attack_profile_loader, vae_model, vae_input_size, args.n_gram, debug=args.debug, only_outbound=only_outbound, use_conn_id=args.use_conn_id) + finish_loss_ts = time.time() + benign_cnt, benign_seq_cnt, benign_test_loss, benign_seq_test_loss, benign_loss_list, benign_seq_loss_list, benign_x, benign_y = get_losslist( + benign_profile_loader, vae_model, vae_input_size, args.n_gram, debug=args.debug, only_outbound=only_outbound, use_conn_id=args.use_conn_id) + if args.paint_trend: + for conn_id in attack_x.keys() & benign_x.keys(): + attk_x, attk_y = attack_x[conn_id], attack_y[conn_id] + begn_x, begn_y = benign_x[conn_id], benign_y[conn_id] + plt.plot(attk_x, attk_y, color='red', linewidth=3, + label='Adversarial') + plt.plot(begn_x, begn_y, color='green', + linewidth=3, label='Benign') + plt.ylim((0.0, 0.06)) + plt.xlim((0, 60)) + plt.xlabel("Index # of Context Profile", + fontsize=20, fontweight='bold') + plt.ylabel("Recounstruction Error", + fontsize=20, fontweight='bold') + plt.legend(loc='upper right') + plt.tight_layout() + plt.show() + else: + attack_cnt, attack_test_loss, attack_loss_list = get_losslist( + attack_profile_loader, vae_model, vae_input_size, args.n_gram, debug=args.debug, only_outbound=only_outbound) + benign_cnt, benign_test_loss, benign_loss_list = get_losslist( + benign_profile_loader, vae_model, vae_input_size, args.n_gram, debug=args.debug, only_outbound=only_outbound) + + end_timestamp = time.time() + print("[INFO] Ending timing: %f" % end_timestamp) + duration = end_timestamp - start_timestamp + feature_ext_duration = finish_feature_ext_ts - start_feature_ext_ts + loss_duration = finish_loss_ts - start_loss_ts + pkt_cnt = sum(list(cnt_map.values())) + conn_cnt = len(attack_test_loader) + + print("[INFO] Total # of connections: %d; # of packets: %d; total elapsed time: %f; time for feature extraction: %f; time for computing loss: %f" % ( + conn_cnt, pkt_cnt, duration, feature_ext_duration, loss_duration)) + print("[INFO] Averge processing time per packet: %f" % + ((feature_ext_duration + loss_duration) / pkt_cnt)) + print("[INFO] Averge processing time per connection: %f" % + ((feature_ext_duration + loss_duration) / conn_cnt)) + + if args.partition_mode == "none": + with open(args.loss_list_fpath + '.UNILABEL', 'w') as fin: + for (loss, idx, conn_id, leng) in attack_loss_list: + fin.write( + '\t'.join([str(loss), str(idx), str(conn_id), str(leng), '1']) + '\n') + for (loss, idx, _, leng) in benign_loss_list: + fin.write( + '\t'.join([str(loss), str(idx), str(leng), '0']) + '\n') + else: + losslist_files = {} + for _, label in label_map.items(): + losslist_files[label] = open( + '%s.%s' % (args.loss_list_fpath, label), 'w') + + for label, loss_list in attack_loss_list.items(): + for (loss, idx) in loss_list: + losslist_files[label].write("%f,%s,%s\n" % (loss, idx, '1')) + for label, loss_list in benign_loss_list.items(): + for (loss, idx) in loss_list: + losslist_files[label].write("%f,%s,%s\n" % (loss, idx, '0')) + + for label, f in losslist_files.items(): + f.close() + + if args.partition_mode == "none": + print("Number of connections: %d | %d" % (attack_cnt, benign_cnt)) + print("Number of sequences: %d | %d" % + (attack_seq_cnt, benign_seq_cnt)) + print('Per-connection average loss: {:.4f} | {:.4f}'.format( + attack_test_loss/attack_cnt, benign_test_loss/benign_cnt)) + print('Per-seq average loss: {:.4f} | {:.4f}'.format( + attack_seq_test_loss/attack_seq_cnt, benign_seq_test_loss/benign_seq_cnt)) + else: + for label, _ in attack_loss_list.items(): + print("----- Label %s -----" % label) + print("Number of connections: %d | %d" % + (attack_cnt[label], benign_cnt[label])) + print('Per-connection average loss: {:.4f} | {:.4f}'.format( + attack_test_loss[label]/attack_cnt[label], benign_test_loss[label]/benign_cnt[label])) |
