summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author张世俊 <[email protected]>2022-06-15 07:23:28 +0000
committer张世俊 <[email protected]>2022-06-15 07:23:28 +0000
commit71d03015ba766b8368fa13fba66b906cd3f2f271 (patch)
tree783673d4d5c029c8229560e388b6e410ef4fc96f
parent935f1800b3fd6ac89da90c37db24751fd0a61b9b (diff)
Upload HGNR HGT HAN model for pirated video website detection
-rw-r--r--graph.ipynb1680
1 files changed, 1680 insertions, 0 deletions
diff --git a/graph.ipynb b/graph.ipynb
new file mode 100644
index 0000000..eefcfbc
--- /dev/null
+++ b/graph.ipynb
@@ -0,0 +1,1680 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cBo19DhI7uoD"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "# !git clone https://github.com/joerg84/Graph_Powered_ML_Workshop.git\n",
+ "# !rsync -av Graph_Powered_ML_Workshop/ ./ --exclude=.git\n",
+ "!pip3 install dgl-cu110 -f https://data.dgl.ai/wheels/repo.html\n",
+ "# !pip3 install dgl-cu102 -f https://data.dgl.ai/wheels/repo.html\n",
+ "!pip3 install numpy\n",
+ "!pip3 install torch\n",
+ "!pip3 install networkx\n",
+ "!pip3 install matplotlib\n",
+ "!pip3 install scikit-learn\n",
+ "!pip3 install chars2vec\n",
+ "!pip3 install keras\n",
+ "!pip3 install pandas"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "OX7Fi6esQzct",
+ "outputId": "7a959969-b1b5-474f-e147-9591cb9c05d9"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')\n",
+ "!mkdir ./data/\n",
+ "!mkdir ./graph/\n",
+ "!mkdir ./pagerank/\n",
+ "!mkdir ./model\n",
+ "\n",
+ "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_1k_labels.csv ./data/\n",
+ "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_3k_labels.csv ./data/\n",
+ "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_5k_labels.csv ./data/\n",
+ "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_1w_labels.csv ./data/\n",
+ "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_whois.csv ./data/\n",
+ "!cp ./drive/MyDrive/Colab\\ Notebooks/han/model.py ./"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 541
+ },
+ "id": "CmJZXrpVHod1",
+ "outputId": "573fe488-fe6d-4b7b-c16a-cfad16e5b8b3"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "XnQWwdumsSdD",
+ "outputId": "f0b3f413-b629-4d29-e678-cc758f14a4e0"
+ },
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "load pirate data\n",
+ "'''\n",
+ "import re\n",
+ "from pprint import pprint\n",
+ "\n",
+ "import re\n",
+ "from pprint import pprint\n",
+ "\n",
+ "import chars2vec\n",
+ "import dgl\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "from keras_preprocessing.sequence import pad_sequences\n",
+ "from keras_preprocessing.text import Tokenizer\n",
+ "\n",
+ "\n",
+ "def get_raw_data(path):\n",
+ " raw = pd.read_csv(path)\n",
+ " reg_dic, dns_dic, email_dic = get_whois()\n",
+ " print('Load Pirate Dataset: {}'.format(path))\n",
+ " in_feat = {}\n",
+ " website = {}\n",
+ " third = {}\n",
+ " w_ip = {}\n",
+ " t_ip = {}\n",
+ " cert = {}\n",
+ " reg = {}\n",
+ " dns = {}\n",
+ " email = {}\n",
+ " wi = 0\n",
+ " ti = 0\n",
+ " wpi = 0\n",
+ " tpi = 0\n",
+ " ci = 0\n",
+ " ri = 0\n",
+ " di = 0\n",
+ " ei = 0\n",
+ " edge_w_t = []\n",
+ " edge_w_i = []\n",
+ " edge_w_c = []\n",
+ " edge_t_c = []\n",
+ " edge_t_i = []\n",
+ " edge_w_r = []\n",
+ " edge_w_d = []\n",
+ " edge_w_e = []\n",
+ " labels = []\n",
+ "\n",
+ " '''\n",
+ " raw data for website's ndoe embedding\n",
+ " '''\n",
+ " raw_reg = []\n",
+ " raw_dns = []\n",
+ " raw_email = []\n",
+ "\n",
+ " '''\n",
+ " raw data for graph's node embedding\n",
+ " '''\n",
+ " node_domain = []\n",
+ " node_third = []\n",
+ " node_w_ip = []\n",
+ " node_t_ip = []\n",
+ " node_cert = []\n",
+ " node_email = []\n",
+ " node_reg = []\n",
+ " node_dns = []\n",
+ "\n",
+ "\n",
+ " for index, row in raw.iterrows():\n",
+ " t = str(row['req']).split('/')[2].strip('www.')\n",
+ " i = str(row['ip']).strip()\n",
+ " w = str(row['website']).strip().split(':')[0]\n",
+ " c = str(row['cert']).strip()\n",
+ " if t != w:\n",
+ " if t not in third:\n",
+ " third[t] = ti\n",
+ " node_third.append(t)\n",
+ " ti += 1\n",
+ " if i not in t_ip and i != 'nan':\n",
+ " t_ip[i] = tpi\n",
+ " node_t_ip.append(i)\n",
+ " tpi += 1\n",
+ " else:\n",
+ " if i not in w_ip and i != 'nan':\n",
+ " w_ip[i] = wpi\n",
+ " node_w_ip.append(i)\n",
+ " wpi += 1\n",
+ " if w not in website:\n",
+ " labels.append(int(row['label']))\n",
+ " node_domain.append(w)\n",
+ " raw_reg.append(reg_dic[w])\n",
+ " raw_dns.append(dns_dic[w])\n",
+ " raw_email.append(email_dic[w])\n",
+ " website[w] = wi\n",
+ " wi+=1\n",
+ " if reg_dic[w] not in reg and reg_dic[w] != 'nan':\n",
+ " reg[reg_dic[w]] = ri\n",
+ " node_reg.append(reg_dic[w])\n",
+ " ri += 1\n",
+ " if dns_dic[w] != 'nan':\n",
+ " for _ in dns_dic[w].split(','):\n",
+ " __ = _.strip().lower()\n",
+ " if __ not in dns:\n",
+ " dns[__] = di\n",
+ " node_dns.append(__)\n",
+ " di += 1\n",
+ " if email_dic[w] not in email and email_dic[w] != 'nan':\n",
+ " email[email_dic[w]] = ei\n",
+ " node_email.append(email_dic[w])\n",
+ " ei += 1\n",
+ " in_feat[w] = [0, 0, 0]\n",
+ " if c not in cert and c != 'nan':\n",
+ " cert[c] = ci\n",
+ " node_cert.append(c)\n",
+ " ci += 1\n",
+ "\n",
+ " in_feat[w][0] += 1\n",
+ " if t != w:\n",
+ " edge_w_t.append(tuple([website[w], third[t]]))\n",
+ " if i != 'nan':\n",
+ " edge_t_i.append(tuple([third[t], t_ip[i]]))\n",
+ " in_feat[w][1] += 1\n",
+ " if c != 'nan':\n",
+ " edge_t_c.append(tuple([third[t], cert[c]]))\n",
+ " if t == w:\n",
+ " if i != 'nan':\n",
+ " edge_w_i.append(tuple([website[w], w_ip[i]]))\n",
+ " if reg_dic[w] != 'nan':\n",
+ " edge_w_r.append(tuple([website[w], reg[reg_dic[w]]]))\n",
+ " if dns_dic[w] != 'nan':\n",
+ " for _ in dns_dic[w].split(','):\n",
+ " __ = _.strip().lower()\n",
+ " edge_w_d.append(tuple([website[w], dns[__]]))\n",
+ " if email_dic[w] != 'nan':\n",
+ " edge_w_e.append(tuple([website[w], email[email_dic[w]]]))\n",
+ " if c != 'nan':\n",
+ " edge_w_c.append(tuple([website[w], cert[c]]))\n",
+ " in_feat[w][2] += 1\n",
+ "\n",
+ "\n",
+ " edge_w_i = set(edge_w_i)\n",
+ " edge_w_t = set(edge_w_t)\n",
+ " edge_w_c = set(edge_w_c)\n",
+ " edge_t_c = set(edge_t_c)\n",
+ " edge_t_i = set(edge_t_i)\n",
+ " edge_w_r = set(edge_w_r)\n",
+ " edge_w_d = set(edge_w_d)\n",
+ " edge_w_e = set(edge_w_e)\n",
+ "\n",
+ "\n",
+ "\n",
+ " print('Node of Websites: ', len(website))\n",
+ " print('Node of Third-party Services: ', len(third))\n",
+ " print('Node of IP: ', len(w_ip))\n",
+ " print('Node of Cert: ', len(cert))\n",
+ " print('Node of Registrant: ',len(reg))\n",
+ " print('Node of DNS: ', len(dns))\n",
+ " print('Node of Email: ', len(email))\n",
+ " print('Edge of Website to Third-party: ', len(edge_w_t))\n",
+ " print('Edge of Webiste to IP: ', len(edge_w_i))\n",
+ " print('Edge of Website to Certificate: ', len(edge_w_c))\n",
+ " print('Edge of Website to Registrant: ', len(edge_w_r))\n",
+ " print('Edge of Webiste to DNS Server: ', len(edge_w_d))\n",
+ " print('Edge of Website to Email: ', len(edge_w_e))\n",
+ " print('Edge of Third-party to Certificate: ', len(edge_t_c))\n",
+ " print('Edge of Third-party to IP: ', len(edge_t_i))\n",
+ " print('labels: ', len(labels))\n",
+ " web_tp_src = []\n",
+ " web_tp_dst = []\n",
+ " web_ip_src = []\n",
+ " web_ip_dst = []\n",
+ " web_cert_src = []\n",
+ " web_cert_dst = []\n",
+ " third_cert_src = []\n",
+ " third_cert_dst = []\n",
+ " third_ip_src = []\n",
+ " third_ip_dst = []\n",
+ " web_reg_src = []\n",
+ " web_reg_dst = []\n",
+ " web_dns_src = []\n",
+ " web_dns_dst = []\n",
+ " web_email_src = []\n",
+ " web_email_dst = []\n",
+ " for item in edge_w_t:\n",
+ " web_tp_src.append(item[0])\n",
+ " web_tp_dst.append(item[1])\n",
+ " for item in edge_w_i:\n",
+ " web_ip_src.append(item[0])\n",
+ " web_ip_dst.append(item[1])\n",
+ " for item in edge_w_c:\n",
+ " web_cert_src.append(item[0])\n",
+ " web_cert_dst.append(item[1])\n",
+ " for item in edge_t_c:\n",
+ " third_cert_src.append(item[0])\n",
+ " third_cert_dst.append(item[1])\n",
+ " for item in edge_t_i:\n",
+ " third_ip_src.append(item[0])\n",
+ " third_ip_dst.append(item[1])\n",
+ " for item in edge_w_r:\n",
+ " web_reg_src.append(item[0])\n",
+ " web_reg_dst.append(item[1])\n",
+ " for item in edge_w_d:\n",
+ " web_dns_src.append(item[0])\n",
+ " web_dns_dst.append(item[1])\n",
+ " for item in edge_w_e:\n",
+ " web_email_src.append(item[0])\n",
+ " web_email_dst.append(item[1])\n",
+ " # hg = dgl.heterograph({\n",
+ " # ('web', 'wt', 'third'): (web_tp_src, web_tp_dst),\n",
+ " # ('third', 'tw', 'web'): (web_tp_dst, web_tp_src)},{'web':len(website),'third':len(third),'ip':len(ip),'cert':len(cert)})\n",
+ " hg = dgl.heterograph({\n",
+ " ('web', 'wt', 'third'): (web_tp_src, web_tp_dst),\n",
+ " ('third', 'tw', 'web'): (web_tp_dst, web_tp_src),\n",
+ " ('web', 'wi', 'ip'): (web_ip_src, web_ip_dst),\n",
+ " ('ip', 'iw', 'web'): (web_ip_dst, web_ip_src),\n",
+ " ('web', 'wc', 'cert'): (web_cert_src, web_cert_dst),\n",
+ " ('cert', 'cw', 'web'): (web_cert_dst, web_cert_src),\n",
+ " # ('third', 'tc', 'cert'): (third_cert_src, third_cert_dst),\n",
+ " # ('cert', 'ct', 'third'): (third_cert_dst, third_cert_src),\n",
+ " # ('third', 'ti', 'ip'): (third_ip_src, third_ip_dst),\n",
+ " # ('ip', 'it', 'third'): (third_ip_dst, third_ip_src),\n",
+ " ('web', 'wr', 'reg'): (web_reg_src, web_reg_dst),\n",
+ " ('reg', 'rw', 'web'): (web_reg_dst, web_reg_src),\n",
+ " ('web', 'we', 'email'): (web_email_src, web_email_dst),\n",
+ " ('email', 'ew', 'web'): (web_email_dst, web_email_src),\n",
+ " ('web', 'wd', 'dns'): (web_dns_src, web_dns_dst),\n",
+ " ('dns', 'dw', 'web'): (web_dns_dst, web_dns_src)\n",
+ " },\n",
+ " # {'web':len(website),'ip':len(w_ip)}\n",
+ " {'web':len(website),'third':len(third),'ip':len(w_ip),'dns':len(dns),'cert':len(cert),'reg':len(reg),'email':len(email)}\n",
+ " )\n",
+ "\n",
+ " labels = torch.LongTensor(labels)\n",
+ "\n",
+ " '''\n",
+ " one-hot features\n",
+ " '''\n",
+ " reg_oh = [[0 for _ in range(len(reg))] for __ in range(len(website))]\n",
+ " email_oh = [[0 for _ in range(len(email))] for __ in range(len(website))]\n",
+ " dns_oh = [[0 for _ in range(len(dns))] for __ in range(len(website))]\n",
+ "\n",
+ "\n",
+ " for i in range(len(raw_reg)):\n",
+ " if raw_reg[i]!='nan':\n",
+ " reg_oh[i][reg[raw_reg[i]]] = 1\n",
+ " for i in range(len(raw_email)):\n",
+ " if raw_email[i] != 'nan':\n",
+ " email_oh[i][email[raw_email[i]]] = 1\n",
+ " for i in range(len(raw_dns)):\n",
+ " for _ in raw_dns[i].split(','):\n",
+ " __ = _.strip().lower()\n",
+ " if __ != 'nan':\n",
+ " dns_oh[i][dns[__]]=1\n",
+ "\n",
+ " reg_oh = np.array(reg_oh)\n",
+ " email_oh = np.array(email_oh)\n",
+ " dns_oh = np.array(dns_oh)\n",
+ "\n",
+ "\n",
+ " in_features = []\n",
+ " for item in node_domain:\n",
+ " in_features.append(in_feat[item])\n",
+ "\n",
+ " data = {'labels': labels, 'node_domain': node_domain, 'raw_reg': raw_reg,\n",
+ " 'raw_dns': raw_reg, 'raw_email': raw_email,\n",
+ " 'hg': hg, 'in_features': np.array(in_features),'node_third':node_third,'node_w_ip':node_w_ip,'node_cert':node_cert,\n",
+ " 'node_reg':node_reg,'node_dns':node_dns,'node_email':node_email\n",
+ " }\n",
+ " import pickle\n",
+ " pickle.dump(data, open('./graph/WTICRDE_1k.pkl', 'wb'))\n",
+ "\n",
+ "\n",
+ "def get_whois():\n",
+ " raw = pd.read_csv('./data/fix_whois.csv')\n",
+ " dic = {}\n",
+ " website = {}\n",
+ " email = {}\n",
+ " dns = {}\n",
+ " for web,x,d,e in zip(raw.website,raw.reg,raw.dns,raw.email):\n",
+ " dns.setdefault(web,str(d))\n",
+ " email.setdefault(web,str(e))\n",
+ " tmp = str(x).strip().strip('.')\n",
+ " if re.match(r'.*GoDaddy',tmp):\n",
+ " dic.setdefault('GoDaddy.com, LLC',0)\n",
+ " website.setdefault(web,'GoDaddy.com, LLC')\n",
+ " elif re.match(r'.*NameSilo',tmp) or re.match(r'.*Namesilo',tmp):\n",
+ " dic.setdefault('NameSilo, LLC',0)\n",
+ " website.setdefault(web, 'NameSilo, LLC')\n",
+ " elif re.match(r'.*Alibaba',tmp) or re.match(r'.*阿里巴巴',tmp) or re.match(r'.*阿里云',tmp):\n",
+ " dic.setdefault('Alibaba Cloud Computing (Beijing) Co., Ltd',0)\n",
+ " website.setdefault(web, 'Alibaba Cloud Computing (Beijing) Co., Ltd')\n",
+ " elif re.match(r'.*新网',tmp) or re.match(r'Xin',tmp):\n",
+ " dic.setdefault('Xin Net Technology Corporation', 0)\n",
+ " website.setdefault(web, 'Xin Net Technology Corporation')\n",
+ " elif re.match(r'NameCheap',tmp):\n",
+ " dic.setdefault('NameCheap, Inc',0)\n",
+ " website.setdefault(web, 'NameCheap, Inc')\n",
+ " elif re.match(r'.*DropCatch',tmp):\n",
+ " dic.setdefault('DropCatch.com, Inc',0)\n",
+ " website.setdefault(web, 'DropCatch.com, Inc')\n",
+ " elif re.match(r'.*Chengdu west',tmp) or re.match(r'.*Chengdu West',tmp) or re.match(r'.*成都西维',tmp):\n",
+ " dic.setdefault('Chengdu West Dimension Digital Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Chengdu West Dimension Digital Technology Co., Ltd')\n",
+ " elif re.match(r'.*Xiamen 35',tmp) or re.match(r'.*厦门三五',tmp):\n",
+ " dic.setdefault('Xiamen 35.Com Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Xiamen 35.Com Technology Co., Ltd')\n",
+ " elif re.match(r'.*1API',tmp):\n",
+ " dic.setdefault('1API GmbH',0)\n",
+ " website.setdefault(web, '1API GmbH')\n",
+ " elif re.match(r'.*中央编办',tmp):\n",
+ " dic.setdefault('China Gov',0)\n",
+ " website.setdefault(web, 'China Gov')\n",
+ " elif re.match(r'.*浙江贰',tmp):\n",
+ " dic.setdefault('Zhejiang ErEr Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Zhejiang ErEr Technology Co., Ltd')\n",
+ " elif re.match(r'.*Dynadot',tmp):\n",
+ " dic.setdefault('Dynadot, LLC',0)\n",
+ " website.setdefault(web, 'Dynadot, LLC')\n",
+ " elif re.match(r'.*广州云',tmp):\n",
+ " dic.setdefault('Guangzhou Cloud Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Guangzhou Cloud Technology Co., Ltd')\n",
+ " elif re.match(r'.*厦门纳网',tmp):\n",
+ " dic.setdefault('Xiamen NaNetwork Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Xiamen NaNetwork Technology Co., Ltd')\n",
+ " elif re.match(r'.*厦门易名',tmp):\n",
+ " dic.setdefault('Xiamen Easy Name Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Xiamen Easy Name Technology Co., Ltd')\n",
+ " elif re.match(r'.*四川域趣',tmp):\n",
+ " dic.setdefault('Sichuan Fun Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Sichuan Fun Technology Co., Ltd')\n",
+ " elif re.match(r'.*成都飞数',tmp):\n",
+ " dic.setdefault('Chengdu Fly Digital Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Chengdu Fly Digital Technology Co., Ltd')\n",
+ " elif re.match(r'.*北京光速',tmp):\n",
+ " dic.setdefault('Beijing LightSpeed Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Beijing LightSpeed Technology Co., Ltd')\n",
+ " elif re.match(r'.*上海福虎',tmp):\n",
+ " dic.setdefault('Shanghai Tiger Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Shanghai Tiger Technology Co., Ltd')\n",
+ " elif re.match(r'.*广东时代',tmp):\n",
+ " dic.setdefault('Guangdong TimeInternet Technology Co., Ltd',0)\n",
+ " website.setdefault(web, 'Guangdong TimeInternet Technology Co., Ltd')\n",
+ " elif re.match(r'.*成都世纪',tmp):\n",
+ " dic.setdefault('Chengdu Century Oriental Network Communication Co., Ltd',0)\n",
+ " website.setdefault(web, 'Chengdu Century Oriental Network Communication Co., Ltd')\n",
+ " elif re.match(r'.*北京国旭',tmp):\n",
+ " dic.setdefault('Beijing Guoxu Network Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Beijing Guoxu Network Technology Co., Ltd')\n",
+ " elif re.match(r'.*北京中科',tmp):\n",
+ " dic.setdefault('Beijing Sanfront Information Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Beijing Sanfront Information Technology Co., Ltd')\n",
+ " elif re.match(r'.*北京东方',tmp):\n",
+ " dic.setdefault('Beijing East Network Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Beijing East Network Technology Co., Ltd')\n",
+ " elif re.match(r'.*商中在线',tmp):\n",
+ " dic.setdefault('Shangzhong Online Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Shangzhong Online Technology Co., Ltd')\n",
+ " elif re.match(r'.*中企动力',tmp):\n",
+ " dic.setdefault('Zhong Qi Dong Li Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Zhong Qi Dong Li Technology Co., Ltd')\n",
+ " elif re.match(r'.*厦门市中资',tmp):\n",
+ " dic.setdefault('Xiamen ChinaSource Internet Service Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Xiamen ChinaSource Internet Service Co., Ltd')\n",
+ " elif re.match(r'.*广东互易',tmp):\n",
+ " dic.setdefault('Guangdong HuYi Network Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Guangdong HuYi Network Technology Co., Ltd')\n",
+ " elif re.match(r'.*遵义中域',tmp):\n",
+ " dic.setdefault('Zunyi Mid Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Zunyi Mid Technology Co., Ltd')\n",
+ " elif re.match(r'.*北京首信网',tmp):\n",
+ " dic.setdefault('Beijing Trust Network Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Beijing Trust Network Technology Co., Ltd')\n",
+ " elif re.match(r'.*赛尔网络',tmp):\n",
+ " dic.setdefault('Saier Network Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Saier Network Technology Co., Ltd')\n",
+ " elif re.match(r'.*遵义中域',tmp):\n",
+ " dic.setdefault('Zunyi Mid Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Zunyi Mid Technology Co., Ltd')\n",
+ " elif re.match(r'.*烟台帝思普',tmp):\n",
+ " dic.setdefault('Yantai DiSiPu Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Yantai DiSiPu Technology Co., Ltd')\n",
+ " elif re.match(r'.*北京万维',tmp):\n",
+ " dic.setdefault('Beijing WanWei Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Beijing WanWei Technology Co., Ltd')\n",
+ " elif re.match(r'.*江苏邦宁',tmp):\n",
+ " dic.setdefault('Jiangsu Bangning Science and technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'Jiangsu Bangning Science and technology Co. Ltd')\n",
+ " elif re.match(r'.*北京神州',tmp):\n",
+ " dic.setdefault('Beijing ShenZhou technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'Beijing ShenZhou technology Co. Ltd')\n",
+ " elif re.match(r'.*北京神州',tmp):\n",
+ " dic.setdefault('Beijing ShenZhou technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'Beijing ShenZhou technology Co. Ltd')\n",
+ " elif re.match(r'.*上海美橙',tmp):\n",
+ " dic.setdefault('Shanghai Meicheng Technology Information Development Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Shanghai Meicheng Technology Information Development Co., Ltd')\n",
+ " elif re.match(r'.*广东金万',tmp):\n",
+ " dic.setdefault('Guangdong JinWanBang Technology Investment Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Guangdong JinWanBang Technology Investment Co., Ltd')\n",
+ " elif re.match(r'.*易介集团',tmp):\n",
+ " dic.setdefault('YiJie Beijing technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'YiJie Beijing technology Co. Ltd')\n",
+ " elif re.match(r'.*佛山市亿动',tmp):\n",
+ " dic.setdefault('Foshan YiDong Network Co., LTD', 0)\n",
+ " website.setdefault(web, 'Foshan YiDong Network Co., LTD')\n",
+ " elif re.match(r'.*山东开创',tmp):\n",
+ " dic.setdefault('Shandong KaiChuang technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'Shandong KaiChuang technology Co. Ltd')\n",
+ " elif re.match(r'.*泛亚信息',tmp):\n",
+ " dic.setdefault('FanYa Jiangsu technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'FanYa Jiangsu technology Co. Ltd')\n",
+ " elif re.match(r'.*杭州电商',tmp):\n",
+ " dic.setdefault('Hangzhou Dianshang Internet Technology Co., LTD', 0)\n",
+ " website.setdefault(web, 'Hangzhou Dianshang Internet Technology Co., LTD')\n",
+ " elif re.match(r'.*深圳市万维',tmp):\n",
+ " dic.setdefault('Shenzhen WanWei technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'Shenzhen WanWei technology Co. Ltd')\n",
+ " elif re.match(r'.*天津追日',tmp):\n",
+ " dic.setdefault('Tianjin Zhuiri Science and Technology Development Co Ltd', 0)\n",
+ " website.setdefault(web, 'Tianjin Zhuiri Science and Technology Development Co Ltd')\n",
+ " elif re.match(r'.*北京宏网',tmp):\n",
+ " dic.setdefault('Beijing HongWang technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'Beijing HongWang technology Co. Ltd')\n",
+ " elif re.match(r'.*昆明乐网',tmp):\n",
+ " dic.setdefault('Kunming LeWang technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'Kunming LeWang technology Co. Ltd')\n",
+ " elif re.match(r'.*上海热线',tmp):\n",
+ " dic.setdefault('Shanghai ReXian technology Co. Ltd', 0)\n",
+ " website.setdefault(web, 'Shanghai ReXian technology Co. Ltd')\n",
+ " elif re.match(r'.*互联网域名',tmp):\n",
+ " dic.setdefault('Internet Domain Name System Beijing Engineering Research Center LLC (ZDNS)', 0)\n",
+ " website.setdefault(web, 'Internet Domain Name System Beijing Engineering Research Center LLC (ZDNS)')\n",
+ " elif re.match(r'.*广东耐思',tmp):\n",
+ " dic.setdefault('SGuangdong Nicenic Technology Co., Ltd. dba NiceNIC', 0)\n",
+ " website.setdefault(web, 'Guangdong Nicenic Technology Co., Ltd. dba NiceNIC')\n",
+ " elif re.match(r'.*重庆智佳',tmp):\n",
+ " dic.setdefault('Chongqing ZhiJia Technology Co., Ltd', 0)\n",
+ " website.setdefault(web, 'Chongqing ZhiJia Technology Co., Ltd')\n",
+ " else:\n",
+ " dic.setdefault(tmp,0)\n",
+ " website.setdefault(web, tmp)\n",
+ " return website,dns,email\n",
+ "\n",
+ "def domain_to_index(domain, length):\n",
+ " \"\"\"\n",
+ "\n",
+ " :param domain: the domain data\n",
+ " :param length: the csv data's length\n",
+ " :return: index_vector of the domain\n",
+ " \"\"\"\n",
+ "\n",
+ " tokenizer = Tokenizer(char_level=True)\n",
+ " tokenizer.fit_on_texts(domain)\n",
+ " word_index = tokenizer.word_index\n",
+ " sequences = tokenizer.texts_to_sequences(domain)\n",
+ "\n",
+ " # padding\n",
+ " data = pad_sequences(sequences, maxlen=30)\n",
+ " # print('Found %s unique tokens.' % len(word_index))\n",
+ " # print(data)\n",
+ " print('Start Transform Character to index')\n",
+ " return np.array(data, dtype=\"float32\").reshape(length, 30), word_index\n",
+ "\n",
+ "def load_pirate(path):\n",
+ " return get_raw_data(path)\n",
+ "\n",
+ "\n",
+ "def load_pkl(dim,path):\n",
+ " print('Load From PKL:',path)\n",
+ " import pickle\n",
+ " data = pickle.load(open(path, 'rb'))\n",
+ " hg = data['hg']\n",
+ "\n",
+ " labels = data['labels']\n",
+ "\n",
+ " '''\n",
+ " character-embedding features\n",
+ " '''\n",
+ " # features, _ = domain_to_index(raw_domain, len(raw_domain))\n",
+ " c2v_model = \\\n",
+ " chars2vec.load_model(dim)\n",
+ " website_features = c2v_model.vectorize_words(data['node_domain'])\n",
+ " web_reg_features = c2v_model.vectorize_words(data['raw_reg'])\n",
+ " web_dns_features = c2v_model.vectorize_words(data['raw_dns'])\n",
+ " web_email_features = c2v_model.vectorize_words(data['raw_email'])\n",
+ "\n",
+ " third_features = c2v_model.vectorize_words(data['node_third'])\n",
+ " ip_features = c2v_model.vectorize_words(data['node_w_ip'])\n",
+ " cert_features = c2v_model.vectorize_words(data['node_cert'])\n",
+ " reg_features = c2v_model.vectorize_words(data['node_reg'])\n",
+ " dns_features = c2v_model.vectorize_words(data['node_dns'])\n",
+ " email_features = c2v_model.vectorize_words(data['node_email'])\n",
+ "\n",
+ " features = {\n",
+ " 'web': torch.FloatTensor(website_features),\n",
+ " 'third': torch.FloatTensor(third_features),\n",
+ " 'ip': torch.FloatTensor(ip_features),\n",
+ " 'cert': torch.FloatTensor(cert_features),\n",
+ " 'email': torch.FloatTensor(email_features),\n",
+ " 'dns':torch.FloatTensor(dns_features),\n",
+ " 'reg':torch.FloatTensor(reg_features),\n",
+ " }\n",
+ "\n",
+ " pprint({\n",
+ " 'Web Features:':features['web'].shape,\n",
+ " 'Third Features:': features['third'].shape,\n",
+ " 'IP Features:': features['ip'].shape,\n",
+ " 'Cert Features:': features['cert'].shape,\n",
+ " 'Reg Features:': features['reg'].shape,\n",
+ " 'DNS Features:': features['dns'].shape,\n",
+ " 'Email Features:': features['email'].shape\n",
+ " })\n",
+ "\n",
+ " in_features = data['in_features']\n",
+ " num_classes = 2\n",
+ "\n",
+ "\n",
+ " # plot(features,labels)\n",
+ "\n",
+ " return hg, features, labels, num_classes\n",
+ "\n",
+ "load_pirate('./data/fix_1k_labels.csv')\n",
+ "# load_pkl('eng_50','./graph/WTICRDE_1w.pkl')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "t6aPVZfwurEP"
+ },
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "utils.py\n",
+ "'''\n",
+ "import datetime\n",
+ "import dgl\n",
+ "import errno\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import pickle\n",
+ "import random\n",
+ "import torch\n",
+ "\n",
+ "from pprint import pprint\n",
+ "from scipy import sparse\n",
+ "from scipy import io as sio\n",
+ "\n",
+ "\n",
+ "def set_random_seed(seed=0):\n",
+ " \"\"\"Set random seed.\n",
+ " Parameters\n",
+ " ----------\n",
+ " seed : int\n",
+ " Random seed to use\n",
+ " \"\"\"\n",
+ " random.seed(seed)\n",
+ " np.random.seed(seed)\n",
+ " torch.manual_seed(seed)\n",
+ " if torch.cuda.is_available():\n",
+ " torch.cuda.manual_seed(seed)\n",
+ "\n",
+ "def mkdir_p(path, log=True):\n",
+ " \"\"\"Create a directory for the specified path.\n",
+ " Parameters\n",
+ " ----------\n",
+ " path : str\n",
+ " Path name\n",
+ " log : bool\n",
+ " Whether to print result for directory creation\n",
+ " \"\"\"\n",
+ " try:\n",
+ " os.makedirs(path)\n",
+ " if log:\n",
+ " print('Created directory {}'.format(path))\n",
+ " except OSError as exc:\n",
+ " if exc.errno == errno.EEXIST and os.path.isdir(path) and log:\n",
+ " print('Directory {} already exists.'.format(path))\n",
+ " else:\n",
+ " raise\n",
+ "\n",
+ "def get_date_postfix():\n",
+ " \"\"\"Get a date based postfix for directory name.\n",
+ " Returns\n",
+ " -------\n",
+ " post_fix : str\n",
+ " \"\"\"\n",
+ " dt = datetime.datetime.now()\n",
+ " post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(\n",
+ " dt.date(), dt.hour, dt.minute, dt.second)\n",
+ "\n",
+ " return post_fix\n",
+ "\n",
+ "def setup_log_dir(args, sampling=False):\n",
+ " \"\"\"Name and create directory for logging.\n",
+ " Parameters\n",
+ " ----------\n",
+ " args Configuration\n",
+ " Returns\n",
+ " -------\n",
+ " log_dir : str\n",
+ " Path for logging directory\n",
+ " sampling : bool\n",
+ " Whether we are using sampling based training\n",
+ " \"\"\"\n",
+ " date_postfix = get_date_postfix()\n",
+ " log_dir = os.path.join(\n",
+ " args['log_dir'],\n",
+ " '{}_{}'.format(args['dataset'], date_postfix))\n",
+ "\n",
+ " if sampling:\n",
+ " log_dir = log_dir + '_sampling'\n",
+ "\n",
+ " mkdir_p(log_dir)\n",
+ " return log_dir\n",
+ "\n",
+ "# The configuration below is from the paper.\n",
+ "default_configure = {\n",
+ " 'lr': 0.005, # Learning rate\n",
+ " 'num_heads': [8], # Number of attention heads for node-level attention\n",
+ " 'hidden_units': 8,\n",
+ " 'dropout': 0.6,\n",
+ " 'weight_decay': 0.001,\n",
+ " 'patience': 100\n",
+ "}\n",
+ "\n",
+ "sampling_configure = {\n",
+ " 'batch_size': 20\n",
+ "}\n",
+ "\n",
+ "def setup(args):\n",
+ " args.update(default_configure)\n",
+ " set_random_seed(args['seed'])\n",
+ " args['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ " return args\n",
+ "\n",
+ "def setup_for_sampling(args):\n",
+ " args.update(default_configure)\n",
+ " args.update(sampling_configure)\n",
+ " set_random_seed()\n",
+ " args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
+ " return args\n",
+ "\n",
+ "def get_binary_mask(total_size, indices):\n",
+ " mask = torch.zeros(total_size)\n",
+ " mask[indices] = 1\n",
+ " return mask.byte()\n",
+ "\n",
+ "\n",
+ "def load_data(dim,dataset, remove_self_loop=False):\n",
+ " print('Load Dataset:',dataset)\n",
+ " if dataset == 'Pirate1RDE':\n",
+ " return load_pkl(dim,'graph/WTICRDE_1k.pkl')\n",
+ " elif dataset == 'Pirate3':\n",
+ " return load_pkl(dim,'graph/WTIC_3k.pkl')\n",
+ " elif dataset == 'Pirate10':\n",
+ " return load_pkl(dim,'graph/WTIC_1w.pkl')\n",
+ " elif dataset == 'Pirate10RDE':\n",
+ " return load_pkl(dim,'graph/WTICRDE_1w.pkl')\n",
+ " elif dataset == 'Pirate5RDE':\n",
+ " return load_pkl(dim,'graph/WTICRDE_5k.pkl')\n",
+ " elif dataset == 'Pirate3RDE':\n",
+ " return load_pkl(dim,'graph/WTICRDE_3k.pkl')\n",
+ " elif dataset == 'Pirate10RE':\n",
+ " return load_pkl(dim,'graph/WTICRE_1w.pkl')\n",
+ " elif dataset == 'Pirate10RD':\n",
+ " return load_pkl(dim,'graph/WTICRD_1w.pkl')\n",
+ " elif dataset == 'Pirate10DE':\n",
+ " return load_pkl(dim,'graph/WTICDE_1w.pkl')\n",
+ " elif dataset == 'Pirate10R':\n",
+ " return load_pkl(dim,'graph/WTICR_1w.pkl')\n",
+ " elif dataset == 'Pirate10WTID':\n",
+ " return load_pkl(dim,'graph/WTID_1w.pkl')\n",
+ " elif dataset == 'Pirate10WT':\n",
+ " return load_pkl(dim,'graph/WT_1w.pkl')\n",
+ " elif dataset == 'Pirate10WI':\n",
+ " return load_pkl(dim,'graph/WI_1w.pkl')\n",
+ " elif dataset == 'Pirate10WD':\n",
+ " return load_pkl(dim,'graph/WD_1w.pkl')\n",
+ " elif dataset == 'Pirate10WC':\n",
+ " return load_pkl(dim,'graph/WC_1w.pkl')\n",
+ " elif dataset == 'Pirate10WE':\n",
+ " return load_pkl(dim,'graph/WE_1w.pkl')\n",
+ " elif dataset == 'Pirate10WR':\n",
+ " return load_pkl(dim,'graph/WR_1w.pkl')\n",
+ " elif dataset == 'Pirate10NoC':\n",
+ " return load_pkl(dim,'graph/WTIRDE_1w.pkl')\n",
+ " else:\n",
+ " return NotImplementedError('Unsupported dataset {}'.format(dataset))\n",
+ "\n",
+ "class EarlyStopping(object):\n",
+ " def __init__(self, patience=10):\n",
+ " dt = datetime.datetime.now()\n",
+ " self.filename = './model/early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(\n",
+ " dt.date(), dt.hour, dt.minute, dt.second)\n",
+ " self.patience = patience\n",
+ " self.counter = 0\n",
+ " self.best_acc = None\n",
+ " self.best_loss = None\n",
+ " self.early_stop = False\n",
+ "\n",
+ " def step(self, loss, acc, model):\n",
+ " if self.best_loss is None:\n",
+ " self.best_acc = acc\n",
+ " self.best_loss = loss\n",
+ " self.save_checkpoint(model)\n",
+ " elif (loss > self.best_loss) and (acc < self.best_acc):\n",
+ " self.counter += 1\n",
+ " # print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
+ " if self.counter >= self.patience:\n",
+ " self.early_stop = True\n",
+ " else:\n",
+ " if (loss <= self.best_loss) and (acc >= self.best_acc):\n",
+ " self.save_checkpoint(model)\n",
+ " self.best_loss = np.min((loss, self.best_loss))\n",
+ " self.best_acc = np.max((acc, self.best_acc))\n",
+ " self.counter = 0\n",
+ " return self.early_stop\n",
+ "\n",
+ " def save_checkpoint(self, model):\n",
+ " \"\"\"Saves model when validation loss decreases.\"\"\"\n",
+ " torch.save(model.state_dict(), self.filename)\n",
+ "\n",
+ " def load_checkpoint(self, model):\n",
+ " \"\"\"Load the latest checkpoint.\"\"\"\n",
+ " model.load_state_dict(torch.load(self.filename))\n",
+ "\n",
+ "\n",
+ "def plot(features,pred,labels):\n",
+ " import torch\n",
+ " import numpy as np\n",
+ " import matplotlib.pyplot as plt\n",
+ " import matplotlib.cm as cm\n",
+ " def PCA(data, k=2):\n",
+ " X = data\n",
+ " X_mean = torch.mean(X, 0)\n",
+ " X = X - X_mean.expand_as(X)\n",
+ " # SVD\n",
+ " U, S, V = torch.svd(torch.t(X))\n",
+ " return torch.mm(X, U[:, :k])\n",
+ "\n",
+ " X = features\n",
+ " y = []\n",
+ " for i in range(len(labels)):\n",
+ " if pred[i]==labels[i] and pred[i]==1:\n",
+ " y.append(1)\n",
+ " elif pred[i]==labels[i] and pred[i]==0:\n",
+ " y.append(2)\n",
+ " elif pred[i]!=labels[i] and pred[i]==0:\n",
+ " y.append(-1)\n",
+ " else:\n",
+ " y.append(-2)\n",
+ " y = torch.tensor(y)\n",
+ " # y = labels\n",
+ " X_pca = PCA(X)\n",
+ " pca = X_pca.numpy()\n",
+ " #\n",
+ " plt.figure()\n",
+ " color = cm.rainbow(np.linspace(0, 1, 4))\n",
+ " plt.scatter(pca[y == 1,0], pca[y == 1,1],s=5, label='True Positive', color=color[0],marker='o')\n",
+ " plt.scatter(pca[y == 2,0], pca[y == 2,1],s=5, label='True Negative', color=color[1],marker='o')\n",
+ " plt.scatter(pca[y == -1,0], pca[y == -1,1],s=100, label='False Positive', color=color[2],marker='*')\n",
+ " plt.scatter(pca[y == -2,0], pca[y == -2,1],s=100, label='False Negative', color=color[3],marker='*')\n",
+ "\n",
+ " plt.legend()\n",
+ " plt.title('PCA of Pirate 1W dataset')\n",
+ " plt.show(): dict\n",
+ "\n",
+ "\n",
+ "# load_data('eng_50','Pirate10RDE')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "FyEAWshW0eDZ"
+ },
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "Focal Loss.py\n",
+ "'''\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import torch.cuda.amp as amp\n",
+ "\n",
+ "\n",
+ "##\n",
+ "# version 1: use torch.autograd\n",
+ "class FocalLossV1(nn.Module):\n",
+ "\n",
+ " def __init__(self,\n",
+ " alpha=0.25,\n",
+ " gamma=2,\n",
+ " reduction='mean',):\n",
+ " super(FocalLossV1, self).__init__()\n",
+ " self.alpha = alpha\n",
+ " self.gamma = gamma\n",
+ " self.reduction = reduction\n",
+ " self.crit = nn.BCEWithLogitsLoss(reduction='none')\n",
+ "\n",
+ " def forward(self, logits, label):\n",
+ " '''\n",
+ " Usage is same as nn.BCEWithLogits:\n",
+ " >>> criteria = FocalLossV1()\n",
+ " >>> logits = torch.randn(8, 19, 384, 384)\n",
+ " >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()\n",
+ " >>> loss = criteria(logits, lbs)\n",
+ " '''\n",
+ " probs = torch.sigmoid(logits)\n",
+ " coeff = torch.abs(label - probs).pow(self.gamma).neg()\n",
+ " log_probs = torch.where(logits >= 0,\n",
+ " F.softplus(logits, -1, 50),\n",
+ " logits - F.softplus(logits, 1, 50))\n",
+ " log_1_probs = torch.where(logits >= 0,\n",
+ " -logits + F.softplus(logits, -1, 50),\n",
+ " -F.softplus(logits, 1, 50))\n",
+ " loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs\n",
+ " loss = loss * coeff\n",
+ "\n",
+ " if self.reduction == 'mean':\n",
+ " loss = loss.mean()\n",
+ " if self.reduction == 'sum':\n",
+ " loss = loss.sum()\n",
+ " return loss\n",
+ "\n",
+ "\n",
+ "##\n",
+ "# version 2: user derived grad computation\n",
+ "class FocalSigmoidLossFuncV2(torch.autograd.Function):\n",
+ " '''\n",
+ " compute backward directly for better numeric stability\n",
+ " '''\n",
+ " @staticmethod\n",
+ " @amp.custom_fwd(cast_inputs=torch.float32)\n",
+ " def forward(ctx, logits, label, alpha, gamma):\n",
+ " # logits = logits.float()\n",
+ "\n",
+ " probs = torch.sigmoid(logits)\n",
+ " coeff = (label - probs).abs_().pow_(gamma).neg_()\n",
+ " log_probs = torch.where(logits >= 0,\n",
+ " F.softplus(logits, -1, 50),\n",
+ " logits - F.softplus(logits, 1, 50))\n",
+ " log_1_probs = torch.where(logits >= 0,\n",
+ " -logits + F.softplus(logits, -1, 50),\n",
+ " -F.softplus(logits, 1, 50))\n",
+ " ce_term1 = log_probs.mul_(label).mul_(alpha)\n",
+ " ce_term2 = log_1_probs.mul_(1. - label).mul_(1. - alpha)\n",
+ " ce = ce_term1.add_(ce_term2)\n",
+ " loss = ce * coeff\n",
+ "\n",
+ " ctx.vars = (coeff, probs, ce, label, gamma, alpha)\n",
+ "\n",
+ " return loss\n",
+ "\n",
+ " @staticmethod\n",
+ " @amp.custom_bwd\n",
+ " def backward(ctx, grad_output):\n",
+ " '''\n",
+ " compute gradient of focal loss\n",
+ " '''\n",
+ " (coeff, probs, ce, label, gamma, alpha) = ctx.vars\n",
+ "\n",
+ " d_coeff = (label - probs).abs_().pow_(gamma - 1.).mul_(gamma)\n",
+ " d_coeff.mul_(probs).mul_(1. - probs)\n",
+ " d_coeff = torch.where(label < probs, d_coeff.neg(), d_coeff)\n",
+ " term1 = d_coeff.mul_(ce)\n",
+ "\n",
+ " d_ce = label * alpha\n",
+ " d_ce.sub_(probs.mul_((label * alpha).mul_(2).add_(1).sub_(label).sub_(alpha)))\n",
+ " term2 = d_ce.mul(coeff)\n",
+ "\n",
+ " grads = term1.add_(term2)\n",
+ " grads.mul_(grad_output)\n",
+ "\n",
+ " return grads, None, None, None\n",
+ "\n",
+ "\n",
+ "class FocalLossV2(nn.Module):\n",
+ "\n",
+ " def __init__(self,\n",
+ " alpha=0.25,\n",
+ " gamma=2,\n",
+ " reduction='mean'):\n",
+ " super(FocalLossV2, self).__init__()\n",
+ " self.alpha = alpha\n",
+ " self.gamma = gamma\n",
+ " self.reduction = reduction\n",
+ "\n",
+ " def forward(self, logits, label):\n",
+ " '''\n",
+ " Usage is same as nn.BCEWithLogits:\n",
+ " >>> criteria = FocalLossV2()\n",
+ " >>> logits = torch.randn(8, 19, 384, 384)\n",
+ " >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()\n",
+ " >>> loss = criteria(logits, lbs)\n",
+ " '''\n",
+ " loss = FocalSigmoidLossFuncV2.apply(logits, label, self.alpha, self.gamma)\n",
+ " if self.reduction == 'mean':\n",
+ " loss = loss.mean(dim=1)\n",
+ " if self.reduction == 'sum':\n",
+ " loss = loss.sum(dim=1)\n",
+ " return loss\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jVPCl_XtZ4xP"
+ },
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "Renode.py\n",
+ "'''\n",
+ "import math\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import numpy as np\n",
+ "\n",
+ "def page_rank(args,adj,filename):\n",
+ " print('Calculate Pagerank')\n",
+ " pr_prob = 1 - args['pagerank_prob']\n",
+ " A = adj.to_dense()\n",
+ " A_hat = A.to(args['device']) + torch.eye(A.size(0)).to(args['device']) # add self-loop\n",
+ " D = torch.diag(torch.sum(A_hat, 1))\n",
+ " D = D.inverse().sqrt()\n",
+ " A_hat = torch.mm(torch.mm(D, A_hat), D)\n",
+ " Pi = pr_prob * ((torch.eye(A.size(0)).to(args['device']) - (1 - pr_prob) * A_hat).inverse())\n",
+ " Pi = Pi.cpu()\n",
+ " torch.save(Pi, filename)\n",
+ " return Pi.to(args['device'])\n",
+ "\n",
+ "def get_gpr(Pi,train_node,labels):\n",
+ " gpr_matrix = [] # the class-level influence distribution\n",
+ " for iter_c in range(2):\n",
+ " iter_Pi = Pi[torch.tensor(train_node[iter_c]).long()]\n",
+ " iter_gpr = torch.mean(iter_Pi, dim=0).squeeze()\n",
+ " gpr_matrix.append(iter_gpr)\n",
+ " '''\n",
+ " gpr_matrix = [tensor[0],tensor[1]]\n",
+ " tensor[0] = []\n",
+ " temp_gpr = []\n",
+ " '''\n",
+ "\n",
+ " temp_gpr = torch.stack(gpr_matrix, dim=0)\n",
+ " temp_gpr = temp_gpr.transpose(0, 1)\n",
+ " gpr = temp_gpr\n",
+ "\n",
+ " return gpr\n",
+ " # rn_weight = get_renode_weight(gpr,) # ReNode Weight\n",
+ "\n",
+ "\n",
+ "def get_renode_weight(args,Pi,gpr,labels,train_mask):\n",
+ " ppr_matrix = Pi # personlized pagerank\n",
+ " gpr_matrix = torch.tensor(gpr).float() # class-accumulated personlized pagerank\n",
+ "\n",
+ " base_w = args['rn_base_weight']\n",
+ " scale_w = args['rn_scale_weight']\n",
+ " nnode = ppr_matrix.size(0)\n",
+ " # unlabel_mask = data.train_mask.int().ne(1) # unlabled node\n",
+ "\n",
+ " # computing the Totoro values for labeled nodes\n",
+ " '''\n",
+ " gpr_sum 对每个节点各label维度上与pagerank计算出的权重进行求和\n",
+ " gpr_rn 转置后减gpr_matrix\n",
+ " rn_matrix pagerank矩阵与每个节点计算出的类别权重进行相乘\n",
+ " '''\n",
+ " gpr_sum = torch.sum(gpr_matrix, dim=1)\n",
+ " gpr_rn = gpr_sum.unsqueeze(1) - gpr_matrix\n",
+ " rn_matrix = torch.mm(ppr_matrix, gpr_rn)\n",
+ "\n",
+ " label_matrix = F.one_hot(labels, gpr_matrix.size(1)).float()\n",
+ " # label_matrix[unlabel_mask] = 0\n",
+ "\n",
+ "\n",
+ " '''\n",
+ " rn_matrix 根据rn矩阵及label标签得到每个节点的rn值\n",
+ " '''\n",
+ " rn_matrix = torch.sum(rn_matrix * label_matrix, dim=1)\n",
+ " # rn_matrix[unlabel_mask] = rn_matrix.max() + 99 # exclude the influence of unlabeled node\n",
+ "\n",
+ "\n",
+ " '''\n",
+ " 根据节点rn值从小到大排序\n",
+ " '''\n",
+ " # computing the ReNode Weight\n",
+ " train_size = torch.sum(train_mask.int()).item()\n",
+ " totoro_list = rn_matrix.tolist()\n",
+ " id2totoro = {i: totoro_list[i] for i in range(len(totoro_list))}\n",
+ "\n",
+ " sorted_totoro = sorted(id2totoro.items(), key=lambda x: x[1], reverse=False)\n",
+ "\n",
+ " id2rank = {sorted_totoro[i][0]: i for i in range(nnode)}\n",
+ "\n",
+ " totoro_rank = [id2rank[i] for i in range(nnode)]\n",
+ "\n",
+ "\n",
+ " rn_weight = [(base_w + 0.5 * scale_w * (1 + math.cos(x * 1.0 * math.pi / (train_size - 1)))) for x in totoro_rank]\n",
+ " rn_weight = torch.from_numpy(np.array(rn_weight)).type(torch.FloatTensor).to(args['device'])\n",
+ " rn_weight = rn_weight * train_mask.float()\n",
+ "\n",
+ " # print(rn_weight.size())\n",
+ " # print('rn size: ',rn_weight.size())\n",
+ " return rn_weight\n",
+ "\n",
+ "\n",
+ "'''\n",
+ "Get PageRank\n",
+ "'''\n",
+ "\n",
+ "def hg2Pi(args,g):\n",
+ " # metapath = ['wi']\n",
+ " metapath = ['wt', 'wi', 'wd','wc','we','wr']\n",
+ " # metapath = ['wt', 'wi', 'wd','we','wr']\n",
+ " Pi = {}\n",
+ " for item in metapath:\n",
+ " if os.path.exists('./pagerank/'+args['dataset'] + item + '.pt'):\n",
+ " Pi[item] = torch.load('./pagerank/'+args['dataset'] + item + '.pt').to(args['device'])\n",
+ " else:\n",
+ " meta_g = dgl.metapath_reachable_graph(g, [item, item[::-1]])\n",
+ " Pi[item] = page_rank(args, meta_g.adjacency_matrix(),'./pagerank/'+args['dataset'] + item + '.pt')\n",
+ "\n",
+ " return Pi\n",
+ "\n",
+ "'''\n",
+ "Get ReNode Weight\n",
+ "'''\n",
+ "\n",
+ "def hg2rn(args,Pi,att_score, train_node, train_mask, labels):\n",
+ " # metapath = ['wi']\n",
+ " metapath = ['wt','wi','wd','wc','we','wr']\n",
+ " # metapath = ['wt','wi','wd','we','wr']\n",
+ " weight = {}\n",
+ " for item in metapath:\n",
+ " pi = Pi[item]*att_score[item]\n",
+ " gpr = get_gpr(pi, train_node, labels)\n",
+ " weight[item] = get_renode_weight(args, pi, gpr, labels, train_mask)\n",
+ "\n",
+ " # for k,v in weight.items():\n",
+ " # print(k,v)\n",
+ " res = torch.zeros(len(labels)).to(args['device'])\n",
+ " for _ in weight.values():\n",
+ " res = torch.add(res,_)\n",
+ " return res"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "R7bIsLs7G8eo"
+ },
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "HGT train.py\n",
+ "'''\n",
+ "from pprint import pprint\n",
+ "\n",
+ "import torch\n",
+ "from sklearn.metrics import f1_score, classification_report, confusion_matrix, accuracy_score, precision_score, \\\n",
+ " recall_score, precision_recall_fscore_support\n",
+ "\n",
+ "from model import *\n",
+ "import numpy as np\n",
+ "from scipy.sparse import coo_matrix\n",
+ "\n",
+ "\n",
+ "\n",
+ "def score(logits, labels):\n",
+ " _, indices = torch.max(logits, dim=1)\n",
+ " prediction = indices.long().cpu().numpy()\n",
+ " labels = labels.cpu().numpy()\n",
+ "\n",
+ " accuracy = (prediction == labels).sum() / len(prediction)\n",
+ " micro_f1 = f1_score(labels, prediction, average='micro')\n",
+ " macro_f1 = f1_score(labels, prediction, average='macro')\n",
+ "\n",
+ " return accuracy, micro_f1, macro_f1\n",
+ "\n",
+ "\n",
+ "def hgt_evaluate(model, g, t, labels, mask, loss_func):\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " logits = model(g, t)\n",
+ " loss = loss_func(logits[mask], F.one_hot(labels[mask], num_classes=2).float())\n",
+ " accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])\n",
+ " # indices = indices.cpu().numpy()\n",
+ " # labels = labels.cpu().numpy()\n",
+ " _, indices = torch.max(logits[mask], dim=1)\n",
+ " val_precision = precision_score(labels[mask].cpu().numpy(), indices.cpu().numpy(), average='macro')\n",
+ " val_recall = recall_score(labels[mask].cpu().numpy(), indices.cpu().numpy(), average='macro')\n",
+ "\n",
+ " return loss, accuracy, micro_f1, macro_f1, val_precision, val_recall\n",
+ "\n",
+ "\n",
+ "def hgt_test(model, g, t, labels, mask):\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " logits = model(g, t)\n",
+ " logits = logits[mask]\n",
+ " labels = labels[mask]\n",
+ " _, indices = torch.max(logits, dim=1)\n",
+ " '''\n",
+ " Plot output feature\n",
+ " '''\n",
+ " # feature_output1 = model.featuremap[mask].cpu()\n",
+ " # plot(feature_output1,indices.cpu(),labels.cpu())\n",
+ "\n",
+ " indices = indices.cpu().numpy()\n",
+ " labels = labels.cpu().numpy()\n",
+ " p,r,f1,s = precision_recall_fscore_support(labels,indices)\n",
+ "\n",
+ " print(classification_report(labels, indices, digits=4))\n",
+ " print(confusion_matrix(labels, indices))\n",
+ " return p[1],r[1],f1[1]\n",
+ "\n",
+ "\n",
+ "def get_n_params(model):\n",
+ " pp = 0\n",
+ " for p in list(model.parameters()):\n",
+ " nn = 1\n",
+ " for s in list(p.size()):\n",
+ " nn = nn * s\n",
+ " pp += nn\n",
+ " return pp\n",
+ "\n",
+ "def att_score(g,model):\n",
+ " '''\n",
+ " Using subgraph to calculate the website attention score matrix\n",
+ " '''\n",
+ " meta = model.meta\n",
+ " att_score = {}\n",
+ " for k, v in meta.items():\n",
+ " row, col = v['edges']\n",
+ " # row = row.cpu().numpy()\n",
+ " # col = col.cpu().numpy()\n",
+ " tmp = v['att_score']\n",
+ "\n",
+ " indices = torch.stack((row,col),0)\n",
+ " if k == 'wt':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('third')])\n",
+ " elif k == 'tw':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('third'), g.num_nodes('web')])\n",
+ " elif k == 'iw':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('ip'), g.num_nodes('web')])\n",
+ " elif k == 'wi':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('ip')])\n",
+ " elif k == 'wd':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('dns')])\n",
+ " elif k == 'dw':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('dns'), g.num_nodes('web')])\n",
+ " elif k == 'ew':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('email'), g.num_nodes('web')])\n",
+ " elif k == 'we':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('email')])\n",
+ " elif k == 'wr':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('reg')])\n",
+ " elif k == 'rw':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('reg'), g.num_nodes('web')])\n",
+ " elif k == 'wc':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('cert')])\n",
+ " elif k == 'cw':\n",
+ " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('cert'), g.num_nodes('web')])\n",
+ " # for k, v in att_score.items():\n",
+ " # print(v.shape)\n",
+ " re_att = {}\n",
+ " re_att['wt'] = torch.mm(att_score['wt'].to_dense(), att_score['tw'].to_dense())\n",
+ " # print(re_att)\n",
+ " re_att['wi'] = torch.mm(att_score['wi'].to_dense(), att_score['iw'].to_dense())\n",
+ " re_att['wd'] = torch.mm(att_score['wd'].to_dense(), att_score['dw'].to_dense())\n",
+ " re_att['wr'] = torch.mm(att_score['wr'].to_dense(), att_score['rw'].to_dense())\n",
+ " re_att['wc'] = torch.mm(att_score['wc'].to_dense(), att_score['cw'].to_dense())\n",
+ " re_att['we'] = torch.mm(att_score['we'].to_dense(), att_score['ew'].to_dense())\n",
+ " return re_att\n",
+ "\n",
+ "def get_train_node(train_idx,labels):\n",
+ " train_node=[[],[]]\n",
+ " for i in train_idx:\n",
+ " if labels[i]==1:\n",
+ " train_node[1].append(i)\n",
+ " else:\n",
+ " train_node[0].append(i)\n",
+ " return train_node\n",
+ "\n",
+ "\n",
+ "def get_binary_mask(total_size, indices):\n",
+ " mask = torch.zeros(total_size)\n",
+ " mask[indices] = 1\n",
+ " return mask.byte()\n",
+ "\n",
+ "\n",
+ "def shuffule(hg,labels):\n",
+ " float_mask = np.random.permutation(np.linspace(0, 1, len(labels)))\n",
+ " train_idx = np.where(float_mask <= 0.6)[0]\n",
+ " val_idx = np.where((float_mask > 0.6) & (float_mask <= 0.8))[0]\n",
+ " test_idx = np.where(float_mask > 0.8)[0]\n",
+ "\n",
+ " num_nodes = hg.number_of_nodes('web')\n",
+ " train_mask = get_binary_mask(num_nodes, train_idx)\n",
+ " val_mask = get_binary_mask(num_nodes, val_idx)\n",
+ " test_mask = get_binary_mask(num_nodes, test_idx)\n",
+ "\n",
+ " print('dataset loaded')\n",
+ " pprint({\n",
+ " 'dataset': 'Pirate',\n",
+ " 'train': train_mask.sum().item() / num_nodes,\n",
+ " 'val': val_mask.sum().item() / num_nodes,\n",
+ " 'test': test_mask.sum().item() / num_nodes\n",
+ " })\n",
+ " return train_idx, val_idx, test_idx, train_mask, val_mask, test_mask\n",
+ "\n",
+ "\n",
+ "def hgt_batch(args,model, g,train_idx, train_mask, val_mask, test_mask, labels, Pi,train_node):\n",
+ " stopper = EarlyStopping(patience=args['patience'])\n",
+ " if args['loss']=='BCE':\n",
+ " loss_fcn = torch.nn.BCEWithLogitsLoss()\n",
+ " if args['loss'] =='BCERN':\n",
+ " loss_fcn = torch.nn.BCEWithLogitsLoss(reduce=False)\n",
+ " if args['loss']=='Focal' or args['loss'] =='RN':\n",
+ " loss_fcn = FocalLossV2()\n",
+ " # loss_fcn = torch.nn.CrossEntropyLoss()\n",
+ "\n",
+ " optimizer = torch.optim.AdamW(model.parameters())\n",
+ "\n",
+ " for epoch in range(args['num_epochs']):\n",
+ " model.train()\n",
+ " logits = model(g, 'web')\n",
+ "\n",
+ " if args['loss']=='RN' or args['loss']=='BCERN':\n",
+ " re_att = att_score(g,model)\n",
+ " rn_weight = hg2rn(args,Pi,re_att,train_node,train_mask,labels)\n",
+ " '''\n",
+ " BCE_LOSS\n",
+ " '''\n",
+ " if args['loss'] == 'BCE':\n",
+ " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2).float())\n",
+ " '''\n",
+ " Focal_LOSS\n",
+ " '''\n",
+ " if args['loss'] == 'Focal':\n",
+ " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2)).mean()\n",
+ "\n",
+ " '''\n",
+ " Focal_LOSS with ReNode\n",
+ " '''\n",
+ " if args['loss'] == 'RN':\n",
+ " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2))\n",
+ " loss = torch.sum(loss * rn_weight[train_mask].to(args['device'])) / loss.size(0)\n",
+ " '''\n",
+ " BCE_LOSS with ReNode\n",
+ " '''\n",
+ " if args['loss'] == 'BCERN':\n",
+ " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2).float()).mean(dim=1)\n",
+ " loss = torch.sum(loss * rn_weight[train_mask].to(args['device'])) / loss.size(0)\n",
+ "\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])\n",
+ " val_loss, val_acc, val_micro_f1, val_macro_f1, val_precision, val_recall = hgt_evaluate(model, g, 'web', labels,\n",
+ " val_mask, loss_fcn)\n",
+ " early_stop = stopper.step(val_loss.mean().data.item(), val_macro_f1, model)\n",
+ "\n",
+ " # print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '\n",
+ " # 'Val Loss {:.4f} | Val Macro Precison {:.4f} | Val Macro Recall {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(\n",
+ " # epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.mean().item(), val_precision, val_recall,\n",
+ " # val_micro_f1, val_macro_f1))\n",
+ "\n",
+ " if early_stop:\n",
+ " break\n",
+ "\n",
+ " stopper.load_checkpoint(model)\n",
+ " print('Seed',args['seed'])\n",
+ " return hgt_test(model, g, 'web', labels, test_mask)\n",
+ "\n",
+ "\n",
+ "def hgt_train(args, G, features, labels, num_classes):\n",
+ " \n",
+ " train_idx, val_idx, test_idx, train_mask, val_mask, test_mask = shuffule(G,labels)\n",
+ "\n",
+ " Pi = hg2Pi(args,G)\n",
+ " # rn_weight = []\n",
+ " train_node = get_train_node(train_idx,labels)\n",
+ "\n",
+ " node_dict = {}\n",
+ " edge_dict = {}\n",
+ " for ntype in G.ntypes:\n",
+ " node_dict[ntype] = len(node_dict)\n",
+ " for etype in G.etypes:\n",
+ " edge_dict[etype] = len(edge_dict)\n",
+ " G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype]\n",
+ "\n",
+ " # Random initialize input feature\n",
+ " for ntype in G.ntypes:\n",
+ " G.nodes[ntype].data['inp'] = features[ntype]\n",
+ " # if ntype in ['web','third','ip','dns']:\n",
+ " # G.nodes[ntype].data['inp'] = features[ntype]\n",
+ " # else:\n",
+ " # emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), features['third'].shape[1]), requires_grad=False)\n",
+ " # nn.init.xavier_uniform_(emb)\n",
+ " # G.nodes[ntype].data['inp'] = emb\n",
+ "\n",
+ " G = G.to(args['device'])\n",
+ "\n",
+ " labels = labels.to(args['device'])\n",
+ " train_mask = train_mask.to(args['device'])\n",
+ " val_mask = val_mask.to(args['device'])\n",
+ " test_mask = test_mask.to(args['device'])\n",
+ "\n",
+ " # model = HGT(G,\n",
+ " # node_dict, edge_dict,\n",
+ " # n_inp=features['web'].shape[1],\n",
+ " # n_hid=args['n_hid'],\n",
+ " # n_out=labels.max().item() + 1,\n",
+ " # n_layers=args['n_layers'],\n",
+ " # n_heads=4,\n",
+ " # use_norm=True).to(args['device'])\n",
+ "\n",
+ " # print('Training HGT with #param: %d' % (get_n_params(model)))\n",
+ " # return hgt_batch(args,model, G,train_idx, train_mask, val_mask, test_mask, labels, Pi,train_node)\n",
+ "\n",
+ " model = HeteroRGCN(G,\n",
+ " in_size=features['web'].shape[1],\n",
+ " hidden_size=args['n_hid'],\n",
+ " out_size=labels.max().item() + 1).to(args['device'])\n",
+ " print('Training RGCN with #param: %d' % (get_n_params(model)))\n",
+ " return hgt_batch(args,model, G,train_idx, train_mask, val_mask, test_mask, labels, Pi,train_node)\n",
+ "\n",
+ " # model = HGT(G,\n",
+ " # node_dict, edge_dict,\n",
+ " # n_inp=features['web'].shape[1],\n",
+ " # n_hid=args['n_hid'],\n",
+ " # n_out=labels.max().item() + 1,\n",
+ " # n_layers=0,\n",
+ " # n_heads=4).to(args['device'])\n",
+ " # print('Training MLP with #param: %d' % (get_n_params(model)))\n",
+ " # hgt_batch(model, G,train_mask,val_mask,test_mask,labels,rn_weight)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jozCMUjGHG_n"
+ },
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "HGT main.py\n",
+ "'''\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
+ "import csv\n",
+ "import numpy as np\n",
+ "result = []\n",
+ "\n",
+ "args = {'hetero':True}\n",
+ "args['clip'] = 1.0\n",
+ "args['max_lr'] = 1e-3\n",
+ "args['n_hid'] = 256\n",
+ "args['num_epochs']=200\n",
+ "args['dataset'] = 'Pirate1RDE'\n",
+ "args['n_layers'] = 2\n",
+ "args['char_dim'] = 'eng_200'\n",
+ "\n",
+ "'''\n",
+ "renode params\n",
+ "'''\n",
+ "args['pagerank_prob'] = 0.85\n",
+ "args['rn_base_weight'] = 0.5\n",
+ "args['rn_scale_weight'] = 1.0\n",
+ "G, features, labels, num_classes = load_data(args['char_dim'], args['dataset'])\n",
+ "print(G)\n",
+ "\n",
+ "for i in range(100):\n",
+ " args['seed'] = i\n",
+ " '''\n",
+ " loss\n",
+ " '''\n",
+ " args['loss'] = 'BCE'\n",
+ " # args['loss'] = 'Focal'\n",
+ " # args['loss'] = 'RN'\n",
+ " # args['loss'] = 'BCERN'\n",
+ " args = setup(args)\n",
+ "\n",
+ " print(args)\n",
+ " # main(args)\n",
+ " tmp = list(hgt_train(args,G, features, labels, num_classes))\n",
+ " with open('./'+args['dataset']+'_'+args['loss']+'.csv','a+',encoding='utf-8') as csvfile:\n",
+ " write = csv.writer(csvfile)\n",
+ " write.writerow(tmp)\n",
+ " result.append(tmp)\n",
+ "for item in result:\n",
+ " print(item)\n",
+ "tmp = np.array(result)\n",
+ "ma = np.max(tmp,0)\n",
+ "me = np.mean(tmp,0)\n",
+ "mi = np.min(tmp,0)\n",
+ "print('Precision: {:.4}, +{:.4}, -{:.4}'.format(me[0],ma[0]-me[0],me[0]-mi[0]))\n",
+ "print('Recall: {:.4}, +{:.4}, -{:.4}'.format(me[1],ma[1]-me[1],me[1]-mi[1]))\n",
+ "print('F1: {:.4}, +{:.4}, -{:.4}'.format(me[2],ma[2]-me[2],me[2]-mi[2]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "wb-YeP-nsYNX"
+ },
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "HAN train.py\n",
+ "'''\n",
+ "import torch\n",
+ "from sklearn.metrics import f1_score, classification_report, confusion_matrix, accuracy_score, precision_score, \\\n",
+ " recall_score\n",
+ "\n",
+ "def han_evaluate(model, g, features, labels, mask, loss_func):\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " logits = model(g, features)\n",
+ " loss = loss_func(logits[mask], F.one_hot(labels[mask], num_classes = 2).float())\n",
+ " accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])\n",
+ "\n",
+ " return loss, accuracy, micro_f1, macro_f1\n",
+ "\n",
+ "def han_test(model, g, features, labels, mask):\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " logits = model(g, features)\n",
+ " logits = logits[mask]\n",
+ " labels = labels[mask]\n",
+ " _, indices = torch.max(logits, dim=1)\n",
+ " '''\n",
+ " Plot output feature\n",
+ " '''\n",
+ " # feature_output1 = model.featuremap[mask].cpu()\n",
+ " # plot(feature_output1,indices.cpu(),labels.cpu())\n",
+ "\n",
+ " indices = indices.cpu().numpy()\n",
+ " labels = labels.cpu().numpy()\n",
+ " p,r,f1,s = precision_recall_fscore_support(labels,indices)\n",
+ "\n",
+ " print(classification_report(labels, indices, digits=4))\n",
+ " print(confusion_matrix(labels, indices))\n",
+ " return p[1],r[1],f1[1]\n",
+ "\n",
+ "\n",
+ "def han_train(args, g, features, labels, num_classes):\n",
+ " # If args['hetero'] is True, g would be a heterogeneous graph.\n",
+ " # Otherwise, it will be a list of homogeneous graphs.\n",
+ " \n",
+ " train_idx, val_idx, test_idx, train_mask, val_mask, test_mask = shuffule(g,labels)\n",
+ "\n",
+ "\n",
+ " if hasattr(torch, 'BoolTensor'):\n",
+ " train_mask = train_mask.bool()\n",
+ " val_mask = val_mask.bool()\n",
+ " test_mask = test_mask.bool()\n",
+ "\n",
+ " features = features['web'].to(args['device'])\n",
+ "\n",
+ " labels = labels.to(args['device'])\n",
+ " train_mask = train_mask.to(args['device'])\n",
+ " val_mask = val_mask.to(args['device'])\n",
+ " test_mask = test_mask.to(args['device'])\n",
+ "\n",
+ " from model import HAN\n",
+ " model = HAN(meta_paths=[['wr', 'rw'],['wd', 'dw']],\n",
+ " in_size=features.shape[1],\n",
+ " hidden_size=args['hidden_units'],\n",
+ " out_size=num_classes,\n",
+ " # out_size=num_classes,\n",
+ " num_heads=args['num_heads'],\n",
+ " dropout=args['dropout']).to(args['device'])\n",
+ " g = g.to(args['device'])\n",
+ "\n",
+ "\n",
+ " stopper = EarlyStopping(patience=args['patience'])\n",
+ " loss_fcn = torch.nn.BCEWithLogitsLoss()\n",
+ " optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],\n",
+ " weight_decay=args['weight_decay'])\n",
+ "\n",
+ " for epoch in range(args['num_epochs']):\n",
+ " model.train()\n",
+ " logits = model(g, features)\n",
+ " # loss = loss_fcn(logits[train_mask], labels[train_mask])\n",
+ " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2).float())\n",
+ "\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])\n",
+ " val_loss, val_acc, val_micro_f1, val_macro_f1 = han_evaluate(model, g, features, labels, val_mask, loss_fcn)\n",
+ " early_stop = stopper.step(val_loss.mean().data.item(), val_macro_f1, model)\n",
+ "\n",
+ " # print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '\n",
+ " # 'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(\n",
+ " # epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.mean().item(), val_micro_f1, val_macro_f1))\n",
+ "\n",
+ " if early_stop:\n",
+ " break\n",
+ "\n",
+ " stopper.load_checkpoint(model)\n",
+ " print('Seed',args['seed'])\n",
+ " return hgt_test(model, g, features, labels, test_mask)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ORw6BM_pTO6Q"
+ },
+ "outputs": [],
+ "source": [
+ "'''\n",
+ "HAN main.py\n",
+ "'''\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
+ "import csv\n",
+ "import numpy as np\n",
+ "result = []\n",
+ "\n",
+ "args = {'hetero':True}\n",
+ "args['clip'] = 1.0\n",
+ "args['max_lr'] = 1e-3\n",
+ "args['n_hid'] = 256\n",
+ "args['num_epochs']=200\n",
+ "args['dataset'] = 'Pirate1RDE'\n",
+ "args['n_layers'] = 2\n",
+ "args['char_dim'] = 'eng_200'\n",
+ "\n",
+ "'''\n",
+ "renode params\n",
+ "'''\n",
+ "args['pagerank_prob'] = 0.85\n",
+ "args['rn_base_weight'] = 0.5\n",
+ "args['rn_scale_weight'] = 1.0\n",
+ "G, features, labels, num_classes = load_data(args['char_dim'], args['dataset'])\n",
+ "print(G)\n",
+ "\n",
+ "for i in range(100):\n",
+ " args['seed'] = i\n",
+ " '''\n",
+ " loss\n",
+ " '''\n",
+ " args['loss'] = 'BCE'\n",
+ " # args['loss'] = 'Focal'\n",
+ " # args['loss'] = 'RN'\n",
+ " # args['loss'] = 'BCERN'\n",
+ " args = setup(args)\n",
+ "\n",
+ " print(args)\n",
+ " # main(args)\n",
+ " tmp = list(han_train(args,G, features, labels, num_classes))\n",
+ " with open('./'+args['dataset']+'_'+args['loss']+'.csv','a+',encoding='utf-8') as csvfile:\n",
+ " write = csv.writer(csvfile)\n",
+ " write.writerow(tmp)\n",
+ " result.append(tmp)\n",
+ "for item in result:\n",
+ " print(item)\n",
+ "tmp = np.array(result)\n",
+ "ma = np.max(tmp,0)\n",
+ "me = np.mean(tmp,0)\n",
+ "mi = np.min(tmp,0)\n",
+ "print('Precision: {:.4}, +{:.4}, -{:.4}'.format(me[0],ma[0]-me[0],me[0]-mi[0]))\n",
+ "print('Recall: {:.4}, +{:.4}, -{:.4}'.format(me[1],ma[1]-me[1],me[1]-mi[1]))\n",
+ "print('F1: {:.4}, +{:.4}, -{:.4}'.format(me[2],ma[2]-me[2],me[2]-mi[2]))"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "graph.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}