summaryrefslogtreecommitdiff
path: root/attack/collectionDataset.py
blob: b9fc75ee90fd940e7317393c4778d1d4eddb9f33 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
Date: 2022-03-09
Author: [email protected]
Desc: collection dataset for trainning the subtitute model
"""
import torch
from torch.utils.data import Dataset
import numpy as np

class CollectionDataset(Dataset):
    """

    """
    def __init__(self, filename, sequence_len=50):
        self.data = np.load(filename, allow_pickle=True)
        self.sequence_len = sequence_len
        print(self.data.shape)

    def __getitem__(self, index):
        data = self.data[index]
        sample = data[:self.sequence_len]
        label = data[self.sequence_len:]
        sample = torch.tensor(sample, dtype=torch.long)
        label = torch.tensor(label, dtype=torch.float)
        return sample, label

    def __len__(self):
        return self.data.shape[0]