diff options
| author | yifei cheng <[email protected]> | 2023-06-26 12:34:43 +0000 |
|---|---|---|
| committer | yifei cheng <[email protected]> | 2023-06-26 12:34:43 +0000 |
| commit | 35e3cc36e1d2af4cb25ea6a2b7c085550abec449 (patch) | |
| tree | 6b7418c6069a54c320c5067a115354d84a59f089 | |
| parent | a2975f0250b825fdff1d1c2fab02220461841f08 (diff) | |
Upload New File
| -rw-r--r-- | attack/adversarialDataset.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/attack/adversarialDataset.py b/attack/adversarialDataset.py new file mode 100644 index 0000000..d3d91ee --- /dev/null +++ b/attack/adversarialDataset.py @@ -0,0 +1,41 @@ +""" +Date: 2022-03-08 +Author: [email protected] +Desc: adversarial dataset +""" + +import torch +from torch.utils.data import Dataset +import numpy as np + +class AdversarialC2Data(Dataset): + """ + adversarial sample for attack + """ + def __init__(self, filename, target_class=5, keep_target=True, norm=False): + """ + + :param filename: + :param target_class: + :param keep_target: Whether to keep target data + """ + data = np.load(filename) + if norm: + data[:,:-1] = (data[:,:-1] - np.min(data[:,:-1], axis=0)) / (np.max(data[:,:-1], axis=0) - np.min(data[:,:-1], axis=0)) + 1e-9 + print("Adversarial Dataset Load: {}".format(filename)) + if keep_target: + self.data = data + else: + self.data = np.array([x for x in data if x[-1] != target_class]) + print(self.data.shape) + + def __getitem__(self, index): + data = self.data[index] + sample = [x if x < 1600 else 1599 for x in data[:-1]] + sample = np.array(sample) + sample = torch.from_numpy(sample) + label = torch.from_numpy(data[-1:]) + return sample, label + + def __len__(self): + return self.data.shape[0] |
