{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "cBo19DhI7uoD" }, "outputs": [], "source": [ "%%capture\n", "# !git clone https://github.com/joerg84/Graph_Powered_ML_Workshop.git\n", "# !rsync -av Graph_Powered_ML_Workshop/ ./ --exclude=.git\n", "!pip3 install dgl-cu110 -f https://data.dgl.ai/wheels/repo.html\n", "# !pip3 install dgl-cu102 -f https://data.dgl.ai/wheels/repo.html\n", "!pip3 install numpy\n", "!pip3 install torch\n", "!pip3 install networkx\n", "!pip3 install matplotlib\n", "!pip3 install scikit-learn\n", "!pip3 install chars2vec\n", "!pip3 install keras\n", "!pip3 install pandas" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OX7Fi6esQzct", "outputId": "7a959969-b1b5-474f-e147-9591cb9c05d9" }, "outputs": [], "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')\n", "!mkdir ./data/\n", "!mkdir ./graph/\n", "!mkdir ./pagerank/\n", "!mkdir ./model\n", "\n", "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_1k_labels.csv ./data/\n", "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_3k_labels.csv ./data/\n", "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_5k_labels.csv ./data/\n", "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_1w_labels.csv ./data/\n", "!cp ./drive/MyDrive/Colab\\ Notebooks/han/data/fix_whois.csv ./data/\n", "!cp ./drive/MyDrive/Colab\\ Notebooks/han/model.py ./" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 541 }, "id": "CmJZXrpVHod1", "outputId": "573fe488-fe6d-4b7b-c16a-cfad16e5b8b3" }, "outputs": [], "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "XnQWwdumsSdD", "outputId": "f0b3f413-b629-4d29-e678-cc758f14a4e0" }, "outputs": [], "source": [ "'''\n", "load pirate data\n", "'''\n", "import re\n", "from pprint import pprint\n", "\n", "import re\n", "from pprint import pprint\n", "\n", "import chars2vec\n", "import dgl\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from keras_preprocessing.sequence import pad_sequences\n", "from keras_preprocessing.text import Tokenizer\n", "\n", "\n", "def get_raw_data(path):\n", " raw = pd.read_csv(path)\n", " reg_dic, dns_dic, email_dic = get_whois()\n", " print('Load Pirate Dataset: {}'.format(path))\n", " in_feat = {}\n", " website = {}\n", " third = {}\n", " w_ip = {}\n", " t_ip = {}\n", " cert = {}\n", " reg = {}\n", " dns = {}\n", " email = {}\n", " wi = 0\n", " ti = 0\n", " wpi = 0\n", " tpi = 0\n", " ci = 0\n", " ri = 0\n", " di = 0\n", " ei = 0\n", " edge_w_t = []\n", " edge_w_i = []\n", " edge_w_c = []\n", " edge_t_c = []\n", " edge_t_i = []\n", " edge_w_r = []\n", " edge_w_d = []\n", " edge_w_e = []\n", " labels = []\n", "\n", " '''\n", " raw data for website's ndoe embedding\n", " '''\n", " raw_reg = []\n", " raw_dns = []\n", " raw_email = []\n", "\n", " '''\n", " raw data for graph's node embedding\n", " '''\n", " node_domain = []\n", " node_third = []\n", " node_w_ip = []\n", " node_t_ip = []\n", " node_cert = []\n", " node_email = []\n", " node_reg = []\n", " node_dns = []\n", "\n", "\n", " for index, row in raw.iterrows():\n", " t = str(row['req']).split('/')[2].strip('www.')\n", " i = str(row['ip']).strip()\n", " w = str(row['website']).strip().split(':')[0]\n", " c = str(row['cert']).strip()\n", " if t != w:\n", " if t not in third:\n", " third[t] = ti\n", " node_third.append(t)\n", " ti += 1\n", " if i not in t_ip and i != 'nan':\n", " t_ip[i] = tpi\n", " node_t_ip.append(i)\n", " tpi += 1\n", " else:\n", " if i not in w_ip and i != 'nan':\n", " w_ip[i] = wpi\n", " node_w_ip.append(i)\n", " wpi += 1\n", " if w not in website:\n", " labels.append(int(row['label']))\n", " node_domain.append(w)\n", " raw_reg.append(reg_dic[w])\n", " raw_dns.append(dns_dic[w])\n", " raw_email.append(email_dic[w])\n", " website[w] = wi\n", " wi+=1\n", " if reg_dic[w] not in reg and reg_dic[w] != 'nan':\n", " reg[reg_dic[w]] = ri\n", " node_reg.append(reg_dic[w])\n", " ri += 1\n", " if dns_dic[w] != 'nan':\n", " for _ in dns_dic[w].split(','):\n", " __ = _.strip().lower()\n", " if __ not in dns:\n", " dns[__] = di\n", " node_dns.append(__)\n", " di += 1\n", " if email_dic[w] not in email and email_dic[w] != 'nan':\n", " email[email_dic[w]] = ei\n", " node_email.append(email_dic[w])\n", " ei += 1\n", " in_feat[w] = [0, 0, 0]\n", " if c not in cert and c != 'nan':\n", " cert[c] = ci\n", " node_cert.append(c)\n", " ci += 1\n", "\n", " in_feat[w][0] += 1\n", " if t != w:\n", " edge_w_t.append(tuple([website[w], third[t]]))\n", " if i != 'nan':\n", " edge_t_i.append(tuple([third[t], t_ip[i]]))\n", " in_feat[w][1] += 1\n", " if c != 'nan':\n", " edge_t_c.append(tuple([third[t], cert[c]]))\n", " if t == w:\n", " if i != 'nan':\n", " edge_w_i.append(tuple([website[w], w_ip[i]]))\n", " if reg_dic[w] != 'nan':\n", " edge_w_r.append(tuple([website[w], reg[reg_dic[w]]]))\n", " if dns_dic[w] != 'nan':\n", " for _ in dns_dic[w].split(','):\n", " __ = _.strip().lower()\n", " edge_w_d.append(tuple([website[w], dns[__]]))\n", " if email_dic[w] != 'nan':\n", " edge_w_e.append(tuple([website[w], email[email_dic[w]]]))\n", " if c != 'nan':\n", " edge_w_c.append(tuple([website[w], cert[c]]))\n", " in_feat[w][2] += 1\n", "\n", "\n", " edge_w_i = set(edge_w_i)\n", " edge_w_t = set(edge_w_t)\n", " edge_w_c = set(edge_w_c)\n", " edge_t_c = set(edge_t_c)\n", " edge_t_i = set(edge_t_i)\n", " edge_w_r = set(edge_w_r)\n", " edge_w_d = set(edge_w_d)\n", " edge_w_e = set(edge_w_e)\n", "\n", "\n", "\n", " print('Node of Websites: ', len(website))\n", " print('Node of Third-party Services: ', len(third))\n", " print('Node of IP: ', len(w_ip))\n", " print('Node of Cert: ', len(cert))\n", " print('Node of Registrant: ',len(reg))\n", " print('Node of DNS: ', len(dns))\n", " print('Node of Email: ', len(email))\n", " print('Edge of Website to Third-party: ', len(edge_w_t))\n", " print('Edge of Webiste to IP: ', len(edge_w_i))\n", " print('Edge of Website to Certificate: ', len(edge_w_c))\n", " print('Edge of Website to Registrant: ', len(edge_w_r))\n", " print('Edge of Webiste to DNS Server: ', len(edge_w_d))\n", " print('Edge of Website to Email: ', len(edge_w_e))\n", " print('Edge of Third-party to Certificate: ', len(edge_t_c))\n", " print('Edge of Third-party to IP: ', len(edge_t_i))\n", " print('labels: ', len(labels))\n", " web_tp_src = []\n", " web_tp_dst = []\n", " web_ip_src = []\n", " web_ip_dst = []\n", " web_cert_src = []\n", " web_cert_dst = []\n", " third_cert_src = []\n", " third_cert_dst = []\n", " third_ip_src = []\n", " third_ip_dst = []\n", " web_reg_src = []\n", " web_reg_dst = []\n", " web_dns_src = []\n", " web_dns_dst = []\n", " web_email_src = []\n", " web_email_dst = []\n", " for item in edge_w_t:\n", " web_tp_src.append(item[0])\n", " web_tp_dst.append(item[1])\n", " for item in edge_w_i:\n", " web_ip_src.append(item[0])\n", " web_ip_dst.append(item[1])\n", " for item in edge_w_c:\n", " web_cert_src.append(item[0])\n", " web_cert_dst.append(item[1])\n", " for item in edge_t_c:\n", " third_cert_src.append(item[0])\n", " third_cert_dst.append(item[1])\n", " for item in edge_t_i:\n", " third_ip_src.append(item[0])\n", " third_ip_dst.append(item[1])\n", " for item in edge_w_r:\n", " web_reg_src.append(item[0])\n", " web_reg_dst.append(item[1])\n", " for item in edge_w_d:\n", " web_dns_src.append(item[0])\n", " web_dns_dst.append(item[1])\n", " for item in edge_w_e:\n", " web_email_src.append(item[0])\n", " web_email_dst.append(item[1])\n", " # hg = dgl.heterograph({\n", " # ('web', 'wt', 'third'): (web_tp_src, web_tp_dst),\n", " # ('third', 'tw', 'web'): (web_tp_dst, web_tp_src)},{'web':len(website),'third':len(third),'ip':len(ip),'cert':len(cert)})\n", " hg = dgl.heterograph({\n", " ('web', 'wt', 'third'): (web_tp_src, web_tp_dst),\n", " ('third', 'tw', 'web'): (web_tp_dst, web_tp_src),\n", " ('web', 'wi', 'ip'): (web_ip_src, web_ip_dst),\n", " ('ip', 'iw', 'web'): (web_ip_dst, web_ip_src),\n", " ('web', 'wc', 'cert'): (web_cert_src, web_cert_dst),\n", " ('cert', 'cw', 'web'): (web_cert_dst, web_cert_src),\n", " # ('third', 'tc', 'cert'): (third_cert_src, third_cert_dst),\n", " # ('cert', 'ct', 'third'): (third_cert_dst, third_cert_src),\n", " # ('third', 'ti', 'ip'): (third_ip_src, third_ip_dst),\n", " # ('ip', 'it', 'third'): (third_ip_dst, third_ip_src),\n", " ('web', 'wr', 'reg'): (web_reg_src, web_reg_dst),\n", " ('reg', 'rw', 'web'): (web_reg_dst, web_reg_src),\n", " ('web', 'we', 'email'): (web_email_src, web_email_dst),\n", " ('email', 'ew', 'web'): (web_email_dst, web_email_src),\n", " ('web', 'wd', 'dns'): (web_dns_src, web_dns_dst),\n", " ('dns', 'dw', 'web'): (web_dns_dst, web_dns_src)\n", " },\n", " # {'web':len(website),'ip':len(w_ip)}\n", " {'web':len(website),'third':len(third),'ip':len(w_ip),'dns':len(dns),'cert':len(cert),'reg':len(reg),'email':len(email)}\n", " )\n", "\n", " labels = torch.LongTensor(labels)\n", "\n", " '''\n", " one-hot features\n", " '''\n", " reg_oh = [[0 for _ in range(len(reg))] for __ in range(len(website))]\n", " email_oh = [[0 for _ in range(len(email))] for __ in range(len(website))]\n", " dns_oh = [[0 for _ in range(len(dns))] for __ in range(len(website))]\n", "\n", "\n", " for i in range(len(raw_reg)):\n", " if raw_reg[i]!='nan':\n", " reg_oh[i][reg[raw_reg[i]]] = 1\n", " for i in range(len(raw_email)):\n", " if raw_email[i] != 'nan':\n", " email_oh[i][email[raw_email[i]]] = 1\n", " for i in range(len(raw_dns)):\n", " for _ in raw_dns[i].split(','):\n", " __ = _.strip().lower()\n", " if __ != 'nan':\n", " dns_oh[i][dns[__]]=1\n", "\n", " reg_oh = np.array(reg_oh)\n", " email_oh = np.array(email_oh)\n", " dns_oh = np.array(dns_oh)\n", "\n", "\n", " in_features = []\n", " for item in node_domain:\n", " in_features.append(in_feat[item])\n", "\n", " data = {'labels': labels, 'node_domain': node_domain, 'raw_reg': raw_reg,\n", " 'raw_dns': raw_reg, 'raw_email': raw_email,\n", " 'hg': hg, 'in_features': np.array(in_features),'node_third':node_third,'node_w_ip':node_w_ip,'node_cert':node_cert,\n", " 'node_reg':node_reg,'node_dns':node_dns,'node_email':node_email\n", " }\n", " import pickle\n", " pickle.dump(data, open('./graph/WTICRDE_1k.pkl', 'wb'))\n", "\n", "\n", "def get_whois():\n", " raw = pd.read_csv('./data/fix_whois.csv')\n", " dic = {}\n", " website = {}\n", " email = {}\n", " dns = {}\n", " for web,x,d,e in zip(raw.website,raw.reg,raw.dns,raw.email):\n", " dns.setdefault(web,str(d))\n", " email.setdefault(web,str(e))\n", " tmp = str(x).strip().strip('.')\n", " if re.match(r'.*GoDaddy',tmp):\n", " dic.setdefault('GoDaddy.com, LLC',0)\n", " website.setdefault(web,'GoDaddy.com, LLC')\n", " elif re.match(r'.*NameSilo',tmp) or re.match(r'.*Namesilo',tmp):\n", " dic.setdefault('NameSilo, LLC',0)\n", " website.setdefault(web, 'NameSilo, LLC')\n", " elif re.match(r'.*Alibaba',tmp) or re.match(r'.*阿里巴巴',tmp) or re.match(r'.*阿里云',tmp):\n", " dic.setdefault('Alibaba Cloud Computing (Beijing) Co., Ltd',0)\n", " website.setdefault(web, 'Alibaba Cloud Computing (Beijing) Co., Ltd')\n", " elif re.match(r'.*新网',tmp) or re.match(r'Xin',tmp):\n", " dic.setdefault('Xin Net Technology Corporation', 0)\n", " website.setdefault(web, 'Xin Net Technology Corporation')\n", " elif re.match(r'NameCheap',tmp):\n", " dic.setdefault('NameCheap, Inc',0)\n", " website.setdefault(web, 'NameCheap, Inc')\n", " elif re.match(r'.*DropCatch',tmp):\n", " dic.setdefault('DropCatch.com, Inc',0)\n", " website.setdefault(web, 'DropCatch.com, Inc')\n", " elif re.match(r'.*Chengdu west',tmp) or re.match(r'.*Chengdu West',tmp) or re.match(r'.*成都西维',tmp):\n", " dic.setdefault('Chengdu West Dimension Digital Technology Co., Ltd',0)\n", " website.setdefault(web, 'Chengdu West Dimension Digital Technology Co., Ltd')\n", " elif re.match(r'.*Xiamen 35',tmp) or re.match(r'.*厦门三五',tmp):\n", " dic.setdefault('Xiamen 35.Com Technology Co., Ltd',0)\n", " website.setdefault(web, 'Xiamen 35.Com Technology Co., Ltd')\n", " elif re.match(r'.*1API',tmp):\n", " dic.setdefault('1API GmbH',0)\n", " website.setdefault(web, '1API GmbH')\n", " elif re.match(r'.*中央编办',tmp):\n", " dic.setdefault('China Gov',0)\n", " website.setdefault(web, 'China Gov')\n", " elif re.match(r'.*浙江贰',tmp):\n", " dic.setdefault('Zhejiang ErEr Technology Co., Ltd',0)\n", " website.setdefault(web, 'Zhejiang ErEr Technology Co., Ltd')\n", " elif re.match(r'.*Dynadot',tmp):\n", " dic.setdefault('Dynadot, LLC',0)\n", " website.setdefault(web, 'Dynadot, LLC')\n", " elif re.match(r'.*广州云',tmp):\n", " dic.setdefault('Guangzhou Cloud Technology Co., Ltd',0)\n", " website.setdefault(web, 'Guangzhou Cloud Technology Co., Ltd')\n", " elif re.match(r'.*厦门纳网',tmp):\n", " dic.setdefault('Xiamen NaNetwork Technology Co., Ltd',0)\n", " website.setdefault(web, 'Xiamen NaNetwork Technology Co., Ltd')\n", " elif re.match(r'.*厦门易名',tmp):\n", " dic.setdefault('Xiamen Easy Name Technology Co., Ltd',0)\n", " website.setdefault(web, 'Xiamen Easy Name Technology Co., Ltd')\n", " elif re.match(r'.*四川域趣',tmp):\n", " dic.setdefault('Sichuan Fun Technology Co., Ltd',0)\n", " website.setdefault(web, 'Sichuan Fun Technology Co., Ltd')\n", " elif re.match(r'.*成都飞数',tmp):\n", " dic.setdefault('Chengdu Fly Digital Technology Co., Ltd',0)\n", " website.setdefault(web, 'Chengdu Fly Digital Technology Co., Ltd')\n", " elif re.match(r'.*北京光速',tmp):\n", " dic.setdefault('Beijing LightSpeed Technology Co., Ltd',0)\n", " website.setdefault(web, 'Beijing LightSpeed Technology Co., Ltd')\n", " elif re.match(r'.*上海福虎',tmp):\n", " dic.setdefault('Shanghai Tiger Technology Co., Ltd',0)\n", " website.setdefault(web, 'Shanghai Tiger Technology Co., Ltd')\n", " elif re.match(r'.*广东时代',tmp):\n", " dic.setdefault('Guangdong TimeInternet Technology Co., Ltd',0)\n", " website.setdefault(web, 'Guangdong TimeInternet Technology Co., Ltd')\n", " elif re.match(r'.*成都世纪',tmp):\n", " dic.setdefault('Chengdu Century Oriental Network Communication Co., Ltd',0)\n", " website.setdefault(web, 'Chengdu Century Oriental Network Communication Co., Ltd')\n", " elif re.match(r'.*北京国旭',tmp):\n", " dic.setdefault('Beijing Guoxu Network Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Beijing Guoxu Network Technology Co., Ltd')\n", " elif re.match(r'.*北京中科',tmp):\n", " dic.setdefault('Beijing Sanfront Information Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Beijing Sanfront Information Technology Co., Ltd')\n", " elif re.match(r'.*北京东方',tmp):\n", " dic.setdefault('Beijing East Network Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Beijing East Network Technology Co., Ltd')\n", " elif re.match(r'.*商中在线',tmp):\n", " dic.setdefault('Shangzhong Online Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Shangzhong Online Technology Co., Ltd')\n", " elif re.match(r'.*中企动力',tmp):\n", " dic.setdefault('Zhong Qi Dong Li Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Zhong Qi Dong Li Technology Co., Ltd')\n", " elif re.match(r'.*厦门市中资',tmp):\n", " dic.setdefault('Xiamen ChinaSource Internet Service Co., Ltd', 0)\n", " website.setdefault(web, 'Xiamen ChinaSource Internet Service Co., Ltd')\n", " elif re.match(r'.*广东互易',tmp):\n", " dic.setdefault('Guangdong HuYi Network Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Guangdong HuYi Network Technology Co., Ltd')\n", " elif re.match(r'.*遵义中域',tmp):\n", " dic.setdefault('Zunyi Mid Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Zunyi Mid Technology Co., Ltd')\n", " elif re.match(r'.*北京首信网',tmp):\n", " dic.setdefault('Beijing Trust Network Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Beijing Trust Network Technology Co., Ltd')\n", " elif re.match(r'.*赛尔网络',tmp):\n", " dic.setdefault('Saier Network Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Saier Network Technology Co., Ltd')\n", " elif re.match(r'.*遵义中域',tmp):\n", " dic.setdefault('Zunyi Mid Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Zunyi Mid Technology Co., Ltd')\n", " elif re.match(r'.*烟台帝思普',tmp):\n", " dic.setdefault('Yantai DiSiPu Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Yantai DiSiPu Technology Co., Ltd')\n", " elif re.match(r'.*北京万维',tmp):\n", " dic.setdefault('Beijing WanWei Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Beijing WanWei Technology Co., Ltd')\n", " elif re.match(r'.*江苏邦宁',tmp):\n", " dic.setdefault('Jiangsu Bangning Science and technology Co. Ltd', 0)\n", " website.setdefault(web, 'Jiangsu Bangning Science and technology Co. Ltd')\n", " elif re.match(r'.*北京神州',tmp):\n", " dic.setdefault('Beijing ShenZhou technology Co. Ltd', 0)\n", " website.setdefault(web, 'Beijing ShenZhou technology Co. Ltd')\n", " elif re.match(r'.*北京神州',tmp):\n", " dic.setdefault('Beijing ShenZhou technology Co. Ltd', 0)\n", " website.setdefault(web, 'Beijing ShenZhou technology Co. Ltd')\n", " elif re.match(r'.*上海美橙',tmp):\n", " dic.setdefault('Shanghai Meicheng Technology Information Development Co., Ltd', 0)\n", " website.setdefault(web, 'Shanghai Meicheng Technology Information Development Co., Ltd')\n", " elif re.match(r'.*广东金万',tmp):\n", " dic.setdefault('Guangdong JinWanBang Technology Investment Co., Ltd', 0)\n", " website.setdefault(web, 'Guangdong JinWanBang Technology Investment Co., Ltd')\n", " elif re.match(r'.*易介集团',tmp):\n", " dic.setdefault('YiJie Beijing technology Co. Ltd', 0)\n", " website.setdefault(web, 'YiJie Beijing technology Co. Ltd')\n", " elif re.match(r'.*佛山市亿动',tmp):\n", " dic.setdefault('Foshan YiDong Network Co., LTD', 0)\n", " website.setdefault(web, 'Foshan YiDong Network Co., LTD')\n", " elif re.match(r'.*山东开创',tmp):\n", " dic.setdefault('Shandong KaiChuang technology Co. Ltd', 0)\n", " website.setdefault(web, 'Shandong KaiChuang technology Co. Ltd')\n", " elif re.match(r'.*泛亚信息',tmp):\n", " dic.setdefault('FanYa Jiangsu technology Co. Ltd', 0)\n", " website.setdefault(web, 'FanYa Jiangsu technology Co. Ltd')\n", " elif re.match(r'.*杭州电商',tmp):\n", " dic.setdefault('Hangzhou Dianshang Internet Technology Co., LTD', 0)\n", " website.setdefault(web, 'Hangzhou Dianshang Internet Technology Co., LTD')\n", " elif re.match(r'.*深圳市万维',tmp):\n", " dic.setdefault('Shenzhen WanWei technology Co. Ltd', 0)\n", " website.setdefault(web, 'Shenzhen WanWei technology Co. Ltd')\n", " elif re.match(r'.*天津追日',tmp):\n", " dic.setdefault('Tianjin Zhuiri Science and Technology Development Co Ltd', 0)\n", " website.setdefault(web, 'Tianjin Zhuiri Science and Technology Development Co Ltd')\n", " elif re.match(r'.*北京宏网',tmp):\n", " dic.setdefault('Beijing HongWang technology Co. Ltd', 0)\n", " website.setdefault(web, 'Beijing HongWang technology Co. Ltd')\n", " elif re.match(r'.*昆明乐网',tmp):\n", " dic.setdefault('Kunming LeWang technology Co. Ltd', 0)\n", " website.setdefault(web, 'Kunming LeWang technology Co. Ltd')\n", " elif re.match(r'.*上海热线',tmp):\n", " dic.setdefault('Shanghai ReXian technology Co. Ltd', 0)\n", " website.setdefault(web, 'Shanghai ReXian technology Co. Ltd')\n", " elif re.match(r'.*互联网域名',tmp):\n", " dic.setdefault('Internet Domain Name System Beijing Engineering Research Center LLC (ZDNS)', 0)\n", " website.setdefault(web, 'Internet Domain Name System Beijing Engineering Research Center LLC (ZDNS)')\n", " elif re.match(r'.*广东耐思',tmp):\n", " dic.setdefault('SGuangdong Nicenic Technology Co., Ltd. dba NiceNIC', 0)\n", " website.setdefault(web, 'Guangdong Nicenic Technology Co., Ltd. dba NiceNIC')\n", " elif re.match(r'.*重庆智佳',tmp):\n", " dic.setdefault('Chongqing ZhiJia Technology Co., Ltd', 0)\n", " website.setdefault(web, 'Chongqing ZhiJia Technology Co., Ltd')\n", " else:\n", " dic.setdefault(tmp,0)\n", " website.setdefault(web, tmp)\n", " return website,dns,email\n", "\n", "def domain_to_index(domain, length):\n", " \"\"\"\n", "\n", " :param domain: the domain data\n", " :param length: the csv data's length\n", " :return: index_vector of the domain\n", " \"\"\"\n", "\n", " tokenizer = Tokenizer(char_level=True)\n", " tokenizer.fit_on_texts(domain)\n", " word_index = tokenizer.word_index\n", " sequences = tokenizer.texts_to_sequences(domain)\n", "\n", " # padding\n", " data = pad_sequences(sequences, maxlen=30)\n", " # print('Found %s unique tokens.' % len(word_index))\n", " # print(data)\n", " print('Start Transform Character to index')\n", " return np.array(data, dtype=\"float32\").reshape(length, 30), word_index\n", "\n", "def load_pirate(path):\n", " return get_raw_data(path)\n", "\n", "\n", "def load_pkl(dim,path):\n", " print('Load From PKL:',path)\n", " import pickle\n", " data = pickle.load(open(path, 'rb'))\n", " hg = data['hg']\n", "\n", " labels = data['labels']\n", "\n", " '''\n", " character-embedding features\n", " '''\n", " # features, _ = domain_to_index(raw_domain, len(raw_domain))\n", " c2v_model = \\\n", " chars2vec.load_model(dim)\n", " website_features = c2v_model.vectorize_words(data['node_domain'])\n", " web_reg_features = c2v_model.vectorize_words(data['raw_reg'])\n", " web_dns_features = c2v_model.vectorize_words(data['raw_dns'])\n", " web_email_features = c2v_model.vectorize_words(data['raw_email'])\n", "\n", " third_features = c2v_model.vectorize_words(data['node_third'])\n", " ip_features = c2v_model.vectorize_words(data['node_w_ip'])\n", " cert_features = c2v_model.vectorize_words(data['node_cert'])\n", " reg_features = c2v_model.vectorize_words(data['node_reg'])\n", " dns_features = c2v_model.vectorize_words(data['node_dns'])\n", " email_features = c2v_model.vectorize_words(data['node_email'])\n", "\n", " features = {\n", " 'web': torch.FloatTensor(website_features),\n", " 'third': torch.FloatTensor(third_features),\n", " 'ip': torch.FloatTensor(ip_features),\n", " 'cert': torch.FloatTensor(cert_features),\n", " 'email': torch.FloatTensor(email_features),\n", " 'dns':torch.FloatTensor(dns_features),\n", " 'reg':torch.FloatTensor(reg_features),\n", " }\n", "\n", " pprint({\n", " 'Web Features:':features['web'].shape,\n", " 'Third Features:': features['third'].shape,\n", " 'IP Features:': features['ip'].shape,\n", " 'Cert Features:': features['cert'].shape,\n", " 'Reg Features:': features['reg'].shape,\n", " 'DNS Features:': features['dns'].shape,\n", " 'Email Features:': features['email'].shape\n", " })\n", "\n", " in_features = data['in_features']\n", " num_classes = 2\n", "\n", "\n", " # plot(features,labels)\n", "\n", " return hg, features, labels, num_classes\n", "\n", "load_pirate('./data/fix_1k_labels.csv')\n", "# load_pkl('eng_50','./graph/WTICRDE_1w.pkl')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6aPVZfwurEP" }, "outputs": [], "source": [ "'''\n", "utils.py\n", "'''\n", "import datetime\n", "import dgl\n", "import errno\n", "import numpy as np\n", "import os\n", "import pickle\n", "import random\n", "import torch\n", "\n", "from pprint import pprint\n", "from scipy import sparse\n", "from scipy import io as sio\n", "\n", "\n", "def set_random_seed(seed=0):\n", " \"\"\"Set random seed.\n", " Parameters\n", " ----------\n", " seed : int\n", " Random seed to use\n", " \"\"\"\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " if torch.cuda.is_available():\n", " torch.cuda.manual_seed(seed)\n", "\n", "def mkdir_p(path, log=True):\n", " \"\"\"Create a directory for the specified path.\n", " Parameters\n", " ----------\n", " path : str\n", " Path name\n", " log : bool\n", " Whether to print result for directory creation\n", " \"\"\"\n", " try:\n", " os.makedirs(path)\n", " if log:\n", " print('Created directory {}'.format(path))\n", " except OSError as exc:\n", " if exc.errno == errno.EEXIST and os.path.isdir(path) and log:\n", " print('Directory {} already exists.'.format(path))\n", " else:\n", " raise\n", "\n", "def get_date_postfix():\n", " \"\"\"Get a date based postfix for directory name.\n", " Returns\n", " -------\n", " post_fix : str\n", " \"\"\"\n", " dt = datetime.datetime.now()\n", " post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(\n", " dt.date(), dt.hour, dt.minute, dt.second)\n", "\n", " return post_fix\n", "\n", "def setup_log_dir(args, sampling=False):\n", " \"\"\"Name and create directory for logging.\n", " Parameters\n", " ----------\n", " args Configuration\n", " Returns\n", " -------\n", " log_dir : str\n", " Path for logging directory\n", " sampling : bool\n", " Whether we are using sampling based training\n", " \"\"\"\n", " date_postfix = get_date_postfix()\n", " log_dir = os.path.join(\n", " args['log_dir'],\n", " '{}_{}'.format(args['dataset'], date_postfix))\n", "\n", " if sampling:\n", " log_dir = log_dir + '_sampling'\n", "\n", " mkdir_p(log_dir)\n", " return log_dir\n", "\n", "# The configuration below is from the paper.\n", "default_configure = {\n", " 'lr': 0.005, # Learning rate\n", " 'num_heads': [8], # Number of attention heads for node-level attention\n", " 'hidden_units': 8,\n", " 'dropout': 0.6,\n", " 'weight_decay': 0.001,\n", " 'patience': 100\n", "}\n", "\n", "sampling_configure = {\n", " 'batch_size': 20\n", "}\n", "\n", "def setup(args):\n", " args.update(default_configure)\n", " set_random_seed(args['seed'])\n", " args['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'\n", " return args\n", "\n", "def setup_for_sampling(args):\n", " args.update(default_configure)\n", " args.update(sampling_configure)\n", " set_random_seed()\n", " args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", " return args\n", "\n", "def get_binary_mask(total_size, indices):\n", " mask = torch.zeros(total_size)\n", " mask[indices] = 1\n", " return mask.byte()\n", "\n", "\n", "def load_data(dim,dataset, remove_self_loop=False):\n", " print('Load Dataset:',dataset)\n", " if dataset == 'Pirate1RDE':\n", " return load_pkl(dim,'graph/WTICRDE_1k.pkl')\n", " elif dataset == 'Pirate3':\n", " return load_pkl(dim,'graph/WTIC_3k.pkl')\n", " elif dataset == 'Pirate10':\n", " return load_pkl(dim,'graph/WTIC_1w.pkl')\n", " elif dataset == 'Pirate10RDE':\n", " return load_pkl(dim,'graph/WTICRDE_1w.pkl')\n", " elif dataset == 'Pirate5RDE':\n", " return load_pkl(dim,'graph/WTICRDE_5k.pkl')\n", " elif dataset == 'Pirate3RDE':\n", " return load_pkl(dim,'graph/WTICRDE_3k.pkl')\n", " elif dataset == 'Pirate10RE':\n", " return load_pkl(dim,'graph/WTICRE_1w.pkl')\n", " elif dataset == 'Pirate10RD':\n", " return load_pkl(dim,'graph/WTICRD_1w.pkl')\n", " elif dataset == 'Pirate10DE':\n", " return load_pkl(dim,'graph/WTICDE_1w.pkl')\n", " elif dataset == 'Pirate10R':\n", " return load_pkl(dim,'graph/WTICR_1w.pkl')\n", " elif dataset == 'Pirate10WTID':\n", " return load_pkl(dim,'graph/WTID_1w.pkl')\n", " elif dataset == 'Pirate10WT':\n", " return load_pkl(dim,'graph/WT_1w.pkl')\n", " elif dataset == 'Pirate10WI':\n", " return load_pkl(dim,'graph/WI_1w.pkl')\n", " elif dataset == 'Pirate10WD':\n", " return load_pkl(dim,'graph/WD_1w.pkl')\n", " elif dataset == 'Pirate10WC':\n", " return load_pkl(dim,'graph/WC_1w.pkl')\n", " elif dataset == 'Pirate10WE':\n", " return load_pkl(dim,'graph/WE_1w.pkl')\n", " elif dataset == 'Pirate10WR':\n", " return load_pkl(dim,'graph/WR_1w.pkl')\n", " elif dataset == 'Pirate10NoC':\n", " return load_pkl(dim,'graph/WTIRDE_1w.pkl')\n", " else:\n", " return NotImplementedError('Unsupported dataset {}'.format(dataset))\n", "\n", "class EarlyStopping(object):\n", " def __init__(self, patience=10):\n", " dt = datetime.datetime.now()\n", " self.filename = './model/early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(\n", " dt.date(), dt.hour, dt.minute, dt.second)\n", " self.patience = patience\n", " self.counter = 0\n", " self.best_acc = None\n", " self.best_loss = None\n", " self.early_stop = False\n", "\n", " def step(self, loss, acc, model):\n", " if self.best_loss is None:\n", " self.best_acc = acc\n", " self.best_loss = loss\n", " self.save_checkpoint(model)\n", " elif (loss > self.best_loss) and (acc < self.best_acc):\n", " self.counter += 1\n", " # print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n", " if self.counter >= self.patience:\n", " self.early_stop = True\n", " else:\n", " if (loss <= self.best_loss) and (acc >= self.best_acc):\n", " self.save_checkpoint(model)\n", " self.best_loss = np.min((loss, self.best_loss))\n", " self.best_acc = np.max((acc, self.best_acc))\n", " self.counter = 0\n", " return self.early_stop\n", "\n", " def save_checkpoint(self, model):\n", " \"\"\"Saves model when validation loss decreases.\"\"\"\n", " torch.save(model.state_dict(), self.filename)\n", "\n", " def load_checkpoint(self, model):\n", " \"\"\"Load the latest checkpoint.\"\"\"\n", " model.load_state_dict(torch.load(self.filename))\n", "\n", "\n", "def plot(features,pred,labels):\n", " import torch\n", " import numpy as np\n", " import matplotlib.pyplot as plt\n", " import matplotlib.cm as cm\n", " def PCA(data, k=2):\n", " X = data\n", " X_mean = torch.mean(X, 0)\n", " X = X - X_mean.expand_as(X)\n", " # SVD\n", " U, S, V = torch.svd(torch.t(X))\n", " return torch.mm(X, U[:, :k])\n", "\n", " X = features\n", " y = []\n", " for i in range(len(labels)):\n", " if pred[i]==labels[i] and pred[i]==1:\n", " y.append(1)\n", " elif pred[i]==labels[i] and pred[i]==0:\n", " y.append(2)\n", " elif pred[i]!=labels[i] and pred[i]==0:\n", " y.append(-1)\n", " else:\n", " y.append(-2)\n", " y = torch.tensor(y)\n", " # y = labels\n", " X_pca = PCA(X)\n", " pca = X_pca.numpy()\n", " #\n", " plt.figure()\n", " color = cm.rainbow(np.linspace(0, 1, 4))\n", " plt.scatter(pca[y == 1,0], pca[y == 1,1],s=5, label='True Positive', color=color[0],marker='o')\n", " plt.scatter(pca[y == 2,0], pca[y == 2,1],s=5, label='True Negative', color=color[1],marker='o')\n", " plt.scatter(pca[y == -1,0], pca[y == -1,1],s=100, label='False Positive', color=color[2],marker='*')\n", " plt.scatter(pca[y == -2,0], pca[y == -2,1],s=100, label='False Negative', color=color[3],marker='*')\n", "\n", " plt.legend()\n", " plt.title('PCA of Pirate 1W dataset')\n", " plt.show(): dict\n", "\n", "\n", "# load_data('eng_50','Pirate10RDE')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FyEAWshW0eDZ" }, "outputs": [], "source": [ "'''\n", "Focal Loss.py\n", "'''\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.cuda.amp as amp\n", "\n", "\n", "##\n", "# version 1: use torch.autograd\n", "class FocalLossV1(nn.Module):\n", "\n", " def __init__(self,\n", " alpha=0.25,\n", " gamma=2,\n", " reduction='mean',):\n", " super(FocalLossV1, self).__init__()\n", " self.alpha = alpha\n", " self.gamma = gamma\n", " self.reduction = reduction\n", " self.crit = nn.BCEWithLogitsLoss(reduction='none')\n", "\n", " def forward(self, logits, label):\n", " '''\n", " Usage is same as nn.BCEWithLogits:\n", " >>> criteria = FocalLossV1()\n", " >>> logits = torch.randn(8, 19, 384, 384)\n", " >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()\n", " >>> loss = criteria(logits, lbs)\n", " '''\n", " probs = torch.sigmoid(logits)\n", " coeff = torch.abs(label - probs).pow(self.gamma).neg()\n", " log_probs = torch.where(logits >= 0,\n", " F.softplus(logits, -1, 50),\n", " logits - F.softplus(logits, 1, 50))\n", " log_1_probs = torch.where(logits >= 0,\n", " -logits + F.softplus(logits, -1, 50),\n", " -F.softplus(logits, 1, 50))\n", " loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs\n", " loss = loss * coeff\n", "\n", " if self.reduction == 'mean':\n", " loss = loss.mean()\n", " if self.reduction == 'sum':\n", " loss = loss.sum()\n", " return loss\n", "\n", "\n", "##\n", "# version 2: user derived grad computation\n", "class FocalSigmoidLossFuncV2(torch.autograd.Function):\n", " '''\n", " compute backward directly for better numeric stability\n", " '''\n", " @staticmethod\n", " @amp.custom_fwd(cast_inputs=torch.float32)\n", " def forward(ctx, logits, label, alpha, gamma):\n", " # logits = logits.float()\n", "\n", " probs = torch.sigmoid(logits)\n", " coeff = (label - probs).abs_().pow_(gamma).neg_()\n", " log_probs = torch.where(logits >= 0,\n", " F.softplus(logits, -1, 50),\n", " logits - F.softplus(logits, 1, 50))\n", " log_1_probs = torch.where(logits >= 0,\n", " -logits + F.softplus(logits, -1, 50),\n", " -F.softplus(logits, 1, 50))\n", " ce_term1 = log_probs.mul_(label).mul_(alpha)\n", " ce_term2 = log_1_probs.mul_(1. - label).mul_(1. - alpha)\n", " ce = ce_term1.add_(ce_term2)\n", " loss = ce * coeff\n", "\n", " ctx.vars = (coeff, probs, ce, label, gamma, alpha)\n", "\n", " return loss\n", "\n", " @staticmethod\n", " @amp.custom_bwd\n", " def backward(ctx, grad_output):\n", " '''\n", " compute gradient of focal loss\n", " '''\n", " (coeff, probs, ce, label, gamma, alpha) = ctx.vars\n", "\n", " d_coeff = (label - probs).abs_().pow_(gamma - 1.).mul_(gamma)\n", " d_coeff.mul_(probs).mul_(1. - probs)\n", " d_coeff = torch.where(label < probs, d_coeff.neg(), d_coeff)\n", " term1 = d_coeff.mul_(ce)\n", "\n", " d_ce = label * alpha\n", " d_ce.sub_(probs.mul_((label * alpha).mul_(2).add_(1).sub_(label).sub_(alpha)))\n", " term2 = d_ce.mul(coeff)\n", "\n", " grads = term1.add_(term2)\n", " grads.mul_(grad_output)\n", "\n", " return grads, None, None, None\n", "\n", "\n", "class FocalLossV2(nn.Module):\n", "\n", " def __init__(self,\n", " alpha=0.25,\n", " gamma=2,\n", " reduction='mean'):\n", " super(FocalLossV2, self).__init__()\n", " self.alpha = alpha\n", " self.gamma = gamma\n", " self.reduction = reduction\n", "\n", " def forward(self, logits, label):\n", " '''\n", " Usage is same as nn.BCEWithLogits:\n", " >>> criteria = FocalLossV2()\n", " >>> logits = torch.randn(8, 19, 384, 384)\n", " >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()\n", " >>> loss = criteria(logits, lbs)\n", " '''\n", " loss = FocalSigmoidLossFuncV2.apply(logits, label, self.alpha, self.gamma)\n", " if self.reduction == 'mean':\n", " loss = loss.mean(dim=1)\n", " if self.reduction == 'sum':\n", " loss = loss.sum(dim=1)\n", " return loss\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jVPCl_XtZ4xP" }, "outputs": [], "source": [ "'''\n", "Renode.py\n", "'''\n", "import math\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "\n", "def page_rank(args,adj,filename):\n", " print('Calculate Pagerank')\n", " pr_prob = 1 - args['pagerank_prob']\n", " A = adj.to_dense()\n", " A_hat = A.to(args['device']) + torch.eye(A.size(0)).to(args['device']) # add self-loop\n", " D = torch.diag(torch.sum(A_hat, 1))\n", " D = D.inverse().sqrt()\n", " A_hat = torch.mm(torch.mm(D, A_hat), D)\n", " Pi = pr_prob * ((torch.eye(A.size(0)).to(args['device']) - (1 - pr_prob) * A_hat).inverse())\n", " Pi = Pi.cpu()\n", " torch.save(Pi, filename)\n", " return Pi.to(args['device'])\n", "\n", "def get_gpr(Pi,train_node,labels):\n", " gpr_matrix = [] # the class-level influence distribution\n", " for iter_c in range(2):\n", " iter_Pi = Pi[torch.tensor(train_node[iter_c]).long()]\n", " iter_gpr = torch.mean(iter_Pi, dim=0).squeeze()\n", " gpr_matrix.append(iter_gpr)\n", " '''\n", " gpr_matrix = [tensor[0],tensor[1]]\n", " tensor[0] = []\n", " temp_gpr = []\n", " '''\n", "\n", " temp_gpr = torch.stack(gpr_matrix, dim=0)\n", " temp_gpr = temp_gpr.transpose(0, 1)\n", " gpr = temp_gpr\n", "\n", " return gpr\n", " # rn_weight = get_renode_weight(gpr,) # ReNode Weight\n", "\n", "\n", "def get_renode_weight(args,Pi,gpr,labels,train_mask):\n", " ppr_matrix = Pi # personlized pagerank\n", " gpr_matrix = torch.tensor(gpr).float() # class-accumulated personlized pagerank\n", "\n", " base_w = args['rn_base_weight']\n", " scale_w = args['rn_scale_weight']\n", " nnode = ppr_matrix.size(0)\n", " # unlabel_mask = data.train_mask.int().ne(1) # unlabled node\n", "\n", " # computing the Totoro values for labeled nodes\n", " '''\n", " gpr_sum 对每个节点各label维度上与pagerank计算出的权重进行求和\n", " gpr_rn 转置后减gpr_matrix\n", " rn_matrix pagerank矩阵与每个节点计算出的类别权重进行相乘\n", " '''\n", " gpr_sum = torch.sum(gpr_matrix, dim=1)\n", " gpr_rn = gpr_sum.unsqueeze(1) - gpr_matrix\n", " rn_matrix = torch.mm(ppr_matrix, gpr_rn)\n", "\n", " label_matrix = F.one_hot(labels, gpr_matrix.size(1)).float()\n", " # label_matrix[unlabel_mask] = 0\n", "\n", "\n", " '''\n", " rn_matrix 根据rn矩阵及label标签得到每个节点的rn值\n", " '''\n", " rn_matrix = torch.sum(rn_matrix * label_matrix, dim=1)\n", " # rn_matrix[unlabel_mask] = rn_matrix.max() + 99 # exclude the influence of unlabeled node\n", "\n", "\n", " '''\n", " 根据节点rn值从小到大排序\n", " '''\n", " # computing the ReNode Weight\n", " train_size = torch.sum(train_mask.int()).item()\n", " totoro_list = rn_matrix.tolist()\n", " id2totoro = {i: totoro_list[i] for i in range(len(totoro_list))}\n", "\n", " sorted_totoro = sorted(id2totoro.items(), key=lambda x: x[1], reverse=False)\n", "\n", " id2rank = {sorted_totoro[i][0]: i for i in range(nnode)}\n", "\n", " totoro_rank = [id2rank[i] for i in range(nnode)]\n", "\n", "\n", " rn_weight = [(base_w + 0.5 * scale_w * (1 + math.cos(x * 1.0 * math.pi / (train_size - 1)))) for x in totoro_rank]\n", " rn_weight = torch.from_numpy(np.array(rn_weight)).type(torch.FloatTensor).to(args['device'])\n", " rn_weight = rn_weight * train_mask.float()\n", "\n", " # print(rn_weight.size())\n", " # print('rn size: ',rn_weight.size())\n", " return rn_weight\n", "\n", "\n", "'''\n", "Get PageRank\n", "'''\n", "\n", "def hg2Pi(args,g):\n", " # metapath = ['wi']\n", " metapath = ['wt', 'wi', 'wd','wc','we','wr']\n", " # metapath = ['wt', 'wi', 'wd','we','wr']\n", " Pi = {}\n", " for item in metapath:\n", " if os.path.exists('./pagerank/'+args['dataset'] + item + '.pt'):\n", " Pi[item] = torch.load('./pagerank/'+args['dataset'] + item + '.pt').to(args['device'])\n", " else:\n", " meta_g = dgl.metapath_reachable_graph(g, [item, item[::-1]])\n", " Pi[item] = page_rank(args, meta_g.adjacency_matrix(),'./pagerank/'+args['dataset'] + item + '.pt')\n", "\n", " return Pi\n", "\n", "'''\n", "Get ReNode Weight\n", "'''\n", "\n", "def hg2rn(args,Pi,att_score, train_node, train_mask, labels):\n", " # metapath = ['wi']\n", " metapath = ['wt','wi','wd','wc','we','wr']\n", " # metapath = ['wt','wi','wd','we','wr']\n", " weight = {}\n", " for item in metapath:\n", " pi = Pi[item]*att_score[item]\n", " gpr = get_gpr(pi, train_node, labels)\n", " weight[item] = get_renode_weight(args, pi, gpr, labels, train_mask)\n", "\n", " # for k,v in weight.items():\n", " # print(k,v)\n", " res = torch.zeros(len(labels)).to(args['device'])\n", " for _ in weight.values():\n", " res = torch.add(res,_)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R7bIsLs7G8eo" }, "outputs": [], "source": [ "'''\n", "HGT train.py\n", "'''\n", "from pprint import pprint\n", "\n", "import torch\n", "from sklearn.metrics import f1_score, classification_report, confusion_matrix, accuracy_score, precision_score, \\\n", " recall_score, precision_recall_fscore_support\n", "\n", "from model import *\n", "import numpy as np\n", "from scipy.sparse import coo_matrix\n", "\n", "\n", "\n", "def score(logits, labels):\n", " _, indices = torch.max(logits, dim=1)\n", " prediction = indices.long().cpu().numpy()\n", " labels = labels.cpu().numpy()\n", "\n", " accuracy = (prediction == labels).sum() / len(prediction)\n", " micro_f1 = f1_score(labels, prediction, average='micro')\n", " macro_f1 = f1_score(labels, prediction, average='macro')\n", "\n", " return accuracy, micro_f1, macro_f1\n", "\n", "\n", "def hgt_evaluate(model, g, t, labels, mask, loss_func):\n", " model.eval()\n", " with torch.no_grad():\n", " logits = model(g, t)\n", " loss = loss_func(logits[mask], F.one_hot(labels[mask], num_classes=2).float())\n", " accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])\n", " # indices = indices.cpu().numpy()\n", " # labels = labels.cpu().numpy()\n", " _, indices = torch.max(logits[mask], dim=1)\n", " val_precision = precision_score(labels[mask].cpu().numpy(), indices.cpu().numpy(), average='macro')\n", " val_recall = recall_score(labels[mask].cpu().numpy(), indices.cpu().numpy(), average='macro')\n", "\n", " return loss, accuracy, micro_f1, macro_f1, val_precision, val_recall\n", "\n", "\n", "def hgt_test(model, g, t, labels, mask):\n", " model.eval()\n", " with torch.no_grad():\n", " logits = model(g, t)\n", " logits = logits[mask]\n", " labels = labels[mask]\n", " _, indices = torch.max(logits, dim=1)\n", " '''\n", " Plot output feature\n", " '''\n", " # feature_output1 = model.featuremap[mask].cpu()\n", " # plot(feature_output1,indices.cpu(),labels.cpu())\n", "\n", " indices = indices.cpu().numpy()\n", " labels = labels.cpu().numpy()\n", " p,r,f1,s = precision_recall_fscore_support(labels,indices)\n", "\n", " print(classification_report(labels, indices, digits=4))\n", " print(confusion_matrix(labels, indices))\n", " return p[1],r[1],f1[1]\n", "\n", "\n", "def get_n_params(model):\n", " pp = 0\n", " for p in list(model.parameters()):\n", " nn = 1\n", " for s in list(p.size()):\n", " nn = nn * s\n", " pp += nn\n", " return pp\n", "\n", "def att_score(g,model):\n", " '''\n", " Using subgraph to calculate the website attention score matrix\n", " '''\n", " meta = model.meta\n", " att_score = {}\n", " for k, v in meta.items():\n", " row, col = v['edges']\n", " # row = row.cpu().numpy()\n", " # col = col.cpu().numpy()\n", " tmp = v['att_score']\n", "\n", " indices = torch.stack((row,col),0)\n", " if k == 'wt':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('third')])\n", " elif k == 'tw':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('third'), g.num_nodes('web')])\n", " elif k == 'iw':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('ip'), g.num_nodes('web')])\n", " elif k == 'wi':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('ip')])\n", " elif k == 'wd':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('dns')])\n", " elif k == 'dw':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('dns'), g.num_nodes('web')])\n", " elif k == 'ew':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('email'), g.num_nodes('web')])\n", " elif k == 'we':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('email')])\n", " elif k == 'wr':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('reg')])\n", " elif k == 'rw':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('reg'), g.num_nodes('web')])\n", " elif k == 'wc':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('web'), g.num_nodes('cert')])\n", " elif k == 'cw':\n", " att_score[k] = torch.sparse_coo_tensor(indices,tmp,[g.num_nodes('cert'), g.num_nodes('web')])\n", " # for k, v in att_score.items():\n", " # print(v.shape)\n", " re_att = {}\n", " re_att['wt'] = torch.mm(att_score['wt'].to_dense(), att_score['tw'].to_dense())\n", " # print(re_att)\n", " re_att['wi'] = torch.mm(att_score['wi'].to_dense(), att_score['iw'].to_dense())\n", " re_att['wd'] = torch.mm(att_score['wd'].to_dense(), att_score['dw'].to_dense())\n", " re_att['wr'] = torch.mm(att_score['wr'].to_dense(), att_score['rw'].to_dense())\n", " re_att['wc'] = torch.mm(att_score['wc'].to_dense(), att_score['cw'].to_dense())\n", " re_att['we'] = torch.mm(att_score['we'].to_dense(), att_score['ew'].to_dense())\n", " return re_att\n", "\n", "def get_train_node(train_idx,labels):\n", " train_node=[[],[]]\n", " for i in train_idx:\n", " if labels[i]==1:\n", " train_node[1].append(i)\n", " else:\n", " train_node[0].append(i)\n", " return train_node\n", "\n", "\n", "def get_binary_mask(total_size, indices):\n", " mask = torch.zeros(total_size)\n", " mask[indices] = 1\n", " return mask.byte()\n", "\n", "\n", "def shuffule(hg,labels):\n", " float_mask = np.random.permutation(np.linspace(0, 1, len(labels)))\n", " train_idx = np.where(float_mask <= 0.6)[0]\n", " val_idx = np.where((float_mask > 0.6) & (float_mask <= 0.8))[0]\n", " test_idx = np.where(float_mask > 0.8)[0]\n", "\n", " num_nodes = hg.number_of_nodes('web')\n", " train_mask = get_binary_mask(num_nodes, train_idx)\n", " val_mask = get_binary_mask(num_nodes, val_idx)\n", " test_mask = get_binary_mask(num_nodes, test_idx)\n", "\n", " print('dataset loaded')\n", " pprint({\n", " 'dataset': 'Pirate',\n", " 'train': train_mask.sum().item() / num_nodes,\n", " 'val': val_mask.sum().item() / num_nodes,\n", " 'test': test_mask.sum().item() / num_nodes\n", " })\n", " return train_idx, val_idx, test_idx, train_mask, val_mask, test_mask\n", "\n", "\n", "def hgt_batch(args,model, g,train_idx, train_mask, val_mask, test_mask, labels, Pi,train_node):\n", " stopper = EarlyStopping(patience=args['patience'])\n", " if args['loss']=='BCE':\n", " loss_fcn = torch.nn.BCEWithLogitsLoss()\n", " if args['loss'] =='BCERN':\n", " loss_fcn = torch.nn.BCEWithLogitsLoss(reduce=False)\n", " if args['loss']=='Focal' or args['loss'] =='RN':\n", " loss_fcn = FocalLossV2()\n", " # loss_fcn = torch.nn.CrossEntropyLoss()\n", "\n", " optimizer = torch.optim.AdamW(model.parameters())\n", "\n", " for epoch in range(args['num_epochs']):\n", " model.train()\n", " logits = model(g, 'web')\n", "\n", " if args['loss']=='RN' or args['loss']=='BCERN':\n", " re_att = att_score(g,model)\n", " rn_weight = hg2rn(args,Pi,re_att,train_node,train_mask,labels)\n", " '''\n", " BCE_LOSS\n", " '''\n", " if args['loss'] == 'BCE':\n", " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2).float())\n", " '''\n", " Focal_LOSS\n", " '''\n", " if args['loss'] == 'Focal':\n", " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2)).mean()\n", "\n", " '''\n", " Focal_LOSS with ReNode\n", " '''\n", " if args['loss'] == 'RN':\n", " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2))\n", " loss = torch.sum(loss * rn_weight[train_mask].to(args['device'])) / loss.size(0)\n", " '''\n", " BCE_LOSS with ReNode\n", " '''\n", " if args['loss'] == 'BCERN':\n", " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2).float()).mean(dim=1)\n", " loss = torch.sum(loss * rn_weight[train_mask].to(args['device'])) / loss.size(0)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])\n", " val_loss, val_acc, val_micro_f1, val_macro_f1, val_precision, val_recall = hgt_evaluate(model, g, 'web', labels,\n", " val_mask, loss_fcn)\n", " early_stop = stopper.step(val_loss.mean().data.item(), val_macro_f1, model)\n", "\n", " # print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '\n", " # 'Val Loss {:.4f} | Val Macro Precison {:.4f} | Val Macro Recall {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(\n", " # epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.mean().item(), val_precision, val_recall,\n", " # val_micro_f1, val_macro_f1))\n", "\n", " if early_stop:\n", " break\n", "\n", " stopper.load_checkpoint(model)\n", " print('Seed',args['seed'])\n", " return hgt_test(model, g, 'web', labels, test_mask)\n", "\n", "\n", "def hgt_train(args, G, features, labels, num_classes):\n", " \n", " train_idx, val_idx, test_idx, train_mask, val_mask, test_mask = shuffule(G,labels)\n", "\n", " Pi = hg2Pi(args,G)\n", " # rn_weight = []\n", " train_node = get_train_node(train_idx,labels)\n", "\n", " node_dict = {}\n", " edge_dict = {}\n", " for ntype in G.ntypes:\n", " node_dict[ntype] = len(node_dict)\n", " for etype in G.etypes:\n", " edge_dict[etype] = len(edge_dict)\n", " G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype]\n", "\n", " # Random initialize input feature\n", " for ntype in G.ntypes:\n", " G.nodes[ntype].data['inp'] = features[ntype]\n", " # if ntype in ['web','third','ip','dns']:\n", " # G.nodes[ntype].data['inp'] = features[ntype]\n", " # else:\n", " # emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), features['third'].shape[1]), requires_grad=False)\n", " # nn.init.xavier_uniform_(emb)\n", " # G.nodes[ntype].data['inp'] = emb\n", "\n", " G = G.to(args['device'])\n", "\n", " labels = labels.to(args['device'])\n", " train_mask = train_mask.to(args['device'])\n", " val_mask = val_mask.to(args['device'])\n", " test_mask = test_mask.to(args['device'])\n", "\n", " # model = HGT(G,\n", " # node_dict, edge_dict,\n", " # n_inp=features['web'].shape[1],\n", " # n_hid=args['n_hid'],\n", " # n_out=labels.max().item() + 1,\n", " # n_layers=args['n_layers'],\n", " # n_heads=4,\n", " # use_norm=True).to(args['device'])\n", "\n", " # print('Training HGT with #param: %d' % (get_n_params(model)))\n", " # return hgt_batch(args,model, G,train_idx, train_mask, val_mask, test_mask, labels, Pi,train_node)\n", "\n", " model = HeteroRGCN(G,\n", " in_size=features['web'].shape[1],\n", " hidden_size=args['n_hid'],\n", " out_size=labels.max().item() + 1).to(args['device'])\n", " print('Training RGCN with #param: %d' % (get_n_params(model)))\n", " return hgt_batch(args,model, G,train_idx, train_mask, val_mask, test_mask, labels, Pi,train_node)\n", "\n", " # model = HGT(G,\n", " # node_dict, edge_dict,\n", " # n_inp=features['web'].shape[1],\n", " # n_hid=args['n_hid'],\n", " # n_out=labels.max().item() + 1,\n", " # n_layers=0,\n", " # n_heads=4).to(args['device'])\n", " # print('Training MLP with #param: %d' % (get_n_params(model)))\n", " # hgt_batch(model, G,train_mask,val_mask,test_mask,labels,rn_weight)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jozCMUjGHG_n" }, "outputs": [], "source": [ "'''\n", "HGT main.py\n", "'''\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "import csv\n", "import numpy as np\n", "result = []\n", "\n", "args = {'hetero':True}\n", "args['clip'] = 1.0\n", "args['max_lr'] = 1e-3\n", "args['n_hid'] = 256\n", "args['num_epochs']=200\n", "args['dataset'] = 'Pirate1RDE'\n", "args['n_layers'] = 2\n", "args['char_dim'] = 'eng_200'\n", "\n", "'''\n", "renode params\n", "'''\n", "args['pagerank_prob'] = 0.85\n", "args['rn_base_weight'] = 0.5\n", "args['rn_scale_weight'] = 1.0\n", "G, features, labels, num_classes = load_data(args['char_dim'], args['dataset'])\n", "print(G)\n", "\n", "for i in range(100):\n", " args['seed'] = i\n", " '''\n", " loss\n", " '''\n", " args['loss'] = 'BCE'\n", " # args['loss'] = 'Focal'\n", " # args['loss'] = 'RN'\n", " # args['loss'] = 'BCERN'\n", " args = setup(args)\n", "\n", " print(args)\n", " # main(args)\n", " tmp = list(hgt_train(args,G, features, labels, num_classes))\n", " with open('./'+args['dataset']+'_'+args['loss']+'.csv','a+',encoding='utf-8') as csvfile:\n", " write = csv.writer(csvfile)\n", " write.writerow(tmp)\n", " result.append(tmp)\n", "for item in result:\n", " print(item)\n", "tmp = np.array(result)\n", "ma = np.max(tmp,0)\n", "me = np.mean(tmp,0)\n", "mi = np.min(tmp,0)\n", "print('Precision: {:.4}, +{:.4}, -{:.4}'.format(me[0],ma[0]-me[0],me[0]-mi[0]))\n", "print('Recall: {:.4}, +{:.4}, -{:.4}'.format(me[1],ma[1]-me[1],me[1]-mi[1]))\n", "print('F1: {:.4}, +{:.4}, -{:.4}'.format(me[2],ma[2]-me[2],me[2]-mi[2]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wb-YeP-nsYNX" }, "outputs": [], "source": [ "'''\n", "HAN train.py\n", "'''\n", "import torch\n", "from sklearn.metrics import f1_score, classification_report, confusion_matrix, accuracy_score, precision_score, \\\n", " recall_score\n", "\n", "def han_evaluate(model, g, features, labels, mask, loss_func):\n", " model.eval()\n", " with torch.no_grad():\n", " logits = model(g, features)\n", " loss = loss_func(logits[mask], F.one_hot(labels[mask], num_classes = 2).float())\n", " accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])\n", "\n", " return loss, accuracy, micro_f1, macro_f1\n", "\n", "def han_test(model, g, features, labels, mask):\n", " model.eval()\n", " with torch.no_grad():\n", " logits = model(g, features)\n", " logits = logits[mask]\n", " labels = labels[mask]\n", " _, indices = torch.max(logits, dim=1)\n", " '''\n", " Plot output feature\n", " '''\n", " # feature_output1 = model.featuremap[mask].cpu()\n", " # plot(feature_output1,indices.cpu(),labels.cpu())\n", "\n", " indices = indices.cpu().numpy()\n", " labels = labels.cpu().numpy()\n", " p,r,f1,s = precision_recall_fscore_support(labels,indices)\n", "\n", " print(classification_report(labels, indices, digits=4))\n", " print(confusion_matrix(labels, indices))\n", " return p[1],r[1],f1[1]\n", "\n", "\n", "def han_train(args, g, features, labels, num_classes):\n", " # If args['hetero'] is True, g would be a heterogeneous graph.\n", " # Otherwise, it will be a list of homogeneous graphs.\n", " \n", " train_idx, val_idx, test_idx, train_mask, val_mask, test_mask = shuffule(g,labels)\n", "\n", "\n", " if hasattr(torch, 'BoolTensor'):\n", " train_mask = train_mask.bool()\n", " val_mask = val_mask.bool()\n", " test_mask = test_mask.bool()\n", "\n", " features = features['web'].to(args['device'])\n", "\n", " labels = labels.to(args['device'])\n", " train_mask = train_mask.to(args['device'])\n", " val_mask = val_mask.to(args['device'])\n", " test_mask = test_mask.to(args['device'])\n", "\n", " from model import HAN\n", " model = HAN(meta_paths=[['wr', 'rw'],['wd', 'dw']],\n", " in_size=features.shape[1],\n", " hidden_size=args['hidden_units'],\n", " out_size=num_classes,\n", " # out_size=num_classes,\n", " num_heads=args['num_heads'],\n", " dropout=args['dropout']).to(args['device'])\n", " g = g.to(args['device'])\n", "\n", "\n", " stopper = EarlyStopping(patience=args['patience'])\n", " loss_fcn = torch.nn.BCEWithLogitsLoss()\n", " optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],\n", " weight_decay=args['weight_decay'])\n", "\n", " for epoch in range(args['num_epochs']):\n", " model.train()\n", " logits = model(g, features)\n", " # loss = loss_fcn(logits[train_mask], labels[train_mask])\n", " loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask], num_classes = 2).float())\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])\n", " val_loss, val_acc, val_micro_f1, val_macro_f1 = han_evaluate(model, g, features, labels, val_mask, loss_fcn)\n", " early_stop = stopper.step(val_loss.mean().data.item(), val_macro_f1, model)\n", "\n", " # print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '\n", " # 'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(\n", " # epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.mean().item(), val_micro_f1, val_macro_f1))\n", "\n", " if early_stop:\n", " break\n", "\n", " stopper.load_checkpoint(model)\n", " print('Seed',args['seed'])\n", " return hgt_test(model, g, features, labels, test_mask)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ORw6BM_pTO6Q" }, "outputs": [], "source": [ "'''\n", "HAN main.py\n", "'''\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "import csv\n", "import numpy as np\n", "result = []\n", "\n", "args = {'hetero':True}\n", "args['clip'] = 1.0\n", "args['max_lr'] = 1e-3\n", "args['n_hid'] = 256\n", "args['num_epochs']=200\n", "args['dataset'] = 'Pirate1RDE'\n", "args['n_layers'] = 2\n", "args['char_dim'] = 'eng_200'\n", "\n", "'''\n", "renode params\n", "'''\n", "args['pagerank_prob'] = 0.85\n", "args['rn_base_weight'] = 0.5\n", "args['rn_scale_weight'] = 1.0\n", "G, features, labels, num_classes = load_data(args['char_dim'], args['dataset'])\n", "print(G)\n", "\n", "for i in range(100):\n", " args['seed'] = i\n", " '''\n", " loss\n", " '''\n", " args['loss'] = 'BCE'\n", " # args['loss'] = 'Focal'\n", " # args['loss'] = 'RN'\n", " # args['loss'] = 'BCERN'\n", " args = setup(args)\n", "\n", " print(args)\n", " # main(args)\n", " tmp = list(han_train(args,G, features, labels, num_classes))\n", " with open('./'+args['dataset']+'_'+args['loss']+'.csv','a+',encoding='utf-8') as csvfile:\n", " write = csv.writer(csvfile)\n", " write.writerow(tmp)\n", " result.append(tmp)\n", "for item in result:\n", " print(item)\n", "tmp = np.array(result)\n", "ma = np.max(tmp,0)\n", "me = np.mean(tmp,0)\n", "mi = np.min(tmp,0)\n", "print('Precision: {:.4}, +{:.4}, -{:.4}'.format(me[0],ma[0]-me[0],me[0]-mi[0]))\n", "print('Recall: {:.4}, +{:.4}, -{:.4}'.format(me[1],ma[1]-me[1],me[1]-mi[1]))\n", "print('F1: {:.4}, +{:.4}, -{:.4}'.format(me[2],ma[2]-me[2],me[2]-mi[2]))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "graph.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }