summaryrefslogtreecommitdiff
path: root/server/apps
diff options
context:
space:
mode:
authorhandingkang <[email protected]>2024-06-18 22:50:42 +0800
committerhandingkang <[email protected]>2024-06-18 22:50:42 +0800
commit1d4200da7064fe41e027f49f967b6cfe795cd88c (patch)
tree438016f2b548e5caa455e2220bf4391e57656b08 /server/apps
parent4339ecd79f54d1cb1cdd129be7fdfe51a9d3cb14 (diff)
1. 实现主控端实际调用代理探测目标时延的功能
2. 修复部分数据格式问题
Diffstat (limited to 'server/apps')
-rw-r--r--server/apps/target.py117
-rw-r--r--server/apps/task.py52
2 files changed, 87 insertions, 82 deletions
diff --git a/server/apps/target.py b/server/apps/target.py
index c056a82..9ae34e6 100644
--- a/server/apps/target.py
+++ b/server/apps/target.py
@@ -1,7 +1,7 @@
# 目标状态感知
# 时延测试接口
import ipaddress
-import random
+import json
import threading
from operator import or_
@@ -12,6 +12,7 @@ from apiflask.validators import OneOf
from requests.exceptions import Timeout
from sqlalchemy import distinct, func, case
+from apps.util import debug
from exts import db
from model import Target, Task, Agent
@@ -55,7 +56,7 @@ class TestNode(Schema):
class Delay(Schema):
- Id = Integer()
+ Id = String()
CurrDelay = String()
# MeanDelay=Integer()
# MaxDelay=Integer()
@@ -89,13 +90,13 @@ class CouInfo(Schema):
def get_nodes(query_data):
node_list = []
query_session = db.session
- # 查询所有的isp
+ # 查询负责该任务的所有状态感知节点简要信息
node_data = query_session.query(Task.SCAN_AGENT_ID_LIST).filter(Task.task_id == query_data['taskid']).all()
# 存在数据
if len(node_data) > 0:
# 负责该任务目标的状态感知的节点
- nodes_info = node_data[0].SCAN_AGENT_ID_LIST
+ nodes_info = json.loads(node_data[0].SCAN_AGENT_ID_LIST)
# 查询每一个节点的信息
nodes = query_session.query(Agent).filter(Agent.agent_id.in_(nodes_info.keys())).all()
for node in nodes:
@@ -108,10 +109,11 @@ def get_nodes(query_data):
"Lat": node.lat,
"Lng": node.lng
})
+ query_session.close()
else:
+ query_session.close()
return {"code": 500, "nodes": []}
- query_session.close()
# 废弃代码
# num = 10
# for i in range(num):
@@ -137,9 +139,10 @@ def get_nodes(query_data):
@bp.output(DelayOut)
# TODO:和实际节点联调测试
def get_pernode_delay(query_data):
+ query_session = db.session
# TODO:DoH处理
# 探测地址
- addr = query_data['ip']
+ addr = query_data['target']
taskid = query_data['taskid']
scan_type = query_data['type']
@@ -147,95 +150,99 @@ def get_pernode_delay(query_data):
ans = []
# 线程池
threads = []
- # # 检索探测节点信息
- # sql = """
- # SELECT SCAN_AGENT_ID_LIST as node_info
- # FROM %s
- # WHERE TASK_ID='%s'
- # """ % (MYSQL_TAB_TASK, taskid)
- # da.cursor.execute(sql)
- # # 探测节点ID与地址
- # nodes = json.loads(da.cursor.fetchall()[0]["node_info"].replace('"', "\""))
- #
- # for id, ip in nodes.items():
- # mythread = threading.Thread(target=task, args=[ans, addr, {'id': id, 'ip': ip}, type])
- # mythread.start()
- # threads.append(mythread)
- # for t in threads:
- # t.join()
+
+ # 检索探测节点信息
+ # 查询负责该任务的所有状态感知节点简要信息
+ node_data = query_session.query(Task.SCAN_AGENT_ID_LIST).filter(Task.task_id == taskid).all()
+
+ # 存在数据
+ if len(node_data) > 0:
+ # 负责该任务目标的状态感知的节点ID和IP地址+端口号组成的列表
+ nodes_info = json.loads(node_data[0].SCAN_AGENT_ID_LIST)
+ # 向每个节点查询时延信息
+ for id, ip_port in nodes_info.items():
+ mythread = threading.Thread(target=task, args=[ans, addr, {'id': id, 'ip_port': ip_port}, scan_type])
+ mythread.start()
+ threads.append(mythread)
+ for t in threads:
+ t.join()
+ query_session.close()
+ else:
+ query_session.close()
+ return {"code": 500, "delay_data": []}
# 暂未部署实际代理节点,以假数据返回
- ans = []
- for i in range(10):
- ans.append({"Id": str(i), "Type": scan_type, "CurrDelay": random.randint(1, 1000)})
+ # ans = []
+ # for i in range(10):
+ # ans.append({"Id": str(i), "Type": scan_type, "CurrDelay": random.randint(1, 1000)})
return {"code": 200, 'delay_data': ans}
threadLock = threading.Lock()
-def task(ans, addr, agent, type):
+def task(ans, addr, agent, scan_type):
res = 0
- if type == "icmp":
- res = icmp_delay_query(addr, agent['ip'])
- if type == "tcp":
- res = tcp_delay_query(addr, agent['ip'])
- if type == "dns":
- res = dns_delay_query(addr, agent['ip'])
+ if scan_type == "icmp":
+ res = icmp_delay_query(addr, agent['ip_port'])
+ if scan_type == "tcp":
+ res = tcp_delay_query(addr, agent['ip_port'])
+ if scan_type == "dns":
+ res = dns_delay_query(addr, agent['ip_port'])
threadLock.acquire()
ans.append({
'Id': agent['id'],
'CurrDelay': res,
- 'Type': type})
+ 'Type': scan_type})
threadLock.release()
-def icmp_delay_query(target, addr):
+def icmp_delay_query(target, addr_port):
try:
- res = requests.get(url="http://" + addr + ":2525/script/icmpdelay", params={'ip': target}, timeout=5)
- print("icmp ok:" + addr + "-------" + res.text + "-------" + str(res.elapsed.total_seconds()))
- icmp_delaytable[str(addr) + str(target)] = res.text
+ res = requests.get(url="http://" + addr_port + "/delay/icmp", params={'ip': target}, timeout=5)
+ debug("icmp ok:" + addr_port + "-------" + res.text + "-------" + str(res.elapsed.total_seconds()))
+ icmp_delaytable[str(addr_port) + str(target)] = res.text
return res.text
except Timeout:
# 如果存在旧数据
- if str(addr) + str(target) in icmp_delaytable.keys():
+ if str(addr_port) + str(target) in icmp_delaytable.keys():
pass
# 不存在则设0
else:
- icmp_delaytable[str(addr) + str(target)] = 0
- return icmp_delaytable[str(addr) + str(target)]
+ icmp_delaytable[str(addr_port) + str(target)] = 0
+ return icmp_delaytable[str(addr_port) + str(target)]
-def tcp_delay_query(target, addr):
+def tcp_delay_query(target, addr_port):
try:
- res = requests.get(url="http://" + addr + ":2525/script/tcpdelay", params={'ip': target, 'port': 53}, timeout=5)
- print("tcp ok:" + addr + "-------" + res.text)
- tcp_delaytable[str(addr) + str(target)] = res.text
+ res = requests.get(url="http://" + addr_port + "/delay/tcp", params={'ip': target, 'port': 53}, timeout=5)
+ debug("tcp ok:" + addr_port + "-------" + res.text)
+ tcp_delaytable[str(addr_port) + str(target)] = res.text
return res.text
except Timeout:
# 如果存在旧数据
- if str(addr) + str(target) in tcp_delaytable.keys():
+ if str(addr_port) + str(target) in tcp_delaytable.keys():
pass
# 不存在则设0
else:
- tcp_delaytable[str(addr) + str(target)] = 0
- return tcp_delaytable[str(addr) + str(target)]
+ tcp_delaytable[str(addr_port) + str(target)] = 0
+ return tcp_delaytable[str(addr_port) + str(target)]
-def dns_delay_query(target, addr):
+def dns_delay_query(target, addr_port):
try:
- res = requests.get(url="http://" + addr + ":2525/script/dnsdelay", params={'ip': target}, timeout=5)
- print("dns ok:" + addr + "-------" + res.text)
- dns_delaytable[str(addr) + str(target)] = res.text
- return dns_delaytable[str(addr) + str(target)]
+ res = requests.get(url="http://" + addr_port + "/delay/dns", params={'ip': target}, timeout=5)
+ debug("dns ok:" + addr_port + "-------" + res.text)
+ dns_delaytable[str(addr_port) + str(target)] = res.text
+ return dns_delaytable[str(addr_port) + str(target)]
except Timeout:
# 如果存在旧数据
- if str(addr) + str(target) in dns_delaytable.keys():
+ if str(addr_port) + str(target) in dns_delaytable.keys():
pass
# 不存在则设0
else:
- dns_delaytable[str(addr) + str(target)] = 0
- return dns_delaytable[str(addr) + str(target)]
+ dns_delaytable[str(addr_port) + str(target)] = 0
+ return dns_delaytable[str(addr_port) + str(target)]
# 状态感知——DNS记录测试接口
diff --git a/server/apps/task.py b/server/apps/task.py
index abac1dd..6f65000 100644
--- a/server/apps/task.py
+++ b/server/apps/task.py
@@ -6,10 +6,9 @@ from apiflask import APIBlueprint, Schema
from apiflask.fields import String, Integer, IP, DateTime, List, Nested
from apiflask.validators import OneOf
from sqlalchemy import and_
-
-from exts import db
-from model import TaskPolicy, Task, Agent
from sqlalchemy.exc import SQLAlchemyError
+
+from model import Task, Agent
from .agentcomm import deliver_task
from .policy import *
from .util import error
@@ -17,7 +16,6 @@ from .util import error
bp = APIBlueprint("任务管理接口集合", __name__, url_prefix="/task")
-
class TaskSchema(Schema):
id = String()
target = String()
@@ -83,15 +81,15 @@ def valid_task_info(task_param: dict):
# 期望注入记录
"target_rr": String(),
# 期望策略
- "policy": String(validate=OneOf(["auto", "ddos", "sjqp"])),
+ "policy": String(validate=OneOf(["auto", "ddos", "sjqp"]), load_default="auto"),
# 状态感知方式
- "scan": String(validate=OneOf(["auto", "icmp", "dns", "tcp", "record"])),
+ "scan": String(validate=OneOf(["auto", "icmp", "dns", "tcp", "record"]), load_default="auto"),
# 策略切换时限,单位分钟
- "policy_time": Integer(),
+ "policy_time": Integer(load_default=60),
# 任务执行时限,单位分钟
- "run_time": Integer(),
+ "run_time": Integer(load_default=600),
# 运行配置
- "run_flag": String(validate=OneOf(["now", "man"]))
+ "run_flag": String(validate=OneOf(["now", "man"]), load_default="now")
}, example={'name': "test_task", 'target': "1.2.3.4", 'agent': "8a9ces", 'target_domain': "www.google.com",
'target_rr': "NS ns.ourattack.com",
'policy': "auto", 'scan': "auto",
@@ -104,35 +102,35 @@ def valid_task_info(task_param: dict):
# TODO: 需要更新接口,created_by
def make_task(json_data):
task = Task(
- task_id = str(uuid.uuid1()),
- task_name = json_data["name"],
- agent_id = json_data["agent"],
+ task_id=str(uuid.uuid1()),
+ task_name=json_data["name"],
+ agent_id=json_data["agent"],
# created_by = "Admin",
- target_ip = str(json_data["target"]),
+ target_ip=str(json_data["target"]),
# policy = json_data["policy"],
- status = "working" if json_data["run_flag"] == "now" else "stop",
- policy_delay = json_data["policy_time"],
- task_delay =json_data["run_time"],
- target_scan = json_data["scan"],
- target_domain = json_data["target_domain"],
- target_rtype = "",
- target_rr = json_data["target_rr"],
+ status="working" if json_data["run_flag"] == "now" else "stop",
+ policy_delay=json_data["policy_time"],
+ task_delay=json_data["run_time"],
+ target_scan=json_data["scan"],
+ target_domain=json_data["target_domain"],
+ target_rtype="",
+ target_rr=json_data["target_rr"],
)
if task.policy == "sjqp":
if task.target_rr == "" or task.target_domain == "":
return {"code": 400, "msg": "数据欺骗缺乏目标域名或注入参数"}
-
+
if task.target_rr != "":
task.target_rtype, task.target_rr = task.target_rr.split(" ")
# 查找所有在线的用于状态感知的agent
- agents = db.session.query(Agent).filter(and_(Agent.status==True, Agent.agent_type=='ztgz')).all()
+ agents = db.session.query(Agent).filter(and_(Agent.status == True, Agent.agent_type == 'ztgz')).all()
agents = random.sample(agents, min(10, len(agents)))
selected_nodes_info = {}
for agent in agents:
- selected_nodes_info[agent.agent_id] = agent.ipaddr.split("|")[0]
+ selected_nodes_info[agent.agent_id] = agent.ipaddr.split("|")[0] + ":" + str(agent.port)
task.SCAN_AGENT_ID_LIST = json.dumps(selected_nodes_info)
-
+
# 插入task
try:
db.session.add(task)
@@ -141,7 +139,7 @@ def make_task(json_data):
db.session.rollback()
error(str(e))
return {"code": 500, "msg": str(e)}
-
+
# 任务策略初始化
tp_id, p_exe, p_param = init_task_policy(json_data["policy"], task.target_ip, task.task_id)
@@ -214,7 +212,7 @@ def tasks_state(query_data):
query = db.session.query(Task).filter().offset((page - 1) * per_page).limit(per_page)
tasks = query.all()
task_count = query.count()
-
+
task_list = []
for task in tasks:
task_r = {}
@@ -325,4 +323,4 @@ def del_task(json_data):
except SQLAlchemyError as e:
db.session.rollback()
error(str(e))
- return {"code": 500, "msg": str(e)} \ No newline at end of file
+ return {"code": 500, "msg": str(e)}