summaryrefslogtreecommitdiff
path: root/server/apps
diff options
context:
space:
mode:
authorshihaoyue <[email protected]>2024-05-31 21:33:38 +0800
committershihaoyue <[email protected]>2024-05-31 21:33:38 +0800
commit6916840f4e11a58448714faeda691b48dfaa041c (patch)
tree5967b28e63db838f6f3018a6bf90ddd7afc8688b /server/apps
parentab29fc5c1bd5e616ce725a6e36e972986b560606 (diff)
创建任务orm
Diffstat (limited to 'server/apps')
-rw-r--r--server/apps/task.py158
1 files changed, 46 insertions, 112 deletions
diff --git a/server/apps/task.py b/server/apps/task.py
index 4941802..6d1c234 100644
--- a/server/apps/task.py
+++ b/server/apps/task.py
@@ -5,9 +5,10 @@ import uuid
from apiflask import APIBlueprint, Schema
from apiflask.fields import String, Integer, IP, DateTime, List, Nested
from apiflask.validators import OneOf
+from sqlalchemy import and_
from exts import db
-from model import TaskPolicy, Task
+from model import Policy, TaskPolicy, Task, Agent
from sqlalchemy.exc import SQLAlchemyError
from .agentcomm import deliver_task
from .policy import *
@@ -100,120 +101,53 @@ def valid_task_info(task_param: dict):
"msg": String()
})
# TODO:TARGET处理
+# TODO: 需要更新接口,created_by
def make_task(json_data):
- # 请求参数转储到变量
- name = json_data["name"]
- target = str(json_data["target"])
- agent_id = json_data["agent"]
- tdomain = json_data["target_domain"]
- target_rr = json_data["target_rr"]
- policy_type = json_data["policy"]
- scan = json_data["scan"]
- policy_time = json_data["policy_time"]
- run_flag = json_data["run_flag"]
- run_time = json_data["run_time"]
-
- # 任务编号生成
- t_id = str(uuid.uuid1())
-
- # 任务信息生成
-
- task_info = {
- "t_id": t_id,
- "name": name,
- "target": target,
- "ag_id": agent_id,
- "p_delay": policy_time,
- "t_delay": run_time,
- "scan": scan,
- "status": "working" if run_flag == "now" else "stop",
- "policy": policy_type,
- "tdomain": tdomain,
- "target_rr": target_rr,
- "target_rtype": ""
- }
-
- # 输入参数检查处理
- warn = valid_task_info(task_info)
- if warn is not None:
- return {"code": 400, "msg": warn}
- if task_info["target_rr"] != "":
- rr = task_info["target_rr"].split(" ")
- task_info["target_rtype"], task_info["target_rr"] = rr[0], rr[1]
-
- # 添加表名
- task_info["tab"] = MYSQL_TAB_TASK
- # 若目标为IP地址
- if is_ipaddress(target) is not None:
- task_info["TARGET"] = "TARGET_IP"
- # 若目标为域名,视为DoH服务
- else:
- task_info["TARGET"] = "TARGET_DOH"
-
- # 添加状态探测节点,用于状态感知
- scan_node_sql = """
- SELECT AGENT_ID,IPADDR
- FROM %s
- WHERE STATUS='1' AND AGENT_TYPE='ztgz'
- LIMIT 50;
- """ % MYSQL_TAB_AGENT
- da.cursor.execute(scan_node_sql)
- # 所有在线的支持状态感知的代理
- all_nodes = da.cursor.fetchall()
- # 随机选择10个
- selected_nodes = random.choices(all_nodes, k=10 if len(all_nodes) > 10 else len(all_nodes))
- # 将选择的代理节点ID和对应IPv4地址存储为字典,键为代理ID,值为代理的IPv4地址
- selected_nodes_info = {'{}'.format(n['AGENT_ID']): '{}'.format(str(n['IPADDR']).split("|")[0]) for n in
- selected_nodes}
- task_info["nodes"] = json.dumps(selected_nodes_info)
-
- # sql语句组合
- sql = """
- INSERT INTO %(tab)s (
- TASK_ID,
- TASK_NAME,
- AGENT_ID,
- STATUS,
- POLICY_DELAY,
- TASK_DELAY,
- TARGET_SCAN,
- TARGET_DOMAIN,
- TARGET_RTYPE,
- TARGET_RR,
- %(TARGET)s,
- SCAN_AGENT_ID_LIST)
- VALUES (
- '%(t_id)s',
- '%(name)s',
- '%(ag_id)s',
- '%(status)s',
- '%(p_delay)s',
- '%(t_delay)s',
- '%(scan)s',
- '%(tdomain)s',
- '%(target_rtype)s',
- '%(target_rr)s',
- '%(target)s',
- '%(nodes)s'
- );
- """ % task_info
- da.conn.ping(reconnect=True)
- da.cursor.execute(sql)
- da.conn.commit()
-
+ task = Task(
+ task_id = str(uuid.uuid1()),
+ task_name = json_data["name"],
+ agent_id = json_data["agent"],
+ # created_by = "Admin",
+ target_ip = str(json_data["target"]),
+ # policy = json_data["policy"],
+ status = "working" if json_data["run_flag"] == "now" else "stop",
+ policy_delay = json_data["policy_time"],
+ task_delay =json_data["run_time"],
+ target_scan = json_data["scan"],
+ target_domain = json_data["target_domain"],
+ target_rtype = "",
+ target_rr = json_data["target_rr"],
+ )
+ if task.policy == "sjqp":
+ if task.target_rr == "" or task.target_domain == "":
+ return {"code": 400, "msg": "数据欺骗缺乏目标域名或注入参数"}
+
+ if task.target_rr != "":
+ task.target_rtype, task.target_rr = task.target_rr.split(" ")
+
+ # 查找所有在线的用于状态感知的agent
+ agents = db.session.query(Agent).filter(and_(Agent.status==True, Agent.agent_type=='ztgz')).all()
+ agents = random.sample(agents, min(10, len(agents)))
+ selected_nodes_info = {}
+ for agent in agents:
+ selected_nodes_info[agent.agent_id] = agent.ipaddr.split("|")[0]
+ task.SCAN_AGENT_ID_LIST = json.dumps(selected_nodes_info)
+
+ # 插入task
+ try:
+ db.session.add(task)
+ db.session.commit()
+ except SQLAlchemyError as e:
+ db.session.rollback()
+ error(str(e))
+ return {"code": 500, "msg": str(e)}
+
# 任务策略初始化
- tp_id, p_exe, p_param = init_task_policy(policy_type, target, t_id)
- update_policy_sql = """
- UPDATE %s
- SET POLICY='%s'
- WHERE TASK_ID='%s';
- """ % (MYSQL_TAB_TASK, tp_id, t_id)
- da.cursor.execute(update_policy_sql)
- da.conn.commit()
+ tp_id, p_exe, p_param = init_task_policy(json_data["policy"], task.target_ip, task.task_id)
- # 根据run_flag判断是否立刻执行
- if run_flag == "now":
- err = deliver_task(agent_id, p_exe, p_param)
+ # 根据状态判断是否立刻执行
+ if task.status == "working":
+ err = deliver_task(task.agent_id, p_exe, p_param)
if err is not None:
error(str(err))
return {"code": 500, "msg": str(err)}