diff options
| author | handingkang <[email protected]> | 2024-06-18 22:50:42 +0800 |
|---|---|---|
| committer | handingkang <[email protected]> | 2024-06-18 22:50:42 +0800 |
| commit | 1d4200da7064fe41e027f49f967b6cfe795cd88c (patch) | |
| tree | 438016f2b548e5caa455e2220bf4391e57656b08 /server/apps | |
| parent | 4339ecd79f54d1cb1cdd129be7fdfe51a9d3cb14 (diff) | |
1. 实现主控端实际调用代理探测目标时延的功能
2. 修复部分数据格式问题
Diffstat (limited to 'server/apps')
| -rw-r--r-- | server/apps/target.py | 117 | ||||
| -rw-r--r-- | server/apps/task.py | 52 |
2 files changed, 87 insertions, 82 deletions
diff --git a/server/apps/target.py b/server/apps/target.py index c056a82..9ae34e6 100644 --- a/server/apps/target.py +++ b/server/apps/target.py @@ -1,7 +1,7 @@ # 目标状态感知 # 时延测试接口 import ipaddress -import random +import json import threading from operator import or_ @@ -12,6 +12,7 @@ from apiflask.validators import OneOf from requests.exceptions import Timeout from sqlalchemy import distinct, func, case +from apps.util import debug from exts import db from model import Target, Task, Agent @@ -55,7 +56,7 @@ class TestNode(Schema): class Delay(Schema): - Id = Integer() + Id = String() CurrDelay = String() # MeanDelay=Integer() # MaxDelay=Integer() @@ -89,13 +90,13 @@ class CouInfo(Schema): def get_nodes(query_data): node_list = [] query_session = db.session - # 查询所有的isp + # 查询负责该任务的所有状态感知节点简要信息 node_data = query_session.query(Task.SCAN_AGENT_ID_LIST).filter(Task.task_id == query_data['taskid']).all() # 存在数据 if len(node_data) > 0: # 负责该任务目标的状态感知的节点 - nodes_info = node_data[0].SCAN_AGENT_ID_LIST + nodes_info = json.loads(node_data[0].SCAN_AGENT_ID_LIST) # 查询每一个节点的信息 nodes = query_session.query(Agent).filter(Agent.agent_id.in_(nodes_info.keys())).all() for node in nodes: @@ -108,10 +109,11 @@ def get_nodes(query_data): "Lat": node.lat, "Lng": node.lng }) + query_session.close() else: + query_session.close() return {"code": 500, "nodes": []} - query_session.close() # 废弃代码 # num = 10 # for i in range(num): @@ -137,9 +139,10 @@ def get_nodes(query_data): @bp.output(DelayOut) # TODO:和实际节点联调测试 def get_pernode_delay(query_data): + query_session = db.session # TODO:DoH处理 # 探测地址 - addr = query_data['ip'] + addr = query_data['target'] taskid = query_data['taskid'] scan_type = query_data['type'] @@ -147,95 +150,99 @@ def get_pernode_delay(query_data): ans = [] # 线程池 threads = [] - # # 检索探测节点信息 - # sql = """ - # SELECT SCAN_AGENT_ID_LIST as node_info - # FROM %s - # WHERE TASK_ID='%s' - # """ % (MYSQL_TAB_TASK, taskid) - # da.cursor.execute(sql) - # # 探测节点ID与地址 - # nodes = json.loads(da.cursor.fetchall()[0]["node_info"].replace('"', "\"")) - # - # for id, ip in nodes.items(): - # mythread = threading.Thread(target=task, args=[ans, addr, {'id': id, 'ip': ip}, type]) - # mythread.start() - # threads.append(mythread) - # for t in threads: - # t.join() + + # 检索探测节点信息 + # 查询负责该任务的所有状态感知节点简要信息 + node_data = query_session.query(Task.SCAN_AGENT_ID_LIST).filter(Task.task_id == taskid).all() + + # 存在数据 + if len(node_data) > 0: + # 负责该任务目标的状态感知的节点ID和IP地址+端口号组成的列表 + nodes_info = json.loads(node_data[0].SCAN_AGENT_ID_LIST) + # 向每个节点查询时延信息 + for id, ip_port in nodes_info.items(): + mythread = threading.Thread(target=task, args=[ans, addr, {'id': id, 'ip_port': ip_port}, scan_type]) + mythread.start() + threads.append(mythread) + for t in threads: + t.join() + query_session.close() + else: + query_session.close() + return {"code": 500, "delay_data": []} # 暂未部署实际代理节点,以假数据返回 - ans = [] - for i in range(10): - ans.append({"Id": str(i), "Type": scan_type, "CurrDelay": random.randint(1, 1000)}) + # ans = [] + # for i in range(10): + # ans.append({"Id": str(i), "Type": scan_type, "CurrDelay": random.randint(1, 1000)}) return {"code": 200, 'delay_data': ans} threadLock = threading.Lock() -def task(ans, addr, agent, type): +def task(ans, addr, agent, scan_type): res = 0 - if type == "icmp": - res = icmp_delay_query(addr, agent['ip']) - if type == "tcp": - res = tcp_delay_query(addr, agent['ip']) - if type == "dns": - res = dns_delay_query(addr, agent['ip']) + if scan_type == "icmp": + res = icmp_delay_query(addr, agent['ip_port']) + if scan_type == "tcp": + res = tcp_delay_query(addr, agent['ip_port']) + if scan_type == "dns": + res = dns_delay_query(addr, agent['ip_port']) threadLock.acquire() ans.append({ 'Id': agent['id'], 'CurrDelay': res, - 'Type': type}) + 'Type': scan_type}) threadLock.release() -def icmp_delay_query(target, addr): +def icmp_delay_query(target, addr_port): try: - res = requests.get(url="http://" + addr + ":2525/script/icmpdelay", params={'ip': target}, timeout=5) - print("icmp ok:" + addr + "-------" + res.text + "-------" + str(res.elapsed.total_seconds())) - icmp_delaytable[str(addr) + str(target)] = res.text + res = requests.get(url="http://" + addr_port + "/delay/icmp", params={'ip': target}, timeout=5) + debug("icmp ok:" + addr_port + "-------" + res.text + "-------" + str(res.elapsed.total_seconds())) + icmp_delaytable[str(addr_port) + str(target)] = res.text return res.text except Timeout: # 如果存在旧数据 - if str(addr) + str(target) in icmp_delaytable.keys(): + if str(addr_port) + str(target) in icmp_delaytable.keys(): pass # 不存在则设0 else: - icmp_delaytable[str(addr) + str(target)] = 0 - return icmp_delaytable[str(addr) + str(target)] + icmp_delaytable[str(addr_port) + str(target)] = 0 + return icmp_delaytable[str(addr_port) + str(target)] -def tcp_delay_query(target, addr): +def tcp_delay_query(target, addr_port): try: - res = requests.get(url="http://" + addr + ":2525/script/tcpdelay", params={'ip': target, 'port': 53}, timeout=5) - print("tcp ok:" + addr + "-------" + res.text) - tcp_delaytable[str(addr) + str(target)] = res.text + res = requests.get(url="http://" + addr_port + "/delay/tcp", params={'ip': target, 'port': 53}, timeout=5) + debug("tcp ok:" + addr_port + "-------" + res.text) + tcp_delaytable[str(addr_port) + str(target)] = res.text return res.text except Timeout: # 如果存在旧数据 - if str(addr) + str(target) in tcp_delaytable.keys(): + if str(addr_port) + str(target) in tcp_delaytable.keys(): pass # 不存在则设0 else: - tcp_delaytable[str(addr) + str(target)] = 0 - return tcp_delaytable[str(addr) + str(target)] + tcp_delaytable[str(addr_port) + str(target)] = 0 + return tcp_delaytable[str(addr_port) + str(target)] -def dns_delay_query(target, addr): +def dns_delay_query(target, addr_port): try: - res = requests.get(url="http://" + addr + ":2525/script/dnsdelay", params={'ip': target}, timeout=5) - print("dns ok:" + addr + "-------" + res.text) - dns_delaytable[str(addr) + str(target)] = res.text - return dns_delaytable[str(addr) + str(target)] + res = requests.get(url="http://" + addr_port + "/delay/dns", params={'ip': target}, timeout=5) + debug("dns ok:" + addr_port + "-------" + res.text) + dns_delaytable[str(addr_port) + str(target)] = res.text + return dns_delaytable[str(addr_port) + str(target)] except Timeout: # 如果存在旧数据 - if str(addr) + str(target) in dns_delaytable.keys(): + if str(addr_port) + str(target) in dns_delaytable.keys(): pass # 不存在则设0 else: - dns_delaytable[str(addr) + str(target)] = 0 - return dns_delaytable[str(addr) + str(target)] + dns_delaytable[str(addr_port) + str(target)] = 0 + return dns_delaytable[str(addr_port) + str(target)] # 状态感知——DNS记录测试接口 diff --git a/server/apps/task.py b/server/apps/task.py index abac1dd..6f65000 100644 --- a/server/apps/task.py +++ b/server/apps/task.py @@ -6,10 +6,9 @@ from apiflask import APIBlueprint, Schema from apiflask.fields import String, Integer, IP, DateTime, List, Nested from apiflask.validators import OneOf from sqlalchemy import and_ - -from exts import db -from model import TaskPolicy, Task, Agent from sqlalchemy.exc import SQLAlchemyError + +from model import Task, Agent from .agentcomm import deliver_task from .policy import * from .util import error @@ -17,7 +16,6 @@ from .util import error bp = APIBlueprint("任务管理接口集合", __name__, url_prefix="/task") - class TaskSchema(Schema): id = String() target = String() @@ -83,15 +81,15 @@ def valid_task_info(task_param: dict): # 期望注入记录 "target_rr": String(), # 期望策略 - "policy": String(validate=OneOf(["auto", "ddos", "sjqp"])), + "policy": String(validate=OneOf(["auto", "ddos", "sjqp"]), load_default="auto"), # 状态感知方式 - "scan": String(validate=OneOf(["auto", "icmp", "dns", "tcp", "record"])), + "scan": String(validate=OneOf(["auto", "icmp", "dns", "tcp", "record"]), load_default="auto"), # 策略切换时限,单位分钟 - "policy_time": Integer(), + "policy_time": Integer(load_default=60), # 任务执行时限,单位分钟 - "run_time": Integer(), + "run_time": Integer(load_default=600), # 运行配置 - "run_flag": String(validate=OneOf(["now", "man"])) + "run_flag": String(validate=OneOf(["now", "man"]), load_default="now") }, example={'name': "test_task", 'target': "1.2.3.4", 'agent': "8a9ces", 'target_domain': "www.google.com", 'target_rr': "NS ns.ourattack.com", 'policy': "auto", 'scan': "auto", @@ -104,35 +102,35 @@ def valid_task_info(task_param: dict): # TODO: 需要更新接口,created_by def make_task(json_data): task = Task( - task_id = str(uuid.uuid1()), - task_name = json_data["name"], - agent_id = json_data["agent"], + task_id=str(uuid.uuid1()), + task_name=json_data["name"], + agent_id=json_data["agent"], # created_by = "Admin", - target_ip = str(json_data["target"]), + target_ip=str(json_data["target"]), # policy = json_data["policy"], - status = "working" if json_data["run_flag"] == "now" else "stop", - policy_delay = json_data["policy_time"], - task_delay =json_data["run_time"], - target_scan = json_data["scan"], - target_domain = json_data["target_domain"], - target_rtype = "", - target_rr = json_data["target_rr"], + status="working" if json_data["run_flag"] == "now" else "stop", + policy_delay=json_data["policy_time"], + task_delay=json_data["run_time"], + target_scan=json_data["scan"], + target_domain=json_data["target_domain"], + target_rtype="", + target_rr=json_data["target_rr"], ) if task.policy == "sjqp": if task.target_rr == "" or task.target_domain == "": return {"code": 400, "msg": "数据欺骗缺乏目标域名或注入参数"} - + if task.target_rr != "": task.target_rtype, task.target_rr = task.target_rr.split(" ") # 查找所有在线的用于状态感知的agent - agents = db.session.query(Agent).filter(and_(Agent.status==True, Agent.agent_type=='ztgz')).all() + agents = db.session.query(Agent).filter(and_(Agent.status == True, Agent.agent_type == 'ztgz')).all() agents = random.sample(agents, min(10, len(agents))) selected_nodes_info = {} for agent in agents: - selected_nodes_info[agent.agent_id] = agent.ipaddr.split("|")[0] + selected_nodes_info[agent.agent_id] = agent.ipaddr.split("|")[0] + ":" + str(agent.port) task.SCAN_AGENT_ID_LIST = json.dumps(selected_nodes_info) - + # 插入task try: db.session.add(task) @@ -141,7 +139,7 @@ def make_task(json_data): db.session.rollback() error(str(e)) return {"code": 500, "msg": str(e)} - + # 任务策略初始化 tp_id, p_exe, p_param = init_task_policy(json_data["policy"], task.target_ip, task.task_id) @@ -214,7 +212,7 @@ def tasks_state(query_data): query = db.session.query(Task).filter().offset((page - 1) * per_page).limit(per_page) tasks = query.all() task_count = query.count() - + task_list = [] for task in tasks: task_r = {} @@ -325,4 +323,4 @@ def del_task(json_data): except SQLAlchemyError as e: db.session.rollback() error(str(e)) - return {"code": 500, "msg": str(e)}
\ No newline at end of file + return {"code": 500, "msg": str(e)} |
