#coding:utf-8 from models.graphcnn import GraphCNN from process_dataset import load_data import torch import numpy as np def pass_data_iteratively(model, graphs, minibatch_size = 64): model.eval() output = [] idx = np.arange(len(graphs)) for i in range(0, len(graphs), minibatch_size): sampled_idx = idx[i:i+minibatch_size] if len(sampled_idx) == 0: continue output.append(model([graphs[j] for j in sampled_idx]).detach()) return torch.cat(output, 0) def test(args, model, device, train_graphs, test_graphs, epoch): model.eval() output = pass_data_iteratively(model, train_graphs) pred = output.max(1, keepdim=True)[1] labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device) correct = pred.eq(labels.view_as(pred)).sum().cpu().item() acc_train = correct / float(len(train_graphs)) output = pass_data_iteratively(model, test_graphs) pred = output.max(1, keepdim=True)[1] labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device) correct = pred.eq(labels.view_as(pred)).sum().cpu().item() acc_test = correct / float(len(test_graphs)) print("accuracy train: %f test: %f" % (acc_train, acc_test)) return acc_train, acc_test def main(): torch.manual_seed(0) np.random.seed(0) device = torch.device("cuda:" + str(0)) if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) graphs, num_classes = load_data("store_true") model = GraphCNN(5, 2, graphs[0].node_features.shape[1], 64, 2, 0.5, "store_true", "sum", "sum", device).to(device) path = "NetModel.pth" model.load_state_dict(torch.load(path)) output = pass_data_iteratively(model,graphs) pred = output.max(1, keepdim=True)[1] labels = torch.LongTensor([graph.label for graph in graphs]).to(device) correct = pred.eq(labels.view_as(pred)).sum().cpu().item() acc = correct / float(len(graphs)) print("accuracy:%s" %acc)