From 4c63d78e4e42aece5e5eaa6e111e7d5fc20aabb8 Mon Sep 17 00:00:00 2001 From: MDK Date: Tue, 26 Sep 2023 14:12:04 +0800 Subject: the project structure modified and new features added --- cmd/cache.go | 7 +++-- cmd/record.go | 47 +++++++++++++++++++++++++++++++ cmd/upstream.go | 2 +- cmd/version.go | 20 -------------- dtool | Bin 0 -> 9689632 bytes prober/cache_prober.go | 8 ++++-- prober/record_prober.go | 38 +++++++++++++++++++++++++ prober/result_handler.go | 2 ++ prober/scheduler.go | 67 +++++++++++++++++++++++++++++++++++++++++++++ prober/version_prober.go | 1 - scheduler/scheduler.go | 68 --------------------------------------------- utils/dns_utils.go | 70 +++++++++++++++++++++++++++++++++++++++++++---- utils/httpdns_utils.go | 6 ++-- utils/input_utils.go | 2 +- utils/other_utils.go | 5 +--- 15 files changed, 235 insertions(+), 108 deletions(-) create mode 100644 cmd/record.go delete mode 100644 cmd/version.go create mode 100755 dtool create mode 100644 prober/record_prober.go create mode 100644 prober/scheduler.go delete mode 100644 prober/version_prober.go delete mode 100644 scheduler/scheduler.go diff --git a/cmd/cache.go b/cmd/cache.go index c72af16..5e17fdd 100644 --- a/cmd/cache.go +++ b/cmd/cache.go @@ -2,7 +2,6 @@ package cmd import ( "dtool/prober" - "dtool/scheduler" "dtool/utils" "github.com/spf13/cobra" @@ -11,6 +10,8 @@ import ( var query_cnt int var inputfile string var outputfile string +var goroutine_num int +var controlled_domain string var cacheCmd = &cobra.Command{ Use: "cache", Short: "cache related test", @@ -24,14 +25,16 @@ func cache_test(cmd *cobra.Command, args []string) { prober.RecursiveCacheTest(args[0], query_cnt) } } else { - scheduler.CreateTask(prober.RecursiveCacheProbe, inputfile, outputfile, 10) + prober.CreateTask(prober.RecursiveCacheProbe, controlled_domain, inputfile, outputfile, goroutine_num) } } func init() { + cacheCmd.Flags().StringVarP(&controlled_domain, "domain", "d", "echodns.xyz", "controlled domain") cacheCmd.Flags().StringVarP(&inputfile, "input", "i", "", "input file(optional)") cacheCmd.Flags().StringVarP(&outputfile, "output", "o", "", "output file(optional)") cacheCmd.MarkFlagsRequiredTogether("input", "output") cacheCmd.Flags().IntVarP(&query_cnt, "num", "n", 20, "number of queries in one test") + cacheCmd.Flags().IntVarP(&goroutine_num, "concurrency", "t", 150, "number of goroutine") rootCmd.AddCommand(cacheCmd) } diff --git a/cmd/record.go b/cmd/record.go new file mode 100644 index 0000000..9ede8d7 --- /dev/null +++ b/cmd/record.go @@ -0,0 +1,47 @@ +package cmd + +import ( + "dtool/prober" + "dtool/utils" + + "fmt" + + "github.com/spf13/cobra" +) + +var record_input string +var record_output string +var record_type string +var record_domain string +var recordCmd = &cobra.Command{ + Use: "record", + Short: "get specific record response", + Long: "get specific record response", + Run: record_probe, +} + +func record_probe(cmd *cobra.Command, args []string) { + if len(args) == 1 { + if utils.IsValidIP(args[0]) { + result, err := prober.SVCBProbeOnce(args[0], record_domain) + if err == nil { + if output_str, err := prober.OutputHandler(result); err == nil { + fmt.Println(output_str) + } + } else { + fmt.Println(err) + } + } + } else { + prober.CreateTask(prober.SVCBProbe, record_domain, record_input, record_output, 500) + } +} + +func init() { + recordCmd.Flags().StringVarP(&record_input, "input", "i", "", "input file(optional)") + recordCmd.Flags().StringVarP(&record_output, "output", "o", "", "output file(optional)") + recordCmd.MarkFlagsRequiredTogether("input", "output") + recordCmd.Flags().StringVarP(&record_type, "type", "t", "A", "request record type") + recordCmd.Flags().StringVarP(&record_domain, "domain", "d", "example.com", "requested domain") + rootCmd.AddCommand(recordCmd) +} diff --git a/cmd/upstream.go b/cmd/upstream.go index c40528b..a9f40ef 100644 --- a/cmd/upstream.go +++ b/cmd/upstream.go @@ -27,7 +27,7 @@ input target can be added as an argument or as a file func upstream(cmd *cobra.Command, args []string) { if len(args) > 1 { - panic(errors.New("too many arguments!")) + panic(errors.New("too many arguments")) } else if len(args) == 1 { if utils.IsValidIP(args[0]) { prober.Get_upstream_ip(args[0]) diff --git a/cmd/version.go b/cmd/version.go deleted file mode 100644 index 3c6712b..0000000 --- a/cmd/version.go +++ /dev/null @@ -1,20 +0,0 @@ -package cmd - -import ( - "github.com/spf13/cobra" -) - -var versionCmd = &cobra.Command{ - Use: "version", - Short: "get server version with version.bind", - Long: "get server version with version.bind chaos txt request", - Run: version, -} - -func version(cmd *cobra.Command, args []string) { - -} - -func init() { - rootCmd.AddCommand(versionCmd) -} diff --git a/dtool b/dtool new file mode 100755 index 0000000..3f38d77 Binary files /dev/null and b/dtool differ diff --git a/prober/cache_prober.go b/prober/cache_prober.go index b8dd9af..94adeb0 100644 --- a/prober/cache_prober.go +++ b/prober/cache_prober.go @@ -6,6 +6,8 @@ import ( "strconv" "strings" "time" + + "github.com/miekg/dns" ) const query_num = 20 @@ -15,7 +17,7 @@ type CacheStruct struct { dict map[int]map[string]bool } -func RecursiveCacheProbe(ip string) CacheStruct { +func RecursiveCacheProbe(ip string, sld string) CacheStruct { data := CacheStruct{ip, make(map[int]map[string]bool)} stop := 0 time_now := strconv.FormatInt(time.Now().Unix(), 10) @@ -24,7 +26,7 @@ func RecursiveCacheProbe(ip string) CacheStruct { break } subdomain := strings.Join([]string{strings.Replace(ip, ".", "-", -1), "fwd", strconv.Itoa(i), time_now}, "-") - domain := subdomain + ".echodns.xyz." + domain := dns.Fqdn(subdomain + "." + sld) res, err := utils.SendQuery(ip, domain) if err != nil { //fmt.Printf("Error sending query: %s\n", err) @@ -49,7 +51,7 @@ func RecursiveCacheProbe(ip string) CacheStruct { func RecursiveCacheTest(ip string, num int) { res := make(map[string]map[int][]string) temp := make(map[int][]string) - data := RecursiveCacheProbe(ip) + data := RecursiveCacheProbe(ip, "echodns.xyz") if len(data.dict) > 0 { for cache_id := range data.dict { for rdns := range data.dict[cache_id] { diff --git a/prober/record_prober.go b/prober/record_prober.go new file mode 100644 index 0000000..e976477 --- /dev/null +++ b/prober/record_prober.go @@ -0,0 +1,38 @@ +package prober + +import ( + "dtool/utils" +) + +type SVCBResult struct { + Ip string `json:"ip"` + Response utils.SVCBResponse `json:"response"` +} + +func SVCBProbeOnce(ip string, domain string) (SVCBResult, error) { + result := SVCBResult{Ip: ip} + res, err := utils.SendSVCBQuery(ip, domain) + if err != nil { + return result, err + } + resp, err := utils.ParseSVCBResponse(res) + if err != nil { + return result, err + } + result.Response = resp + return result, nil +} + +func SVCBProbe(ip string, domain string) SVCBResult { + result := SVCBResult{Ip: ip} + res, err := utils.SendSVCBQuery(ip, domain) + if err != nil { + return result + } + resp, err := utils.ParseSVCBResponse(res) + if err != nil { + return result + } + result.Response = resp + return result +} diff --git a/prober/result_handler.go b/prober/result_handler.go index 35f85f7..470c922 100644 --- a/prober/result_handler.go +++ b/prober/result_handler.go @@ -26,6 +26,8 @@ func OutputHandler(data interface{}) (string, error) { } result[value.target] = temp output_str, err = utils.ToJSON(result, "") + case SVCBResult: + output_str, err = utils.ToJSON(data, "") } return output_str, err } diff --git a/prober/scheduler.go b/prober/scheduler.go new file mode 100644 index 0000000..2ffe4a2 --- /dev/null +++ b/prober/scheduler.go @@ -0,0 +1,67 @@ +package prober + +import ( + "bufio" + "dtool/utils" + "fmt" + "os" + "sync" +) + +//type ProbeTask func(string) interface{} + +func output_process(output chan interface{}, file string, wg *sync.WaitGroup) { + f, err := os.Create(file) + if err != nil { + panic(err) + } + defer f.Close() + writer := bufio.NewWriter(f) + for { + if data, ok := <-output; ok { + str, err := OutputHandler(data) + if err != nil { + fmt.Printf("Error generating output: %s\n", err) + continue + } + _, err = writer.WriteString(str + "\n") + if err != nil { + fmt.Printf("Error writing file: %s\n", err) + } + } else { + break + } + } + writer.Flush() + wg.Done() +} + +func concurrent_execution[T any](fn func(string, string) T, domain string, input chan string, output chan interface{}, wg *sync.WaitGroup) { + for { + if ip, ok := <-input; ok { + data := fn(ip, domain) + output <- data + } else { + break + } + } + wg.Done() +} + +func CreateTask[T any](fn func(string, string) T, domain string, input_file string, output_file string, concurrent_num int) { + input_pool := make(chan string, 500) + output_pool := make(chan interface{}, 500) + var probe_tasks sync.WaitGroup + var store_tasks sync.WaitGroup + + go utils.RetrieveLines(input_pool, input_file) + probe_tasks.Add(concurrent_num) + for i := 0; i < concurrent_num; i++ { + go concurrent_execution(fn, domain, input_pool, output_pool, &probe_tasks) + } + store_tasks.Add(1) + go output_process(output_pool, output_file, &store_tasks) + probe_tasks.Wait() + close(output_pool) + store_tasks.Wait() +} diff --git a/prober/version_prober.go b/prober/version_prober.go deleted file mode 100644 index 8219d21..0000000 --- a/prober/version_prober.go +++ /dev/null @@ -1 +0,0 @@ -package prober diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go deleted file mode 100644 index 496851c..0000000 --- a/scheduler/scheduler.go +++ /dev/null @@ -1,68 +0,0 @@ -package scheduler - -import ( - "bufio" - "dtool/prober" - "dtool/utils" - "fmt" - "os" - "sync" -) - -//type ProbeTask func(string) interface{} - -func output_process(output chan interface{}, file string, wg *sync.WaitGroup) { - f, err := os.Create(file) - if err != nil { - panic(err) - } - defer f.Close() - writer := bufio.NewWriter(f) - for { - if data, ok := <-output; ok { - str, err := prober.OutputHandler(data) - if err != nil { - fmt.Printf("Error generating output: %s\n", err) - continue - } - _, err = writer.WriteString(str + "\n") - if err != nil { - fmt.Printf("Error writing file: %s\n", err) - } - } else { - break - } - } - writer.Flush() - wg.Done() -} - -func concurrent_execution[T any](fn func(string) T, input chan string, output chan interface{}, wg *sync.WaitGroup) { - for { - if ip, ok := <-input; ok { - data := fn(ip) - output <- data - } else { - break - } - } - wg.Done() -} - -func CreateTask[T any](fn func(string) T, input_file string, output_file string, concurrent_num int) { - input_pool := make(chan string, 500) - output_pool := make(chan interface{}, 500) - var probe_tasks sync.WaitGroup - var store_tasks sync.WaitGroup - - go utils.RetrieveLines(input_pool, input_file) - probe_tasks.Add(concurrent_num) - for i := 0; i < concurrent_num; i++ { - go concurrent_execution(fn, input_pool, output_pool, &probe_tasks) - } - store_tasks.Add(1) - go output_process(output_pool, output_file, &store_tasks) - probe_tasks.Wait() - close(output_pool) - store_tasks.Wait() -} diff --git a/utils/dns_utils.go b/utils/dns_utils.go index 9b398ca..641968a 100644 --- a/utils/dns_utils.go +++ b/utils/dns_utils.go @@ -27,6 +27,15 @@ type DomainInfo struct { NSList map[string][]string } +type SVCBRecord struct { + Target string `json:"target"` + Data map[string]string `json:"value"` +} + +type SVCBResponse struct { + Records []SVCBRecord `json:"records"` +} + func (e *WrongAnswerError) Error() string { return fmt.Sprintf("Wrong Answer: %s", e.Message) } @@ -39,11 +48,7 @@ func QuestionMaker(domain string, qclass uint16, qtype uint16) *dns.Question { // build a dns query message func QueryMaker(query QueryStruct) *dns.Msg { msg := new(dns.Msg) - if query.Id < 0 { - msg.Id = dns.Id() - } else { - msg.Id = query.Id - } + msg.Id = query.Id msg.RecursionDesired = query.RD var query_name string @@ -155,3 +160,58 @@ func SendVersionQuery(ip string) (*dns.Msg, error) { return res, err } + +func SendSVCBQuery(ip string, domain string) (*dns.Msg, error) { + addr := ip + ":53" + query := QueryStruct{Id: dns.Id(), Qname: dns.Fqdn(domain), RD: true, Qtype: dns.TypeSVCB} + m := QueryMaker(query) + + res, err := dns.Exchange(m, addr) + + return res, err +} + +func toSVCBKEY(key dns.SVCBKey) string { + switch key { + case dns.SVCB_MANDATORY: + return "mandatory" + case dns.SVCB_ALPN: + return "alpn" + case dns.SVCB_NO_DEFAULT_ALPN: + return "no_default_alpn" + case dns.SVCB_PORT: + return "port" + case dns.SVCB_IPV4HINT: + return "ipv4_hint" + case dns.SVCB_ECHCONFIG: + return "ech_config" + case dns.SVCB_IPV6HINT: + return "ipv6_hint" + case dns.SVCB_DOHPATH: + return "doh_path" + default: + return "unknown" + } +} + +func ParseSVCBResponse(msg *dns.Msg) (SVCBResponse, error) { + response := SVCBResponse{Records: make([]SVCBRecord, 0)} + if len(msg.Answer) > 0 { + for _, rr := range msg.Answer { + if svcb, ok := rr.(*dns.SVCB); ok { + record := SVCBRecord{Data: make(map[string]string)} + record.Target = svcb.Target + for _, kv := range svcb.Value { + key := toSVCBKEY(kv.Key()) + value := kv.String() + record.Data[key] = value + } + response.Records = append(response.Records, record) + } + } + } + if len(response.Records) > 0 { + return response, nil + } + return response, &WrongAnswerError{Message: "no valid SVCB records"} +} diff --git a/utils/httpdns_utils.go b/utils/httpdns_utils.go index 11c891d..ba8ce8d 100644 --- a/utils/httpdns_utils.go +++ b/utils/httpdns_utils.go @@ -5,7 +5,7 @@ import ( "crypto/des" "encoding/hex" "fmt" - "io/ioutil" + "io" "net/http" ) @@ -75,7 +75,7 @@ func SendTencentHttpdnsQuery() { } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { fmt.Printf("read content failed. Error: %s\n", err) return @@ -102,7 +102,7 @@ func SendAlicloudHttpdnsQurey() { } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { fmt.Printf("read content failed. Error: %s\n", err) return diff --git a/utils/input_utils.go b/utils/input_utils.go index a968f07..990eb15 100644 --- a/utils/input_utils.go +++ b/utils/input_utils.go @@ -24,7 +24,7 @@ func RetrieveLines(pool chan string, filename string) { s = strings.Trim(s, "\n") pool <- s cnt++ - if cnt%10 == 0 { + if cnt%1000 == 0 { fmt.Println(cnt) } } diff --git a/utils/other_utils.go b/utils/other_utils.go index 9940a0c..af0a74b 100644 --- a/utils/other_utils.go +++ b/utils/other_utils.go @@ -6,8 +6,5 @@ import ( func IsValidIP(ip string) bool { res := net.ParseIP(ip) - if res == nil { - return false - } - return true + return res != nil } -- cgit v1.2.3