summaryrefslogtreecommitdiff
path: root/script/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'script/train.py')
-rw-r--r--script/train.py356
1 files changed, 356 insertions, 0 deletions
diff --git a/script/train.py b/script/train.py
new file mode 100644
index 0000000..68364ac
--- /dev/null
+++ b/script/train.py
@@ -0,0 +1,356 @@
+import argparse
+import torch
+from torch import nn, optim
+from torch.autograd import Variable
+import torch.nn.functional as F
+import numpy as np
+import pandas
+import random
+import time
+import pickle
+from collections import Counter
+
+from utils import GRUModel, LSTMModel, rnn_loss_function, AEModel, ae_loss_function
+from utils import read_dataset, generate_ngram_seq, generate_contextual_profile_dataset, \
+ generate_contextual_profile_dataset_fused, generate_ngram_seq_dataset, calculate_acc, print_per_label_accu
+from preprocess_dataset import nf_conntrack_states
+
+ERR_TOO_SHORT_SEQ = -1
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description='GRU for learning benign contextual profiles')
+ parser.add_argument('--train-dataset', type=str,
+ help='path to dataset file')
+ parser.add_argument('--loss-list-fpath', type=str,
+ help='path to dump loss list')
+ parser.add_argument('--device', type=str, help='device for training')
+ parser.add_argument('--seed', type=int, default=1,
+ metavar='S', help='random seed (default: 1)')
+ parser.add_argument('--rnn-num-epochs', type=int, default=5,
+ help='number of epochs for RNN training (default: 5)')
+ parser.add_argument('--vae-num-epochs', type=int, default=5,
+ help='number of epochs for vae training (default: 5)')
+ parser.add_argument('--cutoff', type=int, default=-1,
+ help='cutoff for rnn training (default: -1)')
+ parser.add_argument('--ae-batch-size', type=int, default=128,
+ help='AE batch size for AE training (default: 128)')
+ parser.add_argument('--input-size', type=int,
+ help='Raw input size (no default)')
+ # NOTE: Not sure how to make the batch size for training RNN larger than 1
+ parser.add_argument('--rnn-batch-size', type=int, default=1,
+ help='batch size for RNN training (default: 1)')
+ parser.add_argument('--rnn-hidden-size', type=int, default=5,
+ help='hidden state size for RNN (default: 5)')
+ parser.add_argument('--ae-bottleneck-size', type=int, default=5,
+ help='bottleneck size for AE (default: 5)')
+ parser.add_argument('--n-gram', type=int, default=3,
+ help='n-gram for training/testing the autoencoder (default: 3)')
+ parser.add_argument('--train-seq-cutoff', type=int, default=-1,
+ help='seq cutoff for training phase (default: -1)')
+ parser.add_argument('--rnn-train-checkpoint', type=int, default=100000,
+ help='how many samples to perform a testing (default: 5000)')
+ parser.add_argument('--ae-train-checkpoint', type=int, default=10000,
+ help='how many samples to perform a testing (default: 5000)')
+ parser.add_argument('--error-thres', type=float,
+ help='reconstruction error for autoencoder')
+ parser.add_argument('--debug', action="store_true",
+ help='enables debugging information')
+ parser.add_argument('--train-ae-model', action="store_true",
+ help='whether to train AE model')
+ parser.add_argument('--save-model', action="store_true",
+ help='persists the models for future use')
+ parser.add_argument('--loss-weighting', type=str, default="none",
+ help='weights the loss w.r.t. different labels')
+ parser.add_argument('--context-mode', type=str, default="use_gates",
+ help='type of contextual profile to use')
+ parser.add_argument('--ae-model-type', type=str,
+ default="mid", help='type of AE model to use')
+ parser.add_argument('--partition-mode', type=str,
+ default="none", help='type of partitioning to use')
+ parser.add_argument('--rnn-model', type=str, help='loads AE model')
+ parser.add_argument('--save-model-suffix', type=str,
+ help='for marking saved models')
+ parser.add_argument('--test-dataset', type=str,
+ help='path to dataset file')
+ parser.add_argument('--rnn-num-layers', type=int, default=1,
+ help='number of GRU stacked layers (default=3)')
+ parser.add_argument('--rnn-model-type', type=str,
+ default="gru", help='which RNN model to use.')
+ parser.add_argument('--extra-features', type=str,
+ help='whwther to include post-mortem features.')
+ parser.add_argument('--conn-dir', type=str,
+ help='which direction to build contextual profile.')
+ args = parser.parse_args()
+
+ print("[INFO][Train] All arguments:")
+ print(str(args))
+
+ torch.manual_seed(args.seed)
+
+ device = torch.device(args.device if args.device else "cpu")
+
+ if args.loss_weighting == 'weighted':
+ use_loss_weights = True
+ else:
+ use_loss_weights = False
+
+ train_loader, train_state_map, stats, label_count = read_dataset(
+ args.train_dataset, batch_size=args.rnn_batch_size, preprocess=True, cutoff=args.cutoff, seq_cutoff=args.train_seq_cutoff, split_train_test=False)
+ if args.test_dataset:
+ test_loader, test_state_map, _, test_label_count = read_dataset(
+ args.test_dataset, batch_size=1, preprocess=True, cutoff=args.cutoff, seq_cutoff=args.train_seq_cutoff, split_train_test=False, stats=stats)
+
+ if len(train_state_map) < len(test_state_map):
+ print("[ERROR] Different numbers of classes:")
+ print(str(train_state_map))
+ print(str(label_count))
+ print(str(test_state_map))
+ print(str(test_label_count))
+ exit()
+
+ loss_weights = [0.0] * (int(max(list(label_count.keys()))) + 1)
+ for label_id, count in label_count.items():
+ loss_weights[int(label_id)] = 1.0 / count
+ stats_dump = {"label_map": train_state_map,
+ "stats": stats, "loss_weights": loss_weights}
+ loss_weights = torch.FloatTensor(loss_weights).to(device)
+ reversed_train_state_map = {}
+ for label, label_id in train_state_map.items():
+ reversed_train_state_map[label_id] = label
+ train_state_map = reversed_train_state_map
+ print("[INFO] Stats used for the dataset")
+ print(stats_dump)
+ with open(args.train_dataset + '.stats', 'wb') as fout:
+ pickle.dump(stats_dump, fout)
+
+ input_size = args.input_size
+ hidden_size = args.rnn_hidden_size
+ num_class = len(train_state_map)
+ output_size = num_class
+ batch_size = 1
+ num_layers = args.rnn_num_layers
+ if 'bi_' in args.rnn_model_type:
+ rnn_bidirectional = True
+ else:
+ rnn_bidirectional = False
+
+ if not args.rnn_model:
+ if args.rnn_model_type == 'test_lstm':
+ rnn_model = BiLSTMTestModel(
+ input_size, hidden_size, num_layers, num_class, device).to(device)
+ elif 'gru' in args.rnn_model_type:
+ rnn_model = GRUModel(input_size, hidden_size, output_size,
+ num_layers, device, rnn_bidirectional).to(device)
+ elif 'lstm' in args.rnn_model_type:
+ rnn_model = LSTMModel(input_size, hidden_size, output_size,
+ num_layers, device, rnn_bidirectional).to(device)
+ rnn_learning_rate = 1e-4
+ rnn_optimizer = torch.optim.Adam(
+ rnn_model.parameters(), lr=rnn_learning_rate)
+
+ for epoch in range(args.rnn_num_epochs):
+ print("[INFO][Train] ============ Epoch: %d ============" % epoch)
+ train_average_loss = 0.0
+ for batch_idx, [x, labels] in enumerate(train_loader):
+ x = x.to(device, dtype=torch.float)
+ labels = labels.to(
+ device, dtype=torch.long).view(labels.size(1))
+
+ rnn_optimizer.zero_grad()
+
+ if args.rnn_model_type == 'test_lstm':
+ outputs = rnn_model(x)
+ elif 'gru' in args.rnn_model_type:
+ outputs, _, _ = rnn_model(x)
+ elif 'lstm' in args.rnn_model_type:
+ outputs, _, _ = rnn_model(x)
+ outputs = outputs.view(x.size(1), num_class)
+ if use_loss_weights:
+ loss = rnn_loss_function(
+ outputs, labels, weight=loss_weights)
+ else:
+ loss = rnn_loss_function(outputs, labels)
+ train_average_loss += loss.item()
+
+ loss.backward()
+ rnn_optimizer.step()
+
+ if (batch_idx > 0 and batch_idx % args.rnn_train_checkpoint == 0) or batch_idx == len(train_loader) - 1:
+ print("[INFO][Train] Sample idx: %d; Training loss: %f" %
+ (batch_idx, train_average_loss / (batch_idx + 1)))
+
+ if not args.test_dataset:
+ continue
+
+ rnn_model.eval() # Setting model to eval model to disable Dropout layers
+ correct, total = 0, 0
+ test_average_loss = 0.0
+ correct_labels, incorrect_labels = [], []
+ for test_idx, [test_x, test_labels] in enumerate(test_loader):
+ test_x = test_x.to(device, dtype=torch.float)
+ test_labels = test_labels.to(
+ device, dtype=torch.long).view(test_labels.size(1))
+ test_x_size = test_labels.size(0)
+
+ if args.rnn_model_type == 'test_lstm':
+ test_outputs = rnn_model(test_x)
+ elif 'gru' in args.rnn_model_type:
+ test_outputs, _, _ = rnn_model(test_x)
+ elif 'lstm' in args.rnn_model_type:
+ test_outputs, _, _ = rnn_model(test_x)
+ test_outputs = test_outputs.view(
+ test_x.size(1), num_class)
+ if use_loss_weights:
+ test_loss = rnn_loss_function(
+ test_outputs, test_labels, weight=loss_weights)
+ else:
+ test_loss = rnn_loss_function(
+ test_outputs, test_labels)
+ test_average_loss += test_loss.item()
+
+ curr_correct, curr_total, corr_labels, incorr_labels = calculate_acc(
+ test_outputs, test_labels)
+ correct_labels.extend(corr_labels.tolist())
+ incorrect_labels.extend(incorr_labels.tolist())
+ total += curr_total
+ correct += curr_correct
+
+ _ = print_per_label_accu(Counter(correct_labels), Counter(
+ incorrect_labels), test_state_map)
+ test_average_loss /= len(test_loader)
+ accuracy = float(correct) / total
+
+ print('[INFO][Test] Testing loss: {}. Overall testing accuracy: {}'.format(
+ test_average_loss, accuracy))
+ rnn_model.train() # Now returning to train model
+ else:
+ rnn_model = torch.load(args.rnn_model).to(device)
+
+ if args.save_model and not args.rnn_model:
+ if args.save_model_suffix:
+ torch.save(rnn_model, "../model/rnn_model.pt.%s" %
+ args.save_model_suffix)
+ else:
+ torch.save(rnn_model, "../model/rnn_model.pt.%s" %
+ str(int(time.time())))
+
+ if not args.train_ae_model:
+ exit()
+
+ # Reload the training set if we need to include additional features
+ if args.conn_dir == 'only_outbound':
+ only_outbound = True
+ else:
+ only_outbound = False
+ rnn_model.eval() # switching to eval mode
+ if args.extra_features == 'all_addi':
+ addi_train_loader, _, _, _ = read_dataset(args.train_dataset, batch_size=args.rnn_batch_size, preprocess=True,
+ cutoff=args.cutoff, seq_cutoff=args.train_seq_cutoff, split_train_test=False, add_additional_features=True)#tiaozheng false
+ contextual_dataset = generate_contextual_profile_dataset_fused(train_loader, device, rnn_model, context_mode=args.context_mode,
+ partition_mode=args.partition_mode, rnn_model_type=args.rnn_model_type, label_map=train_state_map, addi_data_loader=addi_train_loader)
+ num_addi_features = 15
+ else:
+ # Now is the time to save final contextual profile
+ contextual_dataset = generate_contextual_profile_dataset(
+ train_loader, device, rnn_model, context_mode=args.context_mode, partition_mode=args.partition_mode, rnn_model_type=args.rnn_model_type, label_map=train_state_map)
+ num_addi_features = 0
+
+ if args.context_mode == "baseline":
+ vae_input_size = (input_size + num_addi_features) * args.n_gram
+ elif args.context_mode == "use_hn":
+ vae_input_size = (input_size + hidden_size) * args.n_gram
+ elif args.context_mode == "use_all":
+ vae_input_size = (input_size + hidden_size * 5) * args.n_gram
+ elif args.context_mode == "only_gates":
+ vae_input_size = (hidden_size * 2) * args.n_gram
+ elif args.context_mode == "only_hn":
+ vae_input_size = hidden_size * args.n_gram
+ elif args.context_mode == "use_all_gates":
+ vae_input_size = (input_size + hidden_size * 4) * args.n_gram
+ elif args.context_mode == "use_gates":
+ vae_input_size = (input_size + num_addi_features +
+ hidden_size * 2) * args.n_gram
+ elif args.context_mode == "use_gates_label":
+ vae_input_size = (input_size + num_addi_features + hidden_size *
+ 2 + len(nf_conntrack_states) + 1) * args.n_gram
+ print("[INFO][Train] AE input size: %d" % vae_input_size)
+
+ vae_learning_rate = 1e-3
+ if args.partition_mode == 'none':
+ vae_model = AEModel(
+ vae_input_size, args.ae_bottleneck_size, args.ae_model_type).to(device)
+ vae_optimizer = optim.Adam(
+ vae_model.parameters(), lr=vae_learning_rate)
+ else:
+ vae_models = {}
+ vae_optimizers = {}
+ for label in list(contextual_dataset.keys()):
+ vae_models[label] = AEModel(
+ vae_input_size, args.ae_bottleneck_size, args.ae_model_type).to(device)
+ vae_optimizers[label] = optim.Adam(
+ vae_models[label].parameters(), lr=vae_learning_rate)
+
+ if args.partition_mode == 'none':
+ profile_loader = torch.utils.data.DataLoader(
+ contextual_dataset, batch_size=args.rnn_batch_size, shuffle=True)
+ seq_profile_loader = generate_ngram_seq_dataset(
+ profile_loader, args.n_gram, batch_size=args.ae_batch_size, debug=False, only_outbound=only_outbound)
+ else:
+ profile_loaders = {}
+ for label, dataset in contextual_dataset.items():
+ profile_loaders[label] = torch.utils.data.DataLoader(
+ contextual_dataset[label], batch_size=args.ae_batch_size, shuffle=True)
+
+ for epoch in range(args.vae_num_epochs):
+ if args.partition_mode == "none": # Partitioning is off
+ train_loss = 0.0
+ seq_cnt = 0
+ for batch_idx, ngram in enumerate(seq_profile_loader):
+ ngram = ngram.view(-1, vae_input_size)
+ vae_optimizer.zero_grad()
+ recon_ngram = vae_model(ngram)
+ loss = ae_loss_function(ngram, recon_ngram)
+ loss.backward()
+ train_loss += loss.item()
+ vae_optimizer.step()
+ if (batch_idx != 0 and batch_idx % args.ae_train_checkpoint == 0) or batch_idx == len(seq_profile_loader) - 1:
+ print("[INFO][Train] Training checkpoint: batch #%d" %
+ batch_idx)
+ print('[INFO] ====> Epoch: {}; Average loss: {:.4f}'.format(
+ epoch, train_loss/(batch_idx+1)))
+ else: # Partitioning is on
+ train_loss = {}
+ for label, label_profile_loader in profile_loaders.items():
+ train_loss[label] = 0.0
+ for batch_idx, profile in enumerate(label_profile_loader):
+ profile = profile.view(-1, vae_input_size)
+ vae_optimizers[label].zero_grad()
+ recon_profile = vae_models[label](profile)
+ loss = ae_loss_function(profile, recon_profile)
+ loss.backward()
+ train_loss[label] += loss.item()
+ vae_optimizers[label].step()
+ if (batch_idx != 0 and batch_idx % args.ae_train_checkpoint == 0) or batch_idx == len(profile_loaders[label]) - 1:
+ print(
+ "[INFO][Train] Training checkpoint: batch #%d" % batch_idx)
+ print('[INFO] ====> [Label: %s] Epoch: %d; Average loss: %f' % (
+ reversed_train_state_map[label], epoch, train_loss[label]/(batch_idx+1)))
+
+ if args.save_model:
+ if args.partition_mode == 'none':
+ if args.save_model_suffix:
+ torch.save(vae_model, "../model/vae_model.pt.%s" %
+ args.save_model_suffix)
+ else:
+ torch.save(vae_model, "../model/vae_model.pt.%s" %
+ str(int(time.time())))
+ else:
+ for label, vae_model in vae_models.items():
+ if args.save_model_suffix:
+ torch.save(vae_model, "../model/vae_model.pt.%s.%s" %
+ (args.save_model_suffix, str(reversed_train_state_map[label])))
+ else:
+ torch.save(vae_model, "../model/vae_model.pt.%s.%s" %
+ (str(int(time.time())), str(reversed_train_state_map[label])))