summaryrefslogtreecommitdiff
path: root/script/deeptest.py
diff options
context:
space:
mode:
Diffstat (limited to 'script/deeptest.py')
-rw-r--r--script/deeptest.py124
1 files changed, 124 insertions, 0 deletions
diff --git a/script/deeptest.py b/script/deeptest.py
new file mode 100644
index 0000000..1a2b82a
--- /dev/null
+++ b/script/deeptest.py
@@ -0,0 +1,124 @@
+from utils import *
+from os import listdir
+import argparse
+
+from scipy.optimize import brentq
+from scipy.interpolate import interp1d
+
+
+def get_localization_acc(adv_pkt_idx_dict, top_loss_idx_dict, ngram=3):
+ def ngram_idx_lst(lst, ngram):
+ new_lst = []
+ for top_loss_idx in lst:
+ new_lst.append(top_loss_idx)
+ for delta in range(1, ngram):
+ new_lst.append(top_loss_idx + delta)
+ return new_lst
+
+ hit_dict = {}
+ top_1_hit_cnt = 0
+ top_3_hit_cnt = 0
+ top_5_hit_cnt = 0
+
+ for conn_id, adv_pkt_idx_lst in adv_pkt_idx_dict.items():
+ if conn_id not in top_loss_idx_dict:
+ continue
+
+ top_loss_idx_lst = top_loss_idx_dict[conn_id]
+ if len(top_loss_idx_lst) == 0:
+ continue
+
+ adv_pkt_idx_set = set(adv_pkt_idx_lst)
+
+ top_1 = ngram_idx_lst([top_loss_idx_lst[0]], ngram)
+
+ if len(top_loss_idx_lst) >= 5:
+ top_5 = ngram_idx_lst(top_loss_idx_lst[:5], ngram)
+ else:
+ top_5 = ngram_idx_lst(top_loss_idx_lst, ngram)
+ if len(top_loss_idx_lst) >= 3:
+ top_3 = ngram_idx_lst(top_loss_idx_lst[:3], ngram)
+ else:
+ top_3 = ngram_idx_lst(top_loss_idx_lst, ngram)
+
+ top_1_set = set(top_1)
+ top_3_set = set(top_3)
+ top_5_set = set(top_5)
+
+ if len(top_1_set.intersection(adv_pkt_idx_set)) != 0:
+ top_1_hit = True
+ else:
+ top_1_hit = False
+ if len(top_3_set.intersection(adv_pkt_idx_set)) != 0:
+ top_3_hit = True
+ else:
+ top_3_hit = False
+ if len(top_5_set.intersection(adv_pkt_idx_set)) != 0:
+ top_5_hit = True
+ else:
+ top_5_hit = False
+
+ hit_dict[conn_id] = (top_1_hit, top_3_hit, top_5_hit)
+
+ for conn_id, hits in hit_dict.items():
+ if hits[0]:
+ top_1_hit_cnt += 1
+ if hits[1]:
+ top_3_hit_cnt += 1
+ if hits[2]:
+ top_5_hit_cnt += 1
+
+ top_1_hit_acc = top_1_hit_cnt / len(hit_dict)
+ top_3_hit_acc = top_3_hit_cnt / len(hit_dict)
+ top_5_hit_acc = top_5_hit_cnt / len(hit_dict)
+
+ return top_1_hit_acc, top_3_hit_acc, top_5_hit_acc
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description='Let us compute the ROC curve.')
+ parser.add_argument('--loss-list-fname', type=str,
+ help='1st path to loss list')
+ parser.add_argument('--n-gram', type=int)
+ parser.add_argument('--attack-info-fpath', type=str,
+ help='to calculate the localization perf')
+ parser.add_argument('--loss-list-dir', type=str,
+ help='1st path to loss list')
+ parser.add_argument('--fig-fpath', type=str, help='path to save figure')
+ parser.add_argument('--ds-title', type=str,
+ default='SymTCP', help='title in curve painting')
+ parser.add_argument('--balance', action='store_true',
+ default=False, help='whether to balance by labels')
+ args = parser.parse_args()
+
+ all_losslist_files = listdir(args.loss_list_dir)
+ losslist_files = {}
+ for fname in all_losslist_files:
+ #遍历文件夹下的所有重构误差损失文件,找到指定的损失文件
+ if args.loss_list_fname in fname:
+ label = fname.split('.')[-1]
+ attack_type = fname.split('.')[-9]
+ losslist_files[label] = (fname, attack_type)
+
+ #以fname:mawi_ws_ds_sorted.42.42.3.50.1000.-1.-1.22.use_gates.none.gru.large.weighted.no_addi.only_outbound.UNILABEL为例子
+ #label: UNILABEL attacktype:22
+ for label, (fname, attack_type) in losslist_files.items():
+ lossfile_fname = '/'.join([args.loss_list_dir, fname])
+ y, scores, top_loss_lst = read_loss_list(
+ lossfile_fname, balance_by_label=args.balance)
+ if len(y) == 0:
+ print(">>>>>>>>> [Label %s; Attack type %s] Not found <<<<<<<<<" % (
+ label, attack_type))
+ continue
+ idx = 0
+ conn_ids = list(top_loss_lst.keys())
+ #print(top_loss_lst)
+ for label, score in zip(y, scores): # 使用 zip 来同时迭代 y 和 scores
+ if label == 1 :
+ #阈值设定,可以调整分数的大小
+ if score > 0.0159:
+ conn_id = conn_ids[idx] # 获取当前的 conn_id
+ if conn_id in top_loss_lst: # 检查是否存在对应的 conn_id
+ print(','.join([str(conn_id),str(top_loss_lst[str(conn_id)])])) # 输出对应 conn_id 的值
+ idx += 1