from utils import * from os import listdir import argparse from scipy.optimize import brentq from scipy.interpolate import interp1d def get_localization_acc(adv_pkt_idx_dict, top_loss_idx_dict, ngram=3): def ngram_idx_lst(lst, ngram): new_lst = [] for top_loss_idx in lst: new_lst.append(top_loss_idx) for delta in range(1, ngram): new_lst.append(top_loss_idx + delta) return new_lst hit_dict = {} top_1_hit_cnt = 0 top_3_hit_cnt = 0 top_5_hit_cnt = 0 for conn_id, adv_pkt_idx_lst in adv_pkt_idx_dict.items(): if conn_id not in top_loss_idx_dict: continue top_loss_idx_lst = top_loss_idx_dict[conn_id] if len(top_loss_idx_lst) == 0: continue adv_pkt_idx_set = set(adv_pkt_idx_lst) top_1 = ngram_idx_lst([top_loss_idx_lst[0]], ngram) if len(top_loss_idx_lst) >= 5: top_5 = ngram_idx_lst(top_loss_idx_lst[:5], ngram) else: top_5 = ngram_idx_lst(top_loss_idx_lst, ngram) if len(top_loss_idx_lst) >= 3: top_3 = ngram_idx_lst(top_loss_idx_lst[:3], ngram) else: top_3 = ngram_idx_lst(top_loss_idx_lst, ngram) top_1_set = set(top_1) top_3_set = set(top_3) top_5_set = set(top_5) if len(top_1_set.intersection(adv_pkt_idx_set)) != 0: top_1_hit = True else: top_1_hit = False if len(top_3_set.intersection(adv_pkt_idx_set)) != 0: top_3_hit = True else: top_3_hit = False if len(top_5_set.intersection(adv_pkt_idx_set)) != 0: top_5_hit = True else: top_5_hit = False hit_dict[conn_id] = (top_1_hit, top_3_hit, top_5_hit) for conn_id, hits in hit_dict.items(): if hits[0]: top_1_hit_cnt += 1 if hits[1]: top_3_hit_cnt += 1 if hits[2]: top_5_hit_cnt += 1 top_1_hit_acc = top_1_hit_cnt / len(hit_dict) top_3_hit_acc = top_3_hit_cnt / len(hit_dict) top_5_hit_acc = top_5_hit_cnt / len(hit_dict) return top_1_hit_acc, top_3_hit_acc, top_5_hit_acc if __name__ == "__main__": parser = argparse.ArgumentParser( description='Let us compute the ROC curve.') parser.add_argument('--loss-list-fname', type=str, help='1st path to loss list') parser.add_argument('--n-gram', type=int) parser.add_argument('--attack-info-fpath', type=str, help='to calculate the localization perf') parser.add_argument('--loss-list-dir', type=str, help='1st path to loss list') parser.add_argument('--fig-fpath', type=str, help='path to save figure') parser.add_argument('--ds-title', type=str, default='SymTCP', help='title in curve painting') parser.add_argument('--balance', action='store_true', default=False, help='whether to balance by labels') args = parser.parse_args() all_losslist_files = listdir(args.loss_list_dir) losslist_files = {} for fname in all_losslist_files: #遍历文件夹下的所有重构误差损失文件,找到指定的损失文件 if args.loss_list_fname in fname: label = fname.split('.')[-1] attack_type = fname.split('.')[-9] losslist_files[label] = (fname, attack_type) #以fname:mawi_ws_ds_sorted.42.42.3.50.1000.-1.-1.22.use_gates.none.gru.large.weighted.no_addi.only_outbound.UNILABEL为例子 #label: UNILABEL attacktype:22 for label, (fname, attack_type) in losslist_files.items(): lossfile_fname = '/'.join([args.loss_list_dir, fname]) y, scores, top_loss_lst = read_loss_list( lossfile_fname, balance_by_label=args.balance) if len(y) == 0: print(">>>>>>>>> [Label %s; Attack type %s] Not found <<<<<<<<<" % ( label, attack_type)) continue idx = 0 conn_ids = list(top_loss_lst.keys()) #print(top_loss_lst) for label, score in zip(y, scores): # 使用 zip 来同时迭代 y 和 scores if label == 1 : #阈值设定,可以调整分数的大小 if score > 0.0159: conn_id = conn_ids[idx] # 获取当前的 conn_id if conn_id in top_loss_lst: # 检查是否存在对应的 conn_id print(','.join([str(conn_id),str(top_loss_lst[str(conn_id)])])) # 输出对应 conn_id 的值 idx += 1