summaryrefslogtreecommitdiff
path: root/injection_probe.go
blob: 44ae277ad163a0117bad281b91a1da98daf7c628 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package main

import (
	"flag"
	"fmt"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/miekg/dns"
	"github.com/rs/zerolog"
	"github.com/rs/zerolog/log"
)

func dnsQuery(ip, domain string) (*dns.Msg, error) {
	domain = dns.Fqdn(domain)
	addr := ip + ":53"
	m := new(dns.Msg)
	m.SetQuestion(domain, dns.TypeA)
	m.RecursionDesired = true
	res, err := dns.Exchange(m, addr)
	return res, err
}

func main() {
	domain := flag.String("domain", "", "Target domain")
	targetIP := flag.String("server", "", "server IP address")
	packetsPerSecond := flag.Int("speed", 0, "Packets per second")
	packetNum := flag.Int("num", 1000, "total packets")
	expectedResult := flag.String("truth", "", "Expected result")
	random_flag := flag.Bool("random", false, "random subdomain")
	flag.Parse()

	if *domain == "" || *targetIP == "" || *packetsPerSecond == 0 {
		flag.Usage()
		return
	}

	currentTime := strconv.FormatInt(time.Now().Unix(), 10)
	totalRequests := *packetNum
	responseTimes := make([]int64, 0, totalRequests)
	successfulRequests := 0
	validResponses := 0

	zerolog.TimeFieldFormat = zerolog.TimeFormatUnix

	var query_tasks sync.WaitGroup
	fmt.Println("testing......")
	for i := 0; i < totalRequests; i++ {
		query_tasks.Add(1)
		go func() {
			var prefix string
			start := time.Now()
			if *random_flag {
				prefix = strings.Join([]string{currentTime, strconv.Itoa(i), ""}, "-")
			} else {
				prefix = ""
			}
			query_dn := prefix + (*domain)
			res, err := dnsQuery(*targetIP, query_dn)
			elapsed := time.Since(start).Microseconds()
			if err == nil {
				successfulRequests += 1
				log.Log().Msg("gfw worked!")
				responseTimes = append(responseTimes, elapsed)
				if *expectedResult == "" {
					if *random_flag {
						if (len(res.Answer) != 1) || (res.MsgHdr.Rcode != dns.RcodeSuccess) {
							validResponses += 1
							log.Log().Msg("gfw escaped!")
							//fmt.Printf("answer: %d rcode: %d\n", len(res.Answer), res.MsgHdr.Rcode)
						}
					} else {
						if len(res.Answer) > 1 {
							validResponses += 1
							log.Log().Msg("gfw escaped!")
							//fmt.Println(elapsed)
						}
					}
				} else {
					if a, ok := res.Answer[0].(*dns.A); ok {
						//fmt.Println(a.A.String())
						if a.A.String() == *expectedResult {
							validResponses += 1
							log.Log().Msg("gfw escaped!")
							//fmt.Println(elapsed)
						}
					}
				}
			}

			query_tasks.Done()
		}()
		time.Sleep(time.Second / time.Duration(*packetsPerSecond))
	}

	query_tasks.Wait()

	var sum_rtt int64
	var max_rtt int64
	var min_rtt int64
	sum_rtt = 0
	max_rtt = 0
	min_rtt = 3000000
	for _, rtt := range responseTimes {
		sum_rtt += rtt
		if rtt < min_rtt {
			min_rtt = rtt
		}
		if rtt > max_rtt {
			max_rtt = rtt
		}
	}
	avg_rtt := float64(sum_rtt) / float64(len(responseTimes))

	valid_rate := float64(validResponses) / float64(successfulRequests)

	fmt.Println(avg_rtt, max_rtt, min_rtt, successfulRequests, validResponses, valid_rate)

}