1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
"""Training a SEAL-CI model."""
import torch
from utils import tab_printer
from seal import SEALCITrainer
from param_parser import parameter_parser
def main():
"""
Parsing command line parameters, reading data.
Fitting and scoring a SEAL-CI model.
"""
seeds = [1, 2, 4, 7, 9]
seeds = '/'
root_d = "D:\学习\mesa\毕设\研究点三\OpenWGL-main\opgl\dataset\\test1/"
flags = False
for seed in seeds:
args = parameter_parser()
args.graphs = root_d + str(seed) + "/"
args.hierarchical_graph = root_d + str(seed) + "/edges_DoS.csv"
tab_printer(args)
trainer = SEALCITrainer(args)
scores, prediction_indices, accuracy, f1, precition, recall, report = trainer.fit()
trainer.score()
if not flags:
facc, ff1, fpre, frecall, freport = accuracy, f1, precition, recall, report
else:
facc = (facc + accuracy) / 2
ff1 = (ff1 + f1) / 2
frecall = (frecall + recall) / 2
freport = freport.combine(report, lambda s1, s2: (s1 + s2) / 2)
flags = True
print(f"_____average report_____\n{freport}")
print(f"_____average accuracy_____\n{facc}")
print(f"_____average f1_____\n{ff1}")
print(f"_____average recall_____\n{frecall}")
if __name__ == "__main__":
main()
|