diff options
Diffstat (limited to 'script/paint_fig.py')
| -rw-r--r-- | script/paint_fig.py | 351 |
1 files changed, 351 insertions, 0 deletions
diff --git a/script/paint_fig.py b/script/paint_fig.py new file mode 100644 index 0000000..371391e --- /dev/null +++ b/script/paint_fig.py @@ -0,0 +1,351 @@ +import matplotlib as mpl +import argparse +import pprint +from math import ceil + +import matplotlib +import matplotlib.pyplot as plt + +pp = pprint.PrettyPrinter(indent=2) + +font = {'family': 'sans-serif', + 'weight': 'bold', + 'size': 14} + +matplotlib.rcParams['hatch.linewidth'] = 0.9 + +matplotlib.rc('font', **font) + +NAME_MAP = { + "Geneva_Strategy_1*max*": "Invalid Data-Offset,Bad TCP Checksum", + "Geneva_Strategy_2*max*": "Invalid Data-Offset,Low TTL", + "Geneva_Strategy_3*max*": "Invalid Data-Offset,Bad ACK Num", + "Geneva_Strategy_4*max*": "Invalid TCP WScale-Option,Invalid Data-Offset", + "Geneva_Strategy_5*max*": "Bad Payload Length,Bad TCP Checksum", + "Geneva_Strategy_6*max*": "Bad Payload Length,Low TTL", + "Geneva_Strategy_7*max*": "Bad Payload Length,Bad ACK Num", + "Geneva_Strategy_8*max*": "/,Bad Payload Length", + "Geneva_Strategy_9*max*": "Bad IP Length,/", + "Geneva_Strategy_10*max*": "Injected RST,Bad IP Length", + "Geneva_Strategy_11*max*": "Injected RST,Bad TCP Checksum", + "Geneva_Strategy_12*max*": "Injected RST,Low TTL", + "Geneva_Strategy_13*max*": "Bad TCP MD5-Option,Injected RST", + "Geneva_Strategy_14*max*": "Injected RST-ACK,Bad TCP Checksum", + "Geneva_Strategy_15*max*": "Injected RST-ACK,Low TTL", + "Geneva_Strategy_16*max*": "Bad TCP MD5-Option,Injected RST", + "Geneva_Strategy_17*max*": "Invalid Flags #1,Bad TCP Checksum", + "Geneva_Strategy_18*max*": "Invalid Flags #2,Low TTL", + "Geneva_Strategy_19*max*": "Invalid Flags #2,Bad TCP MD5-Option", + "Geneva_Strategy_23*max*": "Injected FIN,Bad IP Length", + "Geneva_Strategy_24*max*": "Injected SYN-ACK,Bad TCP MD5-Option", + "Geneva_Strategy_23*max*": "Bad TCP UTO-Option,Bad TCP MD5-Option", + "Liberate_IP_InvalidHeaderLen*max*": "Invalid IP Header Length,Max", + "Liberate_IP_InvalidHeaderLen*min*": "Invalid IP Header Length,Min", + "Liberate_IP_InvalidVersion*max*": "Invalid IP Version,Max", + "Liberate_IP_InvalidVersion*max*": "Invalid IP Version,Min", + "Liberate_IP_LongerLength*max*": "Bad IP Length (Too Long),Max", + "Liberate_IP_LongerLength*min*": "Bad IP Length (Too Long),Min", + "Liberate_IP_ShorterLength*max*": "Bad IP Length (Too Short),Max", + "Liberate_IP_ShorterLength*min*": "Bad IP Length (Too Short),Min", + "Liberate_IP_LowTTL*max*": "Low TTL,Max", + "Liberate_IP_LowTTL*min*": "Low TTL,Min", + "Liberate_IP_LowTTLRSTa*max*": "RST w/ Low TTL #1,max", + "Liberate_IP_LowTTLRSTa*min*": "RST w/ Low TTL #1,min", + "Liberate_IP_LowTTLRSTb*max*": "RST w/ Low TTL #2,max", + "Liberate_IP_LowTTLRSTb*min*": "RST w/ Low TTL #2,min", + "Liberate_TCP_ACKNotSet*max*": "Data Packet wo/ ACK Flag,Max", + "Liberate_TCP_ACKNotSet*min*": "Data Packet wo/ ACK Flag,Min", + "Liberate_TCP_InvalidDataoff*max*": "Invalid Data-Offset,Max", + "Liberate_TCP_InvalidDataoff*min*": "Invalid Data-Offset,Min", + "Liberate_TCP_InvalidFlagComb*max*": "Invalid Flags,Max", + "Liberate_TCP_InvalidFlagComb*min*": "Invalid Flags,Min", + "Liberate_TCP_WrongChksum*max*": "Bad TCP Checksum,Max", + "Liberate_TCP_WrongChksum*min*": "Bad TCP Checksum,Min", + "Liberate_TCP_WrongSEQ*max*": "Bad SEQ,Max", + "Liberate_TCP_WrongSEQ*min*": "Bad SEQ,Min", + "SymTCP_Zeek_SYNWithData": "SYN,w/ Payload,Zeek", + "SymTCP_GFW_OutOfWindowSYNData": "SYN,w/ Payload & Bad SEQ,GFW #1", + "SymTCP_GFW_RetransmittedSYNData": "SYN,w/ Payload & Bad SEQ,GFW #2", + "SymTCP_Zeek_MultipleSYN": "SYN,Multiple (SYN),Zeek", + "SymTCP_Snort_MultipleSYN": "SYN,Multiple (SYN),Snort", + "SymTCP_GFW_FINWithData": "Injected FIN,w/ Payload,GFW", + "SymTCP_Zeek_PureFIN": "Injected FIN,w/ Payload,Zeek", + "SymTCP_Zeek_PureFIN": "Injected FIN,Pure,Zeek", + "SymTCP_Snort_InWindowFIN": "Injected FIN,Pure,Snort", + "SymTCP_Snort_RSTMD5": "Injected RST,Bad TCP MD5-Option,Snort", + "SymTCP_Snort_RSTBadTimestamp": "Injected RST,Bad Timestamp,Snort", + "SymTCP_GFW_RSTBadTimestamp": "Injected RST,Bad Timestamp,GFW", + "SymTCP_Snort_PartialInWindowRST": "Injected RST,Partial In-Window,Snort", + "SymTCP_GFW_BadRST": "Injected RST,Bad TCP-Checksum/MD5-Option,GFW", + "SymTCP_Snort_InWindowRST": "Injected RST,Pure,Snort", + "SymTCP_Zeek_BadRSTFIN": "Injected RST/FIN-ACK,Bad SEQ,Zeek", + "SymTCP_Snort_FINACKBadACK": "Injected FIN-ACK,Bad ACK Num,Snort", + "SymTCP_GFW_FINACKDataBadACK": "Injected FIN-ACK,Bad ACK Num,GFW", + "SymTCP_Snort_FINACKMD5": "Injected FIN-ACK,Bad TCP MD5-Option,Snort", + "SymTCP_GFW_BadFINACKData": "Injected FIN-ACK,Bad TCP-Checksum/MD5-Option,GFW", + "SymTCP_GFW_RSTACKBadACKNum": "Injected RST-ACK,Bad ACK Num,GFW", + "SymTCP_Snort_RSTACKBadACKNum": "Injected RST-ACK,Bad ACK Num,Snort", + "SymTCP_Zeek_SEQJump": "Data Packet (ACK),Bad SEQ,Zeek", + "SymTCP_Zeek_UnderflowSEQ": "Data Packet (ACK),Underflow SEQ,Zeek", + "SymTCP_GFW_UnderflowSEQ": "Data Packet (ACK),Underflow SEQ,GFW", + "SymTCP_Zeek_DataBadACK": "Data Packet (ACK),Bad ACK Num,Zeek", + "SymTCP_Zeek_DataOverlapping": "Data Packet (ACK),Overlapping,Zeek", + "SymTCP_Zeek_DataWithoutACK": "Data Packet (ACK),wo/ ACK Flag,Zeek", + "SymTCP_GFW_DataWithoutACK": "Data Packet (ACK),wo/ ACK Flag,GFW", + "SymTCP_Snort_UrgentData": "Data Packet (ACK),w/ Urgent Pointer,Snort", + "SymTCP_GFW_BadData": "Data Packet (ACK),Bad TCP-Checksum/MD5-Option,GFW", + "SymTCP_Snort_TimeGap": "Data Packet (ACK),Bad Timestamp,Snort", +} + + +def read_and_merge_res(our_app_res_fpath, dump_fpath): + def read_data(fpath, opt=None): + with open(fpath, 'r') as fin: + d = {} + names_by_work = {"SymTCP": [], "Liberate": [], "Geneva": []} + data = fin.readlines() + del data[0] + for row in data: + if row.startswith('#'): + continue + row = row.rstrip('\n') + if opt == 'tool_log_format': + _, name, auc_roc_score, tpr001, tpr005, eer_score, top1_hit_acc, top3_hit_acc, top5_hit_acc = row.split( + ',') + data_rec = ','.join( + [auc_roc_score, tpr001, tpr005, eer_score, top1_hit_acc, top3_hit_acc, top5_hit_acc]) + if opt == 'kitsune_log_format': + name, auc_roc_score, tpr001, tpr005, eer_score = row.split( + ',') + if name.endswith("_max"): + name = name.replace('_max', '*max*') + if name.endswith("_min"): + name = name.replace('_min', '*min*') + data_rec = ','.join( + [auc_roc_score, tpr001, tpr005, eer_score]) + if "SymTCP" in name: + names_by_work["SymTCP"].append(name) + if "Liberate" in name: + names_by_work["Liberate"].append(name) + if "Geneva" in name: + names_by_work["Geneva"].append(name) + d[name] = data_rec + return d, names_by_work + + our_app_d, names_by_work = read_data( + our_app_res_fpath, opt='tool_log_format') + + dump = [] + + dump.append("#SymTCP") + for ori_name in names_by_work["SymTCP"]: + if ori_name not in NAME_MAP: + continue + name = NAME_MAP[ori_name] + dump.append( + ','.join([name, our_app_d[ori_name]])) + + dump.append("#Liberate") + for ori_name in names_by_work["Liberate"]: + if ori_name not in NAME_MAP: + continue + name = NAME_MAP[ori_name] + dump.append( + ','.join([name, our_app_d[ori_name]])) + + dump.append("#Geneva") + for ori_name in names_by_work["Geneva"]: + if ori_name not in NAME_MAP: + continue + name = NAME_MAP[ori_name] + dump.append( + ','.join([name, our_app_d[ori_name]])) + + with open(dump_fpath, 'w') as fout: + for line in dump: + fout.write('%s\n' % line) + + +def read_data(fpath): + with open(fpath, 'r') as fin: + data = fin.readlines() + + data_dict = {} + curr_tag = "" + for row in data: + row = row.rstrip('\n') + if len(row) == 0: + continue + if row.startswith('#'): + curr_tag = row[1:] + data_dict[curr_tag] = {} + else: + if curr_tag == 'SymTCP': + primary_type, secondary_type, variant, \ + cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc = row.split(',') + cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc = float(cxt_res), float(cxt_tpr001), \ + float(cxt_tpr005), float(cxt_eer_score), float(cxt_top1_hit_acc), float(cxt_top3_hit_acc), float(cxt_top5_hit_acc) + if primary_type not in data_dict[curr_tag]: + data_dict[curr_tag][primary_type] = {} + if secondary_type not in data_dict[curr_tag][primary_type]: + data_dict[curr_tag][primary_type][secondary_type] = {} + data_dict[curr_tag][primary_type][secondary_type][variant] = ( + cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc) + if curr_tag in {'Liberate', 'Geneva'}: + primary_type, secondary_type, \ + cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc = row.split( + ',') + cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc = float(cxt_res), float(cxt_tpr001), \ + float(cxt_tpr005), float(cxt_eer_score), float(cxt_top1_hit_acc), float(cxt_top3_hit_acc), float(cxt_top5_hit_acc) + if primary_type not in data_dict[curr_tag]: + data_dict[curr_tag][primary_type] = {} + data_dict[curr_tag][primary_type][secondary_type] = (cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc) + return data_dict + + +def draw(data_dict, type='detection'): + def draw_a_subplot(ax, y, title, remove_yticks=False, plot_type='detection'): + if plot_type == 'detection': + x = [i for i in range(2)] + idx_lst = [0, 3] + scores = [y[i] for i in idx_lst] + bars = ax.bar(x, scores, color=('#F96E46', '#FFE3E3')) + + hatch_lst = ('/', '/') + if plot_type == 'localization': + x = [i for i in range(1)] + idx_lst = [6] + scores = [y[i] for i in idx_lst] + bars = ax.bar(x, scores, color=('#F96E46')) + hatch_lst = ('+') + ax.set_xticklabels([]) + if remove_yticks: + ax.set_yticklabels([]) + ax.set_ylim([0.0, 1.0]) + ax.set_title(title, fontweight='bold') + + idx = 0 + for bar in bars: + yval = '%.3f' % bar.get_height() + ax.text(bar.get_x() + 0.4, 0.1, + yval, rotation='vertical', color='black', fontsize='x-large', ha='center') + if hatch_lst is not None: + bar.set_hatch(hatch_lst[idx]) + idx += 1 + return + + # for SymTCP + data = data_dict["SymTCP"] + subplot_cnt = 0 + for _, secondary in data.items(): + for _, var in secondary.items(): + for _, _ in var.items(): + subplot_cnt += 1 + if type == 'detection': + ncol = 10 + if type == 'localization': + ncol = 10 + nrow = ceil(subplot_cnt / ncol) + fig, axes = plt.subplots(nrow, ncol) + row, col = 0, 0 + cnt = 0 + for primary, secondary in data.items(): + for secondary_type, variant in secondary.items(): + if len(variant) == 1: # meaning this is 'All' case + row, col = cnt // ncol, cnt % ncol + if col != 0: + draw_a_subplot(axes[row, col], list(variant.values())[0], '%s: %s\n%s' % + (list(variant.keys())[0], primary, secondary_type), remove_yticks=True, plot_type=type) + else: + draw_a_subplot(axes[row, col], list(variant.values())[0], '%s: %s\n%s' % + (list(variant.keys())[0], primary, secondary_type), plot_type=type) + cnt += 1 + else: + for variant_name, res in variant.items(): + row, col = cnt // ncol, cnt % ncol + if col != 0: + draw_a_subplot(axes[row, col], res, '%s: %s\n%s' % + (variant_name, primary, secondary_type), remove_yticks=True, plot_type=type) + else: + draw_a_subplot(axes[row, col], res, '%s: %s\n%s' % + (variant_name, primary, secondary_type), plot_type=type) + cnt += 1 + + for i in range(nrow*ncol-subplot_cnt): + fig.delaxes(axes[-1][-(i+1)]) + plt.subplots_adjust(hspace=.3) + plt.show() + + # for Liberate + data = data_dict['Liberate'] + subplot_cnt = 0 + for _, strategy in data.items(): + for _, _ in strategy.items(): + subplot_cnt += 1 + if type == 'detection': + ncol = 10 + if type == 'localization': + ncol = 10 + nrow = ceil(subplot_cnt / ncol) + fig, axes = plt.subplots(nrow, ncol) + row, col = 0, 0 + cnt = 0 + for primary, strategy in data.items(): + for strategy, res in strategy.items(): + row, col = cnt // ncol, cnt % ncol + if col != 0: + draw_a_subplot(axes[row, col], res, '%s\n%s' % + (primary, strategy), remove_yticks=True, plot_type=type) + else: + draw_a_subplot(axes[row, col], res, '%s\n%s' % + (primary, strategy), plot_type=type) + cnt += 1 + + for i in range(nrow*ncol-subplot_cnt): + fig.delaxes(axes[-1][-(i+1)]) + plt.subplots_adjust(hspace=.3) + plt.show() + + # for Geneva + data = data_dict['Geneva'] + subplot_cnt = 0 + for _, secondary in data.items(): + for _, _ in secondary.items(): + subplot_cnt += 1 + if type == 'detection': + ncol = 10 + if type == 'localization': + ncol = 10 + nrow = ceil(subplot_cnt / ncol) + fig, axes = plt.subplots(nrow, ncol) + row, col = 0, 0 + cnt = 0 + for primary, secondary in data.items(): + for secondary, res in secondary.items(): + row, col = cnt // ncol, cnt % ncol + if col != 0: + draw_a_subplot(axes[row, col], res, '%s\n%s' % + (primary, secondary), remove_yticks=True, plot_type=type) + else: + draw_a_subplot(axes[row, col], res, '%s\n%s' % + (primary, secondary), plot_type=type) + cnt += 1 + + for i in range(nrow*ncol-subplot_cnt): + fig.delaxes(axes[-1][-(i+1)]) + plt.subplots_adjust(hspace=.3) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='This script generates figure for showing detection results.') + parser.add_argument('--fin-our', type=str) + parser.add_argument('--merged-res', type=str) + args = parser.parse_args() + + read_and_merge_res(args.fin_our, args.merged_res) + + data_dict = read_data(args.merged_res) + draw(data_dict) + draw(data_dict, type='localization') |
