diff options
| author | yifei cheng <[email protected]> | 2023-06-26 12:27:10 +0000 |
|---|---|---|
| committer | yifei cheng <[email protected]> | 2023-06-26 12:27:10 +0000 |
| commit | 0305f37dd0d54d601bb673ac5fdd5ec9b7929167 (patch) | |
| tree | c4912ac30823d4c8ca3af296f2c248fff26fd739 /TargetModel | |
| parent | 84142db33419f3c121ce1b826b761ac3f42f376c (diff) | |
Upload New File
Diffstat (limited to 'TargetModel')
| -rw-r--r-- | TargetModel/TargetSVM.py | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/TargetModel/TargetSVM.py b/TargetModel/TargetSVM.py new file mode 100644 index 0000000..b6c230a --- /dev/null +++ b/TargetModel/TargetSVM.py @@ -0,0 +1,89 @@ +""" +Date: 2022-04-13 +Author: [email protected] +Desc: target model: SVM +""" +from sklearn import svm +from TargetModel.FSNet.dataset import C2Data +import torch +from sklearn.metrics import confusion_matrix +import joblib +from torch.utils.data import DataLoader +import numpy as np +import warnings +warnings.filterwarnings("ignore") + + +class TargetSVM(): + """ + + """ + def __init__(self, param): + self.kernal = param['kernel'] + self.clf = svm.SVC(kernel=self.kernal) + + def train(self, dataloader): + X = [] + y = [] + for batch_x, batch_y in dataloader: + X += batch_x.data.numpy().tolist() + y += batch_y.data.numpy().tolist() + X = np.array(X) + y = np.array(y) + print("X.shape:{}".format(X.size)) + print("y.shape:{}".format(y.size)) + self.clf.fit(X, y) + print("training score:{}".format(self.clf.score(X, y))) + + def eval(self, dataloader): + X = [] + y = [] + for batch_x, batch_y in dataloader: + X += batch_x.data.numpy().tolist() + y += batch_y.data.numpy().tolist() + X = np.array(X) + y = np.array(y) + print("X.shape:{}".format(X.size)) + print("y.shape:{}".format(y.size)) + y_pred = self.clf.predict(X) + return y_pred, y + + def save(self, filename): + joblib.dump(self.clf, filename) + + def load(self, filename): + self.clf = joblib.load(filename) + +if __name__ == '__main__': + param = { + 'kernel': 'rbf', + } + + arch = "svm" + sample_szie = 580 + botname = "Gozi" + normal = "CTUNone" + + batch_size = 128 + + total_size = sample_szie * 2 + test_size = int(total_size * 0.2) + train_size = int((total_size - test_size) * 0.8) + valid_size = total_size - test_size - train_size + print("train data: {}".format(train_size)) + print("valid data: {}".format(valid_size)) + print("test data: {}".format(test_size)) + + c2data = C2Data(botname, number=sample_szie, sequenceLen=30) + train_data, test_data = torch.utils.data.random_split(c2data, [train_size + valid_size, test_size]) + train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=False) + test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=False) + + svm = TargetSVM(param) + svm.train(train_loader) + y_true, y_pred = svm.eval(test_loader) + print("confusion_metrix: \n{}".format(confusion_matrix(y_true, y_pred))) + filename = "../modelFile/target_{}_{}_{}.pkt".format(arch, botname, normal) + svm.save(filename) + + |
