"""Training a SEAL-CI model.""" import torch from utils import tab_printer from seal import SEALCITrainer from param_parser import parameter_parser def main(): """ Parsing command line parameters, reading data. Fitting and scoring a SEAL-CI model. """ seeds = [1, 2, 4, 7, 9] seeds = '/' root_d = "D:\学习\mesa\毕设\研究点三\OpenWGL-main\opgl\dataset\\test1/" flags = False for seed in seeds: args = parameter_parser() args.graphs = root_d + str(seed) + "/" args.hierarchical_graph = root_d + str(seed) + "/edges_DoS.csv" tab_printer(args) trainer = SEALCITrainer(args) scores, prediction_indices, accuracy, f1, precition, recall, report = trainer.fit() trainer.score() if not flags: facc, ff1, fpre, frecall, freport = accuracy, f1, precition, recall, report else: facc = (facc + accuracy) / 2 ff1 = (ff1 + f1) / 2 frecall = (frecall + recall) / 2 freport = freport.combine(report, lambda s1, s2: (s1 + s2) / 2) flags = True print(f"_____average report_____\n{freport}") print(f"_____average accuracy_____\n{facc}") print(f"_____average f1_____\n{ff1}") print(f"_____average recall_____\n{frecall}") if __name__ == "__main__": main()