summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoryifei cheng <[email protected]>2023-06-26 12:35:24 +0000
committeryifei cheng <[email protected]>2023-06-26 12:35:24 +0000
commit24fb916e3b5f043414b0e79f0e9198e3d75ffd98 (patch)
tree6e699ff43ff7f772ac40d01cf8a8e91b8d2ecf0e
parent8880debd0ddb6f618d77285890de9bebdd88703f (diff)
Upload New File
-rw-r--r--attack/collectTargetData.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/attack/collectTargetData.py b/attack/collectTargetData.py
new file mode 100644
index 0000000..0cf544b
--- /dev/null
+++ b/attack/collectTargetData.py
@@ -0,0 +1,50 @@
+"""
+Date: 2022-03-09
+Desc: Collect the labels output by the target model for training the subtitute model
+"""
+
+import torch
+from TargetModel.FSNet.FSNet import FSNet
+from TargetModel.FSNet.dataset import C2Data
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+import torch.nn.functional as F
+import numpy as np
+
+def collectData(model:FSNet, dataloader, device, filename):
+ """
+
+ :param model:
+ :param dataloader:
+ :param filenale:
+ :return:
+ """
+ collectData = []
+ for batch_x, batch_y in tqdm(dataloader):
+ batch_x = batch_x.to(device)
+ batch_y = batch_y.to(device)
+ z_e = model.encode(batch_x)
+ z_d, D = model.decode(z_e)
+ z_dense = model.dense(z_e, z_d)
+ # z_dense = F.softmax(z_dense)
+ z_dense = torch.argmax(z_dense, dim=1, keepdim=True)
+ collect = torch.cat((batch_x, z_dense), dim=1)
+ collectData += collect.detach().cpu().numpy().tolist()
+ collectData = np.array(collectData)
+ np.save(filename, collectData)
+ return collectData
+
+
+
+if __name__ == '__main__':
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ state = torch.load('../modelFile/target_fsnet.pkt')
+ fsnet = FSNet(state['param'])
+ fsnet.load_state_dict(state['model_dict'])
+ fsnet = fsnet.to(device)
+ c2data = C2Data(number=200)
+ batch_size = 32
+ dataloader = DataLoader(c2data, batch_size, shuffle=True, drop_last=False)
+ data = collectData(fsnet, dataloader, device, "../adversarialData/collectionData.npy")
+ print(data.shape)