summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoryifei cheng <[email protected]>2023-06-26 12:34:43 +0000
committeryifei cheng <[email protected]>2023-06-26 12:34:43 +0000
commit35e3cc36e1d2af4cb25ea6a2b7c085550abec449 (patch)
tree6b7418c6069a54c320c5067a115354d84a59f089
parenta2975f0250b825fdff1d1c2fab02220461841f08 (diff)
Upload New File
-rw-r--r--attack/adversarialDataset.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/attack/adversarialDataset.py b/attack/adversarialDataset.py
new file mode 100644
index 0000000..d3d91ee
--- /dev/null
+++ b/attack/adversarialDataset.py
@@ -0,0 +1,41 @@
+"""
+Date: 2022-03-08
+Desc: adversarial dataset
+"""
+
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+
+class AdversarialC2Data(Dataset):
+ """
+ adversarial sample for attack
+ """
+ def __init__(self, filename, target_class=5, keep_target=True, norm=False):
+ """
+
+ :param filename:
+ :param target_class:
+ :param keep_target: Whether to keep target data
+ """
+ data = np.load(filename)
+ if norm:
+ data[:,:-1] = (data[:,:-1] - np.min(data[:,:-1], axis=0)) / (np.max(data[:,:-1], axis=0) - np.min(data[:,:-1], axis=0)) + 1e-9
+ print("Adversarial Dataset Load: {}".format(filename))
+ if keep_target:
+ self.data = data
+ else:
+ self.data = np.array([x for x in data if x[-1] != target_class])
+ print(self.data.shape)
+
+ def __getitem__(self, index):
+ data = self.data[index]
+ sample = [x if x < 1600 else 1599 for x in data[:-1]]
+ sample = np.array(sample)
+ sample = torch.from_numpy(sample)
+ label = torch.from_numpy(data[-1:])
+ return sample, label
+
+ def __len__(self):
+ return self.data.shape[0]