diff options
| author | yifei cheng <[email protected]> | 2023-06-26 12:26:06 +0000 |
|---|---|---|
| committer | yifei cheng <[email protected]> | 2023-06-26 12:26:06 +0000 |
| commit | da15672ba6bc118b30ec1662b92185fa742c5358 (patch) | |
| tree | 261fbd9569ea90f886d86b8457e5282748f65217 /TargetModel | |
| parent | 91da8b658706b2a194acbbd30e1641fcfac1c833 (diff) | |
Upload New File
Diffstat (limited to 'TargetModel')
| -rw-r--r-- | TargetModel/TargetIF.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/TargetModel/TargetIF.py b/TargetModel/TargetIF.py new file mode 100644 index 0000000..1f3a024 --- /dev/null +++ b/TargetModel/TargetIF.py @@ -0,0 +1,59 @@ +""" +Date: 2022-04-13 +Author: [email protected] +Desc: target model: Isolate Forest +""" +from sklearn.ensemble import IsolationForest +import torch +from sklearn.metrics import confusion_matrix +import joblib +from torch.utils.data import DataLoader +import numpy as np +import warnings +warnings.filterwarnings("ignore") + +class TargetIF(): + """ + """ + def __init__(self, param): + # 正则化 + self.clf = IsolationForest( + n_estimators=param['n_estimators'], + contamination=param['outliers_fraction1'], + n_jobs=-1 + ) + + def train(self, dataloader): + X = [] + y = [] + for batch_x, batch_y in dataloader: + X += batch_x.data.numpy().tolist() + y += batch_y.data.numpy().tolist() + X = np.array(X) + y = np.array(y) + # print("X.shape:{}".format(X.size)) + # print("y.shape:{}".format(y.size)) + self.clf.fit(X, y) + # print("training score:{}".format(self.clf.score(X, y))) + + def eval(self, dataloader): + X = [] + y = [] + for batch_x, batch_y in dataloader: + X += batch_x.data.numpy().tolist() + y += batch_y.data.numpy().tolist() + X = np.array(X) + y = np.array(y) + # print("X.shape:{}".format(X.size)) + # print("y.shape:{}".format(y.size)) + y_pred = self.clf.predict(X) + y_pred = y_pred.reshape(-1,1) + y_pred[y_pred == -1] = 0 + y_pred[y_pred == 1] = 1 + return y_pred, y + + def save(self, filename): + joblib.dump(self.clf, filename) + + def load(self, filename): + self.clf = joblib.load(filename)
\ No newline at end of file |
