summaryrefslogtreecommitdiff
path: root/attack/adversarialDataset.py
blob: d3d91ee1bdc7431f2d990a32391e44f09f6c6847 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
Date: 2022-03-08
Author: [email protected]
Desc: adversarial dataset
"""

import torch
from torch.utils.data import Dataset
import numpy as np

class AdversarialC2Data(Dataset):
    """
    adversarial sample for attack
    """
    def __init__(self, filename, target_class=5, keep_target=True, norm=False):
        """

        :param filename:
        :param target_class:
        :param keep_target: Whether to keep target data
        """
        data = np.load(filename)
        if norm:
            data[:,:-1] = (data[:,:-1] - np.min(data[:,:-1], axis=0)) / (np.max(data[:,:-1], axis=0) - np.min(data[:,:-1], axis=0)) + 1e-9
        print("Adversarial Dataset Load: {}".format(filename))
        if keep_target:
            self.data = data
        else:
            self.data = np.array([x for x in data if x[-1] != target_class])
        print(self.data.shape)

    def __getitem__(self, index):
        data = self.data[index]
        sample = [x if x < 1600 else 1599 for x in data[:-1]]
        sample = np.array(sample)
        sample = torch.from_numpy(sample)
        label = torch.from_numpy(data[-1:])
        return sample, label

    def __len__(self):
        return self.data.shape[0]