summaryrefslogtreecommitdiff
path: root/src/main.py
blob: f3c6535a7c0b258e7420d75eb9ac8dd433fdacd9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Training a SEAL-CI model."""

import torch
from utils import tab_printer
from seal import SEALCITrainer
from param_parser import parameter_parser

def main():
    """
    Parsing command line parameters, reading data.
    Fitting and scoring a SEAL-CI model.
    """
    seeds = [1, 2, 4, 7, 9]
    seeds = '/'
    root_d = "D:\学习\mesa\毕设\研究点三\OpenWGL-main\opgl\dataset\\test1/"
    flags = False
    for seed in seeds:
        args = parameter_parser()
        args.graphs = root_d + str(seed) + "/"
        args.hierarchical_graph = root_d + str(seed) + "/edges_DoS.csv"
        tab_printer(args)
        trainer = SEALCITrainer(args)
        scores, prediction_indices, accuracy, f1, precition, recall, report = trainer.fit()
        trainer.score()
        if not flags:
            facc, ff1, fpre, frecall, freport = accuracy, f1, precition, recall, report
        else:
            facc = (facc + accuracy) / 2
            ff1 = (ff1 + f1) / 2
            frecall = (frecall + recall) / 2
            freport = freport.combine(report, lambda s1, s2: (s1 + s2) / 2)
            flags = True
    print(f"_____average report_____\n{freport}")
    print(f"_____average accuracy_____\n{facc}")
    print(f"_____average f1_____\n{ff1}")
    print(f"_____average recall_____\n{frecall}")
    
if __name__ == "__main__":
    main()