diff options
| author | shihaoyue <[email protected]> | 2024-05-31 21:33:38 +0800 |
|---|---|---|
| committer | shihaoyue <[email protected]> | 2024-05-31 21:33:38 +0800 |
| commit | 6916840f4e11a58448714faeda691b48dfaa041c (patch) | |
| tree | 5967b28e63db838f6f3018a6bf90ddd7afc8688b /server/apps | |
| parent | ab29fc5c1bd5e616ce725a6e36e972986b560606 (diff) | |
创建任务orm
Diffstat (limited to 'server/apps')
| -rw-r--r-- | server/apps/task.py | 158 |
1 files changed, 46 insertions, 112 deletions
diff --git a/server/apps/task.py b/server/apps/task.py index 4941802..6d1c234 100644 --- a/server/apps/task.py +++ b/server/apps/task.py @@ -5,9 +5,10 @@ import uuid from apiflask import APIBlueprint, Schema from apiflask.fields import String, Integer, IP, DateTime, List, Nested from apiflask.validators import OneOf +from sqlalchemy import and_ from exts import db -from model import TaskPolicy, Task +from model import Policy, TaskPolicy, Task, Agent from sqlalchemy.exc import SQLAlchemyError from .agentcomm import deliver_task from .policy import * @@ -100,120 +101,53 @@ def valid_task_info(task_param: dict): "msg": String() }) # TODO:TARGET处理 +# TODO: 需要更新接口,created_by def make_task(json_data): - # 请求参数转储到变量 - name = json_data["name"] - target = str(json_data["target"]) - agent_id = json_data["agent"] - tdomain = json_data["target_domain"] - target_rr = json_data["target_rr"] - policy_type = json_data["policy"] - scan = json_data["scan"] - policy_time = json_data["policy_time"] - run_flag = json_data["run_flag"] - run_time = json_data["run_time"] - - # 任务编号生成 - t_id = str(uuid.uuid1()) - - # 任务信息生成 - - task_info = { - "t_id": t_id, - "name": name, - "target": target, - "ag_id": agent_id, - "p_delay": policy_time, - "t_delay": run_time, - "scan": scan, - "status": "working" if run_flag == "now" else "stop", - "policy": policy_type, - "tdomain": tdomain, - "target_rr": target_rr, - "target_rtype": "" - } - - # 输入参数检查处理 - warn = valid_task_info(task_info) - if warn is not None: - return {"code": 400, "msg": warn} - if task_info["target_rr"] != "": - rr = task_info["target_rr"].split(" ") - task_info["target_rtype"], task_info["target_rr"] = rr[0], rr[1] - - # 添加表名 - task_info["tab"] = MYSQL_TAB_TASK - # 若目标为IP地址 - if is_ipaddress(target) is not None: - task_info["TARGET"] = "TARGET_IP" - # 若目标为域名,视为DoH服务 - else: - task_info["TARGET"] = "TARGET_DOH" - - # 添加状态探测节点,用于状态感知 - scan_node_sql = """ - SELECT AGENT_ID,IPADDR - FROM %s - WHERE STATUS='1' AND AGENT_TYPE='ztgz' - LIMIT 50; - """ % MYSQL_TAB_AGENT - da.cursor.execute(scan_node_sql) - # 所有在线的支持状态感知的代理 - all_nodes = da.cursor.fetchall() - # 随机选择10个 - selected_nodes = random.choices(all_nodes, k=10 if len(all_nodes) > 10 else len(all_nodes)) - # 将选择的代理节点ID和对应IPv4地址存储为字典,键为代理ID,值为代理的IPv4地址 - selected_nodes_info = {'{}'.format(n['AGENT_ID']): '{}'.format(str(n['IPADDR']).split("|")[0]) for n in - selected_nodes} - task_info["nodes"] = json.dumps(selected_nodes_info) - - # sql语句组合 - sql = """ - INSERT INTO %(tab)s ( - TASK_ID, - TASK_NAME, - AGENT_ID, - STATUS, - POLICY_DELAY, - TASK_DELAY, - TARGET_SCAN, - TARGET_DOMAIN, - TARGET_RTYPE, - TARGET_RR, - %(TARGET)s, - SCAN_AGENT_ID_LIST) - VALUES ( - '%(t_id)s', - '%(name)s', - '%(ag_id)s', - '%(status)s', - '%(p_delay)s', - '%(t_delay)s', - '%(scan)s', - '%(tdomain)s', - '%(target_rtype)s', - '%(target_rr)s', - '%(target)s', - '%(nodes)s' - ); - """ % task_info - da.conn.ping(reconnect=True) - da.cursor.execute(sql) - da.conn.commit() - + task = Task( + task_id = str(uuid.uuid1()), + task_name = json_data["name"], + agent_id = json_data["agent"], + # created_by = "Admin", + target_ip = str(json_data["target"]), + # policy = json_data["policy"], + status = "working" if json_data["run_flag"] == "now" else "stop", + policy_delay = json_data["policy_time"], + task_delay =json_data["run_time"], + target_scan = json_data["scan"], + target_domain = json_data["target_domain"], + target_rtype = "", + target_rr = json_data["target_rr"], + ) + if task.policy == "sjqp": + if task.target_rr == "" or task.target_domain == "": + return {"code": 400, "msg": "数据欺骗缺乏目标域名或注入参数"} + + if task.target_rr != "": + task.target_rtype, task.target_rr = task.target_rr.split(" ") + + # 查找所有在线的用于状态感知的agent + agents = db.session.query(Agent).filter(and_(Agent.status==True, Agent.agent_type=='ztgz')).all() + agents = random.sample(agents, min(10, len(agents))) + selected_nodes_info = {} + for agent in agents: + selected_nodes_info[agent.agent_id] = agent.ipaddr.split("|")[0] + task.SCAN_AGENT_ID_LIST = json.dumps(selected_nodes_info) + + # 插入task + try: + db.session.add(task) + db.session.commit() + except SQLAlchemyError as e: + db.session.rollback() + error(str(e)) + return {"code": 500, "msg": str(e)} + # 任务策略初始化 - tp_id, p_exe, p_param = init_task_policy(policy_type, target, t_id) - update_policy_sql = """ - UPDATE %s - SET POLICY='%s' - WHERE TASK_ID='%s'; - """ % (MYSQL_TAB_TASK, tp_id, t_id) - da.cursor.execute(update_policy_sql) - da.conn.commit() + tp_id, p_exe, p_param = init_task_policy(json_data["policy"], task.target_ip, task.task_id) - # 根据run_flag判断是否立刻执行 - if run_flag == "now": - err = deliver_task(agent_id, p_exe, p_param) + # 根据状态判断是否立刻执行 + if task.status == "working": + err = deliver_task(task.agent_id, p_exe, p_param) if err is not None: error(str(err)) return {"code": 500, "msg": str(err)} |
