summaryrefslogtreecommitdiff
path: root/server/apps/target.py
diff options
context:
space:
mode:
authorshihaoyue <[email protected]>2024-09-20 09:07:10 +0800
committershihaoyue <[email protected]>2024-09-20 09:07:10 +0800
commit5d07e2a4e2f5e93c9f4699c49cbcb52c38aebbee (patch)
tree8f756f0c014cdfc87412224d9569f1e21bb5ff19 /server/apps/target.py
parent78575c5a7322693359d35c4f3d6e9d9698c5188e (diff)
# 重大更新 自动化任务,极其不稳定
Diffstat (limited to 'server/apps/target.py')
-rw-r--r--server/apps/target.py205
1 files changed, 183 insertions, 22 deletions
diff --git a/server/apps/target.py b/server/apps/target.py
index 62aca9a..ffd00f2 100644
--- a/server/apps/target.py
+++ b/server/apps/target.py
@@ -7,6 +7,7 @@ import threading
import asyncio
from operator import or_
from concurrent.futures import ThreadPoolExecutor
+import time
from flask import current_app
import requests
from apiflask import APIBlueprint, Schema
@@ -122,6 +123,121 @@ class CouInfo(Schema):
title = String()
# 总数量
value = Integer()
+# from apiflask.fields import Integer, String, List, Nested, Float
+
+class ClusterInfo(Schema):
+ lat = Float(required=True)
+ lng = Float(required=True)
+ count = Integer(required=True)
+
+class TargetInfo(Schema):
+ ipv4 = String(allow_none=True)
+ ipv6 = String(allow_none=True)
+ cou = String()
+ isp = String()
+ lat = Float()
+ lng = Float()
+ time = String()
+ protocol = List(String())
+ protect = String()
+
+
+
+def task_monitoring(task):
+ with scheduler.app.app_context():
+ debug("taskmonitor")
+ target_status = task.task_policies[-1].target_status
+ debug(target_status)
+ addr = task.target.addrv4 if task.target.addrv4 else task.target.addrv6
+
+ nodes_info = json.loads(task.SCAN_AGENT_ID_LIST)
+
+ # 初始化延迟结果列表
+ icmp_results = []
+ tcp_results = []
+ dns_results = []
+
+ # 执行 ICMP 查询并收集结果
+ with ThreadPoolExecutor() as executor:
+ for id, ip_port in nodes_info.items():
+ future = executor.submit(icmp_delay_query, addr, f"{ip_port}")
+ icmp_results.append(float(future.result()))
+
+ # 执行 TCP 查询并收集结果
+ with ThreadPoolExecutor() as executor:
+ for id, ip_port in nodes_info.items():
+ future = executor.submit(tcp_delay_query, addr, f"{ip_port}")
+ tcp_results.append(float(future.result()))
+
+ # 执行 DNS 查询并收集结果
+ with ThreadPoolExecutor() as executor:
+ for id, ip_port in nodes_info.items():
+ future = executor.submit(dns_delay_query, addr, f"{ip_port}")
+ dns_results.append(float(future.result()))
+
+ # 计算平均值
+ icmp_avg = sum(icmp_results) / len(icmp_results) if icmp_results else 0
+ tcp_avg = sum(tcp_results) / len(tcp_results) if tcp_results else 0
+ dns_avg = sum(dns_results) / len(dns_results) if dns_results else 0
+ # id, ip_port = nodes_info.items()[0]
+ query_data = {
+ 'rev' : task.target.addrv4 if task.target.addrv4 else task.target.addrv6,
+ 'domain' : task.target_domain,
+ 'qtype' :'A' if task.target.addrv4 else "AAAA"
+ }
+ target_domain = get_record(query_data)
+ try:
+ first_ip = target_domain[0]["rrset"] if target_domain else None
+ except:
+ first_ip = ""
+ target_status = TargetStatus(
+ tp_id = task.task_policies[-1].tp_id,
+ icmp = icmp_avg,
+ tcp = tcp_avg,
+ dns = dns_avg,
+ recorde = first_ip,
+ )
+
+ db.session.add(target_status)
+ db.session.commit()
+ task = db.session.query(Task).get(task.task_id)
+ last_task_policy = task.task_policies[-1]
+
+ # 现在可以安全地访问 target_status
+ target_status = last_task_policy.target_status
+ debug(target_status)
+
+def dida_task(task, ):
+ from .task import effective_detection, finish_task
+ from .task import adjust_task
+ with scheduler.app.app_context():
+ task = db.session.query(Task).filter_by(task_id = task.task_id).first()
+ task_policy=db.session.query(TaskPolicy).get(task.task_policies[-1].tp_id)
+ task_monitoring(task)
+ # 如果任务没有成功
+ if not effective_detection(task_policy=task_policy):
+ debug(task.status)
+ if task.status == "stopped":
+ adjust_task(task=task)
+ else:
+ finish_task(task)
+ pass
+ debug("didadida")
+
+def start_task_monitoring(task):
+
+ with scheduler.app.app_context():
+ scheduler.add_job(
+ func=dida_task, # 要执行的函数
+ trigger='interval', # 触发器类型为间隔
+ args = (task, ), # 传递给函数的参数
+ id = task.task_id, # 任务的唯一标识符
+ seconds = 30, # 触发器的参数,表示每 5 秒执行一次
+ max_instances = 100
+ )
+
+def stop_task_monitoring(task):
+ scheduler.remove_job(task.task_id)
@bp.get("/nodes")
@@ -281,7 +397,7 @@ from apiflask.validators import OneOf, ContainsOnly
from dns import resolver
def get_record(query_data):
- # 特殊协议头
+ # 特殊协议头
protols = ["https", "tls"]
ans = []
# 参数读取
@@ -329,7 +445,7 @@ def get_record(query_data):
})
def record(query_data):
ans = get_record(query_data)
- return {"code": 200, 'ans': ans}
+ return {"code": 200, 'ans': ans}
@bp.get("/")
@@ -451,6 +567,10 @@ def map_info(query_data):
max_lat = query_data.get("max_lat") # 获取最大纬度
min_lng = query_data.get("min_lng") # 获取最小经度
max_lng = query_data.get("max_lng") # 获取最大经度
+ if zoom_level == 0:
+ res = db.session.query(Target.target_id, Target.lat, Target.lng).all()
+ res_dict = [{"target_id": row.target_id, "lat": row.lat, "lng": row.lng} for row in res]
+ return {"code": 200, "data": res_dict, "total": len(res_dict)}
if zoom_level <= 10:
# 查询目标数据
@@ -531,35 +651,85 @@ def map_info(query_data):
@bp.get("/gz")
@bp.doc("目标感知")
@bp.input({
- "ip": IP(required=True)
+ "ip": String(required=True)
}, location="query")
@bp.output({
"code": Integer(),
- "dataObject": List(Nested(TargetSchema())),
+ "dataObject": Nested(TargetSchema())
})
def target_GZ_API(query_data):
- target_GZ(query_data["ip"])
+ ip = query_data["ip"]
+ target = target_GZ(ip)
+ target_dict = {
+ "addrv4": target.addrv4,
+ "addrv6": target.addrv6,
+ "ipv6": target.ipv6,
+ "dnssec": target.dnssec,
+ "dot": target.dot,
+ "doh": target.doh,
+ "cou": target.cou,
+ "isp": target.isp,
+ "lat": target.lat,
+ "lng": target.lng,
+ "protect": target.protect,
+ "doh_domain": target.doh_domain
+ }
return {
'code': 200,
- 'MSG': "success"
+ 'dataObject': target_dict
}
def target_GZ(IP_addr):
- # 获取
- csgz = db.session.query(Agent).filter_by(agent_type = 'gjst').all()
+
+ existing_obj = db.session.query(Target).filter(
+ (Target.addrv4 == IP_addr) | (Target.addrv6 == IP_addr)
+ ).first()
+ if existing_obj:
+ return existing_obj
+
+ ipv6 = None
+ ipv4 = None
+
+ # 判断 IP 地址的版本并存储
+ if 6 == is_ipaddress(IP_addr):
+ ipv6 = IP_addr
+ elif 4 == is_ipaddress(IP_addr):
+ ipv4 = IP_addr
+
+ # 获取随机的 agent
+ csgz = db.session.query(Agent).filter_by(agent_type='gjst').all()
csgz = random.choice(csgz)
- url = f"http://{csgz.ipaddr}:{csgz.port}/target_gz/{IP_addr}"
- protect = requests.get(url)
-
+ # 根据 IP 地址类型构建 URL
+ if ipv6:
+ # IPv6 地址需要加上中括号
+ url = f"http://{csgz.ipaddr}:{csgz.port}/target_gz/[{ipv6}]"
+ elif ipv4:
+ url = f"http://{csgz.ipaddr}:{csgz.port}/target_gz/{ipv4}"
+ else:
+ raise ValueError("Invalid IP address")
+
+ # 发送请求
+ i = 0
+ while i < 30:
+ protect = requests.get(url)
+ status_code = protect.status_code
+ debug(f"目标感知:重试{i}次")
+ i+=1
+ if status_code == 200:
+ break
+ else:
+ time.sleep(0.5)
+
+
url = f'https://ipinfo.io/{IP_addr}/json?token=2c3db02b7ffce3'
response = requests.get(url)
data = response.json()
# 存数据库
target = Target(
- addrv4 = IP_addr,
- addrv6 = False,
+ addrv4 = ipv4,
+ addrv6 = ipv6,
ipv6 = (6 == is_ipaddress(IP_addr)),
dnssec = json.loads(protect.text)['dnssec_enabled'],
dot = False,
@@ -571,15 +741,6 @@ def target_GZ(IP_addr):
protect = protect.text,
doh_domain = None
)
- if 6 == is_ipaddress(IP_addr):
- target.addrv6 = IP_addr
- target.ipv6 = True
- elif 4== is_ipaddress(IP_addr):
- target.addrv4 = IP_addr
-
- existing_obj = db.session.query(Target).filter_by(addrv4 = target.addrv4).first()
- if existing_obj:
- db.session.delete(existing_obj)
db.session.add(target)
db.session.commit()
return target \ No newline at end of file