blob: b9fc75ee90fd940e7317393c4778d1d4eddb9f33 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
"""
Date: 2022-03-09
Author: [email protected]
Desc: collection dataset for trainning the subtitute model
"""
import torch
from torch.utils.data import Dataset
import numpy as np
class CollectionDataset(Dataset):
"""
"""
def __init__(self, filename, sequence_len=50):
self.data = np.load(filename, allow_pickle=True)
self.sequence_len = sequence_len
print(self.data.shape)
def __getitem__(self, index):
data = self.data[index]
sample = data[:self.sequence_len]
label = data[self.sequence_len:]
sample = torch.tensor(sample, dtype=torch.long)
label = torch.tensor(label, dtype=torch.float)
return sample, label
def __len__(self):
return self.data.shape[0]
|