summaryrefslogtreecommitdiff
path: root/script/paint_fig.py
diff options
context:
space:
mode:
Diffstat (limited to 'script/paint_fig.py')
-rw-r--r--script/paint_fig.py351
1 files changed, 351 insertions, 0 deletions
diff --git a/script/paint_fig.py b/script/paint_fig.py
new file mode 100644
index 0000000..371391e
--- /dev/null
+++ b/script/paint_fig.py
@@ -0,0 +1,351 @@
+import matplotlib as mpl
+import argparse
+import pprint
+from math import ceil
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+pp = pprint.PrettyPrinter(indent=2)
+
+font = {'family': 'sans-serif',
+ 'weight': 'bold',
+ 'size': 14}
+
+matplotlib.rcParams['hatch.linewidth'] = 0.9
+
+matplotlib.rc('font', **font)
+
+NAME_MAP = {
+ "Geneva_Strategy_1*max*": "Invalid Data-Offset,Bad TCP Checksum",
+ "Geneva_Strategy_2*max*": "Invalid Data-Offset,Low TTL",
+ "Geneva_Strategy_3*max*": "Invalid Data-Offset,Bad ACK Num",
+ "Geneva_Strategy_4*max*": "Invalid TCP WScale-Option,Invalid Data-Offset",
+ "Geneva_Strategy_5*max*": "Bad Payload Length,Bad TCP Checksum",
+ "Geneva_Strategy_6*max*": "Bad Payload Length,Low TTL",
+ "Geneva_Strategy_7*max*": "Bad Payload Length,Bad ACK Num",
+ "Geneva_Strategy_8*max*": "/,Bad Payload Length",
+ "Geneva_Strategy_9*max*": "Bad IP Length,/",
+ "Geneva_Strategy_10*max*": "Injected RST,Bad IP Length",
+ "Geneva_Strategy_11*max*": "Injected RST,Bad TCP Checksum",
+ "Geneva_Strategy_12*max*": "Injected RST,Low TTL",
+ "Geneva_Strategy_13*max*": "Bad TCP MD5-Option,Injected RST",
+ "Geneva_Strategy_14*max*": "Injected RST-ACK,Bad TCP Checksum",
+ "Geneva_Strategy_15*max*": "Injected RST-ACK,Low TTL",
+ "Geneva_Strategy_16*max*": "Bad TCP MD5-Option,Injected RST",
+ "Geneva_Strategy_17*max*": "Invalid Flags #1,Bad TCP Checksum",
+ "Geneva_Strategy_18*max*": "Invalid Flags #2,Low TTL",
+ "Geneva_Strategy_19*max*": "Invalid Flags #2,Bad TCP MD5-Option",
+ "Geneva_Strategy_23*max*": "Injected FIN,Bad IP Length",
+ "Geneva_Strategy_24*max*": "Injected SYN-ACK,Bad TCP MD5-Option",
+ "Geneva_Strategy_23*max*": "Bad TCP UTO-Option,Bad TCP MD5-Option",
+ "Liberate_IP_InvalidHeaderLen*max*": "Invalid IP Header Length,Max",
+ "Liberate_IP_InvalidHeaderLen*min*": "Invalid IP Header Length,Min",
+ "Liberate_IP_InvalidVersion*max*": "Invalid IP Version,Max",
+ "Liberate_IP_InvalidVersion*max*": "Invalid IP Version,Min",
+ "Liberate_IP_LongerLength*max*": "Bad IP Length (Too Long),Max",
+ "Liberate_IP_LongerLength*min*": "Bad IP Length (Too Long),Min",
+ "Liberate_IP_ShorterLength*max*": "Bad IP Length (Too Short),Max",
+ "Liberate_IP_ShorterLength*min*": "Bad IP Length (Too Short),Min",
+ "Liberate_IP_LowTTL*max*": "Low TTL,Max",
+ "Liberate_IP_LowTTL*min*": "Low TTL,Min",
+ "Liberate_IP_LowTTLRSTa*max*": "RST w/ Low TTL #1,max",
+ "Liberate_IP_LowTTLRSTa*min*": "RST w/ Low TTL #1,min",
+ "Liberate_IP_LowTTLRSTb*max*": "RST w/ Low TTL #2,max",
+ "Liberate_IP_LowTTLRSTb*min*": "RST w/ Low TTL #2,min",
+ "Liberate_TCP_ACKNotSet*max*": "Data Packet wo/ ACK Flag,Max",
+ "Liberate_TCP_ACKNotSet*min*": "Data Packet wo/ ACK Flag,Min",
+ "Liberate_TCP_InvalidDataoff*max*": "Invalid Data-Offset,Max",
+ "Liberate_TCP_InvalidDataoff*min*": "Invalid Data-Offset,Min",
+ "Liberate_TCP_InvalidFlagComb*max*": "Invalid Flags,Max",
+ "Liberate_TCP_InvalidFlagComb*min*": "Invalid Flags,Min",
+ "Liberate_TCP_WrongChksum*max*": "Bad TCP Checksum,Max",
+ "Liberate_TCP_WrongChksum*min*": "Bad TCP Checksum,Min",
+ "Liberate_TCP_WrongSEQ*max*": "Bad SEQ,Max",
+ "Liberate_TCP_WrongSEQ*min*": "Bad SEQ,Min",
+ "SymTCP_Zeek_SYNWithData": "SYN,w/ Payload,Zeek",
+ "SymTCP_GFW_OutOfWindowSYNData": "SYN,w/ Payload & Bad SEQ,GFW #1",
+ "SymTCP_GFW_RetransmittedSYNData": "SYN,w/ Payload & Bad SEQ,GFW #2",
+ "SymTCP_Zeek_MultipleSYN": "SYN,Multiple (SYN),Zeek",
+ "SymTCP_Snort_MultipleSYN": "SYN,Multiple (SYN),Snort",
+ "SymTCP_GFW_FINWithData": "Injected FIN,w/ Payload,GFW",
+ "SymTCP_Zeek_PureFIN": "Injected FIN,w/ Payload,Zeek",
+ "SymTCP_Zeek_PureFIN": "Injected FIN,Pure,Zeek",
+ "SymTCP_Snort_InWindowFIN": "Injected FIN,Pure,Snort",
+ "SymTCP_Snort_RSTMD5": "Injected RST,Bad TCP MD5-Option,Snort",
+ "SymTCP_Snort_RSTBadTimestamp": "Injected RST,Bad Timestamp,Snort",
+ "SymTCP_GFW_RSTBadTimestamp": "Injected RST,Bad Timestamp,GFW",
+ "SymTCP_Snort_PartialInWindowRST": "Injected RST,Partial In-Window,Snort",
+ "SymTCP_GFW_BadRST": "Injected RST,Bad TCP-Checksum/MD5-Option,GFW",
+ "SymTCP_Snort_InWindowRST": "Injected RST,Pure,Snort",
+ "SymTCP_Zeek_BadRSTFIN": "Injected RST/FIN-ACK,Bad SEQ,Zeek",
+ "SymTCP_Snort_FINACKBadACK": "Injected FIN-ACK,Bad ACK Num,Snort",
+ "SymTCP_GFW_FINACKDataBadACK": "Injected FIN-ACK,Bad ACK Num,GFW",
+ "SymTCP_Snort_FINACKMD5": "Injected FIN-ACK,Bad TCP MD5-Option,Snort",
+ "SymTCP_GFW_BadFINACKData": "Injected FIN-ACK,Bad TCP-Checksum/MD5-Option,GFW",
+ "SymTCP_GFW_RSTACKBadACKNum": "Injected RST-ACK,Bad ACK Num,GFW",
+ "SymTCP_Snort_RSTACKBadACKNum": "Injected RST-ACK,Bad ACK Num,Snort",
+ "SymTCP_Zeek_SEQJump": "Data Packet (ACK),Bad SEQ,Zeek",
+ "SymTCP_Zeek_UnderflowSEQ": "Data Packet (ACK),Underflow SEQ,Zeek",
+ "SymTCP_GFW_UnderflowSEQ": "Data Packet (ACK),Underflow SEQ,GFW",
+ "SymTCP_Zeek_DataBadACK": "Data Packet (ACK),Bad ACK Num,Zeek",
+ "SymTCP_Zeek_DataOverlapping": "Data Packet (ACK),Overlapping,Zeek",
+ "SymTCP_Zeek_DataWithoutACK": "Data Packet (ACK),wo/ ACK Flag,Zeek",
+ "SymTCP_GFW_DataWithoutACK": "Data Packet (ACK),wo/ ACK Flag,GFW",
+ "SymTCP_Snort_UrgentData": "Data Packet (ACK),w/ Urgent Pointer,Snort",
+ "SymTCP_GFW_BadData": "Data Packet (ACK),Bad TCP-Checksum/MD5-Option,GFW",
+ "SymTCP_Snort_TimeGap": "Data Packet (ACK),Bad Timestamp,Snort",
+}
+
+
+def read_and_merge_res(our_app_res_fpath, dump_fpath):
+ def read_data(fpath, opt=None):
+ with open(fpath, 'r') as fin:
+ d = {}
+ names_by_work = {"SymTCP": [], "Liberate": [], "Geneva": []}
+ data = fin.readlines()
+ del data[0]
+ for row in data:
+ if row.startswith('#'):
+ continue
+ row = row.rstrip('\n')
+ if opt == 'tool_log_format':
+ _, name, auc_roc_score, tpr001, tpr005, eer_score, top1_hit_acc, top3_hit_acc, top5_hit_acc = row.split(
+ ',')
+ data_rec = ','.join(
+ [auc_roc_score, tpr001, tpr005, eer_score, top1_hit_acc, top3_hit_acc, top5_hit_acc])
+ if opt == 'kitsune_log_format':
+ name, auc_roc_score, tpr001, tpr005, eer_score = row.split(
+ ',')
+ if name.endswith("_max"):
+ name = name.replace('_max', '*max*')
+ if name.endswith("_min"):
+ name = name.replace('_min', '*min*')
+ data_rec = ','.join(
+ [auc_roc_score, tpr001, tpr005, eer_score])
+ if "SymTCP" in name:
+ names_by_work["SymTCP"].append(name)
+ if "Liberate" in name:
+ names_by_work["Liberate"].append(name)
+ if "Geneva" in name:
+ names_by_work["Geneva"].append(name)
+ d[name] = data_rec
+ return d, names_by_work
+
+ our_app_d, names_by_work = read_data(
+ our_app_res_fpath, opt='tool_log_format')
+
+ dump = []
+
+ dump.append("#SymTCP")
+ for ori_name in names_by_work["SymTCP"]:
+ if ori_name not in NAME_MAP:
+ continue
+ name = NAME_MAP[ori_name]
+ dump.append(
+ ','.join([name, our_app_d[ori_name]]))
+
+ dump.append("#Liberate")
+ for ori_name in names_by_work["Liberate"]:
+ if ori_name not in NAME_MAP:
+ continue
+ name = NAME_MAP[ori_name]
+ dump.append(
+ ','.join([name, our_app_d[ori_name]]))
+
+ dump.append("#Geneva")
+ for ori_name in names_by_work["Geneva"]:
+ if ori_name not in NAME_MAP:
+ continue
+ name = NAME_MAP[ori_name]
+ dump.append(
+ ','.join([name, our_app_d[ori_name]]))
+
+ with open(dump_fpath, 'w') as fout:
+ for line in dump:
+ fout.write('%s\n' % line)
+
+
+def read_data(fpath):
+ with open(fpath, 'r') as fin:
+ data = fin.readlines()
+
+ data_dict = {}
+ curr_tag = ""
+ for row in data:
+ row = row.rstrip('\n')
+ if len(row) == 0:
+ continue
+ if row.startswith('#'):
+ curr_tag = row[1:]
+ data_dict[curr_tag] = {}
+ else:
+ if curr_tag == 'SymTCP':
+ primary_type, secondary_type, variant, \
+ cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc = row.split(',')
+ cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc = float(cxt_res), float(cxt_tpr001), \
+ float(cxt_tpr005), float(cxt_eer_score), float(cxt_top1_hit_acc), float(cxt_top3_hit_acc), float(cxt_top5_hit_acc)
+ if primary_type not in data_dict[curr_tag]:
+ data_dict[curr_tag][primary_type] = {}
+ if secondary_type not in data_dict[curr_tag][primary_type]:
+ data_dict[curr_tag][primary_type][secondary_type] = {}
+ data_dict[curr_tag][primary_type][secondary_type][variant] = (
+ cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc)
+ if curr_tag in {'Liberate', 'Geneva'}:
+ primary_type, secondary_type, \
+ cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc = row.split(
+ ',')
+ cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc = float(cxt_res), float(cxt_tpr001), \
+ float(cxt_tpr005), float(cxt_eer_score), float(cxt_top1_hit_acc), float(cxt_top3_hit_acc), float(cxt_top5_hit_acc)
+ if primary_type not in data_dict[curr_tag]:
+ data_dict[curr_tag][primary_type] = {}
+ data_dict[curr_tag][primary_type][secondary_type] = (cxt_res, cxt_tpr001, cxt_tpr005, cxt_eer_score, cxt_top1_hit_acc, cxt_top3_hit_acc, cxt_top5_hit_acc)
+ return data_dict
+
+
+def draw(data_dict, type='detection'):
+ def draw_a_subplot(ax, y, title, remove_yticks=False, plot_type='detection'):
+ if plot_type == 'detection':
+ x = [i for i in range(2)]
+ idx_lst = [0, 3]
+ scores = [y[i] for i in idx_lst]
+ bars = ax.bar(x, scores, color=('#F96E46', '#FFE3E3'))
+
+ hatch_lst = ('/', '/')
+ if plot_type == 'localization':
+ x = [i for i in range(1)]
+ idx_lst = [6]
+ scores = [y[i] for i in idx_lst]
+ bars = ax.bar(x, scores, color=('#F96E46'))
+ hatch_lst = ('+')
+ ax.set_xticklabels([])
+ if remove_yticks:
+ ax.set_yticklabels([])
+ ax.set_ylim([0.0, 1.0])
+ ax.set_title(title, fontweight='bold')
+
+ idx = 0
+ for bar in bars:
+ yval = '%.3f' % bar.get_height()
+ ax.text(bar.get_x() + 0.4, 0.1,
+ yval, rotation='vertical', color='black', fontsize='x-large', ha='center')
+ if hatch_lst is not None:
+ bar.set_hatch(hatch_lst[idx])
+ idx += 1
+ return
+
+ # for SymTCP
+ data = data_dict["SymTCP"]
+ subplot_cnt = 0
+ for _, secondary in data.items():
+ for _, var in secondary.items():
+ for _, _ in var.items():
+ subplot_cnt += 1
+ if type == 'detection':
+ ncol = 10
+ if type == 'localization':
+ ncol = 10
+ nrow = ceil(subplot_cnt / ncol)
+ fig, axes = plt.subplots(nrow, ncol)
+ row, col = 0, 0
+ cnt = 0
+ for primary, secondary in data.items():
+ for secondary_type, variant in secondary.items():
+ if len(variant) == 1: # meaning this is 'All' case
+ row, col = cnt // ncol, cnt % ncol
+ if col != 0:
+ draw_a_subplot(axes[row, col], list(variant.values())[0], '%s: %s\n%s' %
+ (list(variant.keys())[0], primary, secondary_type), remove_yticks=True, plot_type=type)
+ else:
+ draw_a_subplot(axes[row, col], list(variant.values())[0], '%s: %s\n%s' %
+ (list(variant.keys())[0], primary, secondary_type), plot_type=type)
+ cnt += 1
+ else:
+ for variant_name, res in variant.items():
+ row, col = cnt // ncol, cnt % ncol
+ if col != 0:
+ draw_a_subplot(axes[row, col], res, '%s: %s\n%s' %
+ (variant_name, primary, secondary_type), remove_yticks=True, plot_type=type)
+ else:
+ draw_a_subplot(axes[row, col], res, '%s: %s\n%s' %
+ (variant_name, primary, secondary_type), plot_type=type)
+ cnt += 1
+
+ for i in range(nrow*ncol-subplot_cnt):
+ fig.delaxes(axes[-1][-(i+1)])
+ plt.subplots_adjust(hspace=.3)
+ plt.show()
+
+ # for Liberate
+ data = data_dict['Liberate']
+ subplot_cnt = 0
+ for _, strategy in data.items():
+ for _, _ in strategy.items():
+ subplot_cnt += 1
+ if type == 'detection':
+ ncol = 10
+ if type == 'localization':
+ ncol = 10
+ nrow = ceil(subplot_cnt / ncol)
+ fig, axes = plt.subplots(nrow, ncol)
+ row, col = 0, 0
+ cnt = 0
+ for primary, strategy in data.items():
+ for strategy, res in strategy.items():
+ row, col = cnt // ncol, cnt % ncol
+ if col != 0:
+ draw_a_subplot(axes[row, col], res, '%s\n%s' %
+ (primary, strategy), remove_yticks=True, plot_type=type)
+ else:
+ draw_a_subplot(axes[row, col], res, '%s\n%s' %
+ (primary, strategy), plot_type=type)
+ cnt += 1
+
+ for i in range(nrow*ncol-subplot_cnt):
+ fig.delaxes(axes[-1][-(i+1)])
+ plt.subplots_adjust(hspace=.3)
+ plt.show()
+
+ # for Geneva
+ data = data_dict['Geneva']
+ subplot_cnt = 0
+ for _, secondary in data.items():
+ for _, _ in secondary.items():
+ subplot_cnt += 1
+ if type == 'detection':
+ ncol = 10
+ if type == 'localization':
+ ncol = 10
+ nrow = ceil(subplot_cnt / ncol)
+ fig, axes = plt.subplots(nrow, ncol)
+ row, col = 0, 0
+ cnt = 0
+ for primary, secondary in data.items():
+ for secondary, res in secondary.items():
+ row, col = cnt // ncol, cnt % ncol
+ if col != 0:
+ draw_a_subplot(axes[row, col], res, '%s\n%s' %
+ (primary, secondary), remove_yticks=True, plot_type=type)
+ else:
+ draw_a_subplot(axes[row, col], res, '%s\n%s' %
+ (primary, secondary), plot_type=type)
+ cnt += 1
+
+ for i in range(nrow*ncol-subplot_cnt):
+ fig.delaxes(axes[-1][-(i+1)])
+ plt.subplots_adjust(hspace=.3)
+ plt.show()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description='This script generates figure for showing detection results.')
+ parser.add_argument('--fin-our', type=str)
+ parser.add_argument('--merged-res', type=str)
+ args = parser.parse_args()
+
+ read_and_merge_res(args.fin_our, args.merged_res)
+
+ data_dict = read_data(args.merged_res)
+ draw(data_dict)
+ draw(data_dict, type='localization')