diff options
| author | yifei cheng <[email protected]> | 2023-06-26 12:35:24 +0000 |
|---|---|---|
| committer | yifei cheng <[email protected]> | 2023-06-26 12:35:24 +0000 |
| commit | 24fb916e3b5f043414b0e79f0e9198e3d75ffd98 (patch) | |
| tree | 6e699ff43ff7f772ac40d01cf8a8e91b8d2ecf0e | |
| parent | 8880debd0ddb6f618d77285890de9bebdd88703f (diff) | |
Upload New File
| -rw-r--r-- | attack/collectTargetData.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/attack/collectTargetData.py b/attack/collectTargetData.py new file mode 100644 index 0000000..0cf544b --- /dev/null +++ b/attack/collectTargetData.py @@ -0,0 +1,50 @@ +""" +Date: 2022-03-09 +Author: [email protected] +Desc: Collect the labels output by the target model for training the subtitute model +""" + +import torch +from TargetModel.FSNet.FSNet import FSNet +from TargetModel.FSNet.dataset import C2Data +from tqdm import tqdm +from torch.utils.data import DataLoader +import torch.nn.functional as F +import numpy as np + +def collectData(model:FSNet, dataloader, device, filename): + """ + + :param model: + :param dataloader: + :param filenale: + :return: + """ + collectData = [] + for batch_x, batch_y in tqdm(dataloader): + batch_x = batch_x.to(device) + batch_y = batch_y.to(device) + z_e = model.encode(batch_x) + z_d, D = model.decode(z_e) + z_dense = model.dense(z_e, z_d) + # z_dense = F.softmax(z_dense) + z_dense = torch.argmax(z_dense, dim=1, keepdim=True) + collect = torch.cat((batch_x, z_dense), dim=1) + collectData += collect.detach().cpu().numpy().tolist() + collectData = np.array(collectData) + np.save(filename, collectData) + return collectData + + + +if __name__ == '__main__': + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + state = torch.load('../modelFile/target_fsnet.pkt') + fsnet = FSNet(state['param']) + fsnet.load_state_dict(state['model_dict']) + fsnet = fsnet.to(device) + c2data = C2Data(number=200) + batch_size = 32 + dataloader = DataLoader(c2data, batch_size, shuffle=True, drop_last=False) + data = collectData(fsnet, dataloader, device, "../adversarialData/collectionData.npy") + print(data.shape) |
