diff options
Diffstat (limited to 'server/apps/target.py')
| -rw-r--r-- | server/apps/target.py | 205 |
1 files changed, 183 insertions, 22 deletions
diff --git a/server/apps/target.py b/server/apps/target.py index 62aca9a..ffd00f2 100644 --- a/server/apps/target.py +++ b/server/apps/target.py @@ -7,6 +7,7 @@ import threading import asyncio from operator import or_ from concurrent.futures import ThreadPoolExecutor +import time from flask import current_app import requests from apiflask import APIBlueprint, Schema @@ -122,6 +123,121 @@ class CouInfo(Schema): title = String() # 总数量 value = Integer() +# from apiflask.fields import Integer, String, List, Nested, Float + +class ClusterInfo(Schema): + lat = Float(required=True) + lng = Float(required=True) + count = Integer(required=True) + +class TargetInfo(Schema): + ipv4 = String(allow_none=True) + ipv6 = String(allow_none=True) + cou = String() + isp = String() + lat = Float() + lng = Float() + time = String() + protocol = List(String()) + protect = String() + + + +def task_monitoring(task): + with scheduler.app.app_context(): + debug("taskmonitor") + target_status = task.task_policies[-1].target_status + debug(target_status) + addr = task.target.addrv4 if task.target.addrv4 else task.target.addrv6 + + nodes_info = json.loads(task.SCAN_AGENT_ID_LIST) + + # 初始化延迟结果列表 + icmp_results = [] + tcp_results = [] + dns_results = [] + + # 执行 ICMP 查询并收集结果 + with ThreadPoolExecutor() as executor: + for id, ip_port in nodes_info.items(): + future = executor.submit(icmp_delay_query, addr, f"{ip_port}") + icmp_results.append(float(future.result())) + + # 执行 TCP 查询并收集结果 + with ThreadPoolExecutor() as executor: + for id, ip_port in nodes_info.items(): + future = executor.submit(tcp_delay_query, addr, f"{ip_port}") + tcp_results.append(float(future.result())) + + # 执行 DNS 查询并收集结果 + with ThreadPoolExecutor() as executor: + for id, ip_port in nodes_info.items(): + future = executor.submit(dns_delay_query, addr, f"{ip_port}") + dns_results.append(float(future.result())) + + # 计算平均值 + icmp_avg = sum(icmp_results) / len(icmp_results) if icmp_results else 0 + tcp_avg = sum(tcp_results) / len(tcp_results) if tcp_results else 0 + dns_avg = sum(dns_results) / len(dns_results) if dns_results else 0 + # id, ip_port = nodes_info.items()[0] + query_data = { + 'rev' : task.target.addrv4 if task.target.addrv4 else task.target.addrv6, + 'domain' : task.target_domain, + 'qtype' :'A' if task.target.addrv4 else "AAAA" + } + target_domain = get_record(query_data) + try: + first_ip = target_domain[0]["rrset"] if target_domain else None + except: + first_ip = "" + target_status = TargetStatus( + tp_id = task.task_policies[-1].tp_id, + icmp = icmp_avg, + tcp = tcp_avg, + dns = dns_avg, + recorde = first_ip, + ) + + db.session.add(target_status) + db.session.commit() + task = db.session.query(Task).get(task.task_id) + last_task_policy = task.task_policies[-1] + + # 现在可以安全地访问 target_status + target_status = last_task_policy.target_status + debug(target_status) + +def dida_task(task, ): + from .task import effective_detection, finish_task + from .task import adjust_task + with scheduler.app.app_context(): + task = db.session.query(Task).filter_by(task_id = task.task_id).first() + task_policy=db.session.query(TaskPolicy).get(task.task_policies[-1].tp_id) + task_monitoring(task) + # 如果任务没有成功 + if not effective_detection(task_policy=task_policy): + debug(task.status) + if task.status == "stopped": + adjust_task(task=task) + else: + finish_task(task) + pass + debug("didadida") + +def start_task_monitoring(task): + + with scheduler.app.app_context(): + scheduler.add_job( + func=dida_task, # 要执行的函数 + trigger='interval', # 触发器类型为间隔 + args = (task, ), # 传递给函数的参数 + id = task.task_id, # 任务的唯一标识符 + seconds = 30, # 触发器的参数,表示每 5 秒执行一次 + max_instances = 100 + ) + +def stop_task_monitoring(task): + scheduler.remove_job(task.task_id) @bp.get("/nodes") @@ -281,7 +397,7 @@ from apiflask.validators import OneOf, ContainsOnly from dns import resolver def get_record(query_data): - # 特殊协议头 + # 特殊协议头 protols = ["https", "tls"] ans = [] # 参数读取 @@ -329,7 +445,7 @@ def get_record(query_data): }) def record(query_data): ans = get_record(query_data) - return {"code": 200, 'ans': ans} + return {"code": 200, 'ans': ans} @bp.get("/") @@ -451,6 +567,10 @@ def map_info(query_data): max_lat = query_data.get("max_lat") # 获取最大纬度 min_lng = query_data.get("min_lng") # 获取最小经度 max_lng = query_data.get("max_lng") # 获取最大经度 + if zoom_level == 0: + res = db.session.query(Target.target_id, Target.lat, Target.lng).all() + res_dict = [{"target_id": row.target_id, "lat": row.lat, "lng": row.lng} for row in res] + return {"code": 200, "data": res_dict, "total": len(res_dict)} if zoom_level <= 10: # 查询目标数据 @@ -531,35 +651,85 @@ def map_info(query_data): @bp.get("/gz") @bp.doc("目标感知") @bp.input({ - "ip": IP(required=True) + "ip": String(required=True) }, location="query") @bp.output({ "code": Integer(), - "dataObject": List(Nested(TargetSchema())), + "dataObject": Nested(TargetSchema()) }) def target_GZ_API(query_data): - target_GZ(query_data["ip"]) + ip = query_data["ip"] + target = target_GZ(ip) + target_dict = { + "addrv4": target.addrv4, + "addrv6": target.addrv6, + "ipv6": target.ipv6, + "dnssec": target.dnssec, + "dot": target.dot, + "doh": target.doh, + "cou": target.cou, + "isp": target.isp, + "lat": target.lat, + "lng": target.lng, + "protect": target.protect, + "doh_domain": target.doh_domain + } return { 'code': 200, - 'MSG': "success" + 'dataObject': target_dict } def target_GZ(IP_addr): - # 获取 - csgz = db.session.query(Agent).filter_by(agent_type = 'gjst').all() + + existing_obj = db.session.query(Target).filter( + (Target.addrv4 == IP_addr) | (Target.addrv6 == IP_addr) + ).first() + if existing_obj: + return existing_obj + + ipv6 = None + ipv4 = None + + # 判断 IP 地址的版本并存储 + if 6 == is_ipaddress(IP_addr): + ipv6 = IP_addr + elif 4 == is_ipaddress(IP_addr): + ipv4 = IP_addr + + # 获取随机的 agent + csgz = db.session.query(Agent).filter_by(agent_type='gjst').all() csgz = random.choice(csgz) - url = f"http://{csgz.ipaddr}:{csgz.port}/target_gz/{IP_addr}" - protect = requests.get(url) - + # 根据 IP 地址类型构建 URL + if ipv6: + # IPv6 地址需要加上中括号 + url = f"http://{csgz.ipaddr}:{csgz.port}/target_gz/[{ipv6}]" + elif ipv4: + url = f"http://{csgz.ipaddr}:{csgz.port}/target_gz/{ipv4}" + else: + raise ValueError("Invalid IP address") + + # 发送请求 + i = 0 + while i < 30: + protect = requests.get(url) + status_code = protect.status_code + debug(f"目标感知:重试{i}次") + i+=1 + if status_code == 200: + break + else: + time.sleep(0.5) + + url = f'https://ipinfo.io/{IP_addr}/json?token=2c3db02b7ffce3' response = requests.get(url) data = response.json() # 存数据库 target = Target( - addrv4 = IP_addr, - addrv6 = False, + addrv4 = ipv4, + addrv6 = ipv6, ipv6 = (6 == is_ipaddress(IP_addr)), dnssec = json.loads(protect.text)['dnssec_enabled'], dot = False, @@ -571,15 +741,6 @@ def target_GZ(IP_addr): protect = protect.text, doh_domain = None ) - if 6 == is_ipaddress(IP_addr): - target.addrv6 = IP_addr - target.ipv6 = True - elif 4== is_ipaddress(IP_addr): - target.addrv4 = IP_addr - - existing_obj = db.session.query(Target).filter_by(addrv4 = target.addrv4).first() - if existing_obj: - db.session.delete(existing_obj) db.session.add(target) db.session.commit() return target
\ No newline at end of file |
