summaryrefslogtreecommitdiff
path: root/core/dnsserver/server.go
blob: ad2d99141d1412b4daab2e65fc2acb3af3199484 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
package dnsserver

import (
	"context"
	"fmt"
	"github.com/coredns/caddy"
	"github.com/miekg/dns"
	ot "github.com/opentracing/opentracing-go"
	"net"
	"ohmydns2/plugin"
	"ohmydns2/plugin/pkg/edns"
	"ohmydns2/plugin/pkg/log"
	"ohmydns2/plugin/pkg/rcode"
	"ohmydns2/plugin/pkg/request"
	"ohmydns2/plugin/pkg/reuseport"
	"ohmydns2/plugin/pkg/trace"
	"ohmydns2/plugin/pkg/transport"
	"ohmydns2/plugin/prometheus/vars"
	"runtime"
	"runtime/debug"
	"strings"
	"sync"
	"time"
)

// Server represents an instance of a server, which serves
// DNS requests at a particular address (host and port). A
// server is capable of serving numerous zones on
// the same address and the listener may be stopped for
// graceful termination (POSIX only).
type Server struct {
	Addr string // Address we listen on

	server [2]*dns.Server // 0 is a net.Listener, 1 is a net.PacketConn (a *UDPConn) in our case.
	m      sync.Mutex     // protects the servers

	zones        map[string][]*Config // zones keyed by their address
	dnsWg        sync.WaitGroup       // used to wait on outstanding connections
	graceTimeout time.Duration        // the maximum duration of a graceful shutdown
	trace        trace.Trace          // the trace plugin for the server
	debug        bool                 // disable recover()
	stacktrace   bool                 // enable stacktrace in recover error log
	classChaos   bool                 // allow non-INET class queries
	idleTimeout  time.Duration        // Idle timeout for TCP
	readTimeout  time.Duration        // Read timeout for TCP
	writeTimeout time.Duration        // Write timeout for TCP

	tsigSecret map[string]string
}

// MetadataCollector is a plugin that can retrieve metadata functions from all metadata providing plugins
type MetadataCollector interface {
	Collect(context.Context, request.Request) context.Context
}

// NewServer returns a new OhmyDNS server and compiles all plugins in to it. By default CH class
// queries are blocked unless queries from enableChaos are loaded.
func NewServer(addr string, group []*Config) (*Server, error) {
	s := &Server{
		Addr:         addr,
		zones:        make(map[string][]*Config),
		graceTimeout: 5 * time.Second,
		idleTimeout:  10 * time.Second,
		readTimeout:  3 * time.Second,
		writeTimeout: 5 * time.Second,
		tsigSecret:   make(map[string]string),
	}
	log.Infof("Do53服务启动,监听地址: %v", addr)

	// We have to bound our wg with one increment
	// to prevent a "race condition" that is hard-coded
	// into sync.WaitGroup.Wait() - basically, an add
	// with a positive delta must be guaranteed to
	// occur before Wait() is called on the wg.
	// In a way, this kind of acts as a safety barrier.
	s.dnsWg.Add(1)

	for _, site := range group {
		if site.Debug {
			s.debug = true
			log.D.Set()
		}
		s.stacktrace = site.Stacktrace

		// append the config to the zone's configs
		s.zones[site.Zone] = append(s.zones[site.Zone], site)

		// set timeouts
		if site.ReadTimeout != 0 {
			s.readTimeout = site.ReadTimeout
		}
		if site.WriteTimeout != 0 {
			s.writeTimeout = site.WriteTimeout
		}
		if site.IdleTimeout != 0 {
			s.idleTimeout = site.IdleTimeout
		}

		// copy tsig secrets
		for key, secret := range site.TsigSecret {
			s.tsigSecret[key] = secret
		}

		// compile custom plugin for everything
		var stack plugin.Handler
		for i := len(site.Plugin) - 1; i >= 0; i-- {
			stack = site.Plugin[i](stack)

			// register the *handler* also
			site.registerHandler(stack)

			// If the current plugin is a MetadataCollector, bookmark it for later use. This loop traverses the plugin
			// list backwards, so the first MetadataCollector plugin wins.
			if mdc, ok := stack.(MetadataCollector); ok {
				site.metaCollector = mdc
			}

			if s.trace == nil && stack.Name() == "trace" {
				// we have to stash away the plugin, not the
				// Tracer object, because the Tracer won't be initialized yet
				if t, ok := stack.(trace.Trace); ok {
					s.trace = t
				}
			}
			// Unblock CH class queries when any of these plugins are loaded.
			if _, ok := EnableChaos[stack.Name()]; ok {
				s.classChaos = true
			}
		}
		site.pluginChain = stack
	}

	if !s.debug {
		// When reloading we need to explicitly disable debug logging if it is now disabled.
		log.D.Clear()
	}

	return s, nil
}

// Compile-time check to ensure Server implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &Server{}

// Serve starts the server with an existing listener. It blocks until the server stops.
// This implements caddy.TCPServer interface.
func (s *Server) Serve(l net.Listener) error {
	s.m.Lock()

	s.server[tcp] = &dns.Server{Listener: l,
		Net:           "tcp",
		TsigSecret:    s.tsigSecret,
		MaxTCPQueries: tcpMaxQueries,
		ReadTimeout:   s.readTimeout,
		WriteTimeout:  s.writeTimeout,
		IdleTimeout: func() time.Duration {
			return s.idleTimeout
		},
		Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
			ctx := context.WithValue(context.Background(), Key{}, s)
			ctx = context.WithValue(ctx, LoopKey{}, 0)
			s.ServeDNS(ctx, w, r)
		})}

	s.m.Unlock()

	return s.server[tcp].ActivateAndServe()
}

// ServePacket starts the server with an existing packetconn. It blocks until the server stops.
// This implements caddy.UDPServer interface.
func (s *Server) ServePacket(p net.PacketConn) error {
	s.m.Lock()
	s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
		ctx := context.WithValue(context.Background(), Key{}, s)
		ctx = context.WithValue(ctx, LoopKey{}, 0)
		s.ServeDNS(ctx, w, r)
	}), TsigSecret: s.tsigSecret}
	s.m.Unlock()

	return s.server[udp].ActivateAndServe()
}

// Listen implements caddy.TCPServer interface.
func (s *Server) Listen() (net.Listener, error) {
	l, err := reuseport.Listen("tcp", s.Addr[len(transport.DNS+"://"):])
	if err != nil {
		return nil, err
	}
	return l, nil
}

// WrapListener Listen implements caddy.GracefulServer interface.
func (s *Server) WrapListener(ln net.Listener) net.Listener {
	return ln
}

// ListenPacket implements caddy.UDPServer interface.
func (s *Server) ListenPacket() (net.PacketConn, error) {
	p, err := reuseport.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):])
	if err != nil {
		return nil, err
	}

	return p, nil
}

// Stop stops the server. It blocks until the server is
// totally stopped. On POSIX systems, it will wait for
// connections to close (up to a max timeout of a few
// seconds); on Windows it will close the listener
// immediately.
// This implements Caddy.Stopper interface.
func (s *Server) Stop() (err error) {
	if runtime.GOOS != "windows" {
		// force connections to close after timeout
		done := make(chan struct{})
		go func() {
			s.dnsWg.Done() // decrement our initial increment used as a barrier
			s.dnsWg.Wait()
			close(done)
		}()

		// Wait for remaining connections to finish or
		// force them all to close after timeout
		select {
		case <-time.After(s.graceTimeout):
		case <-done:
		}
	}

	// Close the listener now; this stops the server without delay
	s.m.Lock()
	for _, s1 := range s.server {
		// We might not have started and initialized the full set of servers
		if s1 != nil {
			err = s1.Shutdown()
		}
	}
	s.m.Unlock()
	return
}

// Address together with Stop() implement caddy.GracefulServer.
func (s *Server) Address() string { return s.Addr }

// ServeDNS is the entry point for every request to the address that
// is bound to. It acts as a multiplexer for the requests zonename as
// defined in the request so that the correct zone
// (configuration and plugin stack) will handle the request.
func (s *Server) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
	// The default dns.Mux checks the question section size, but we have our
	// own mux here. Check if we have a question section. If not drop them here.
	if r == nil || len(r.Question) == 0 {
		errorAndMetricsFunc(s.Addr, w, r, dns.RcodeServerFailure)
		return
	}

	if !s.debug {
		defer func() {
			// In case the user doesn't enable error plugin, we still
			// need to make sure that we stay alive up here
			if rec := recover(); rec != nil {
				if s.stacktrace {
					log.Errorf("Recovered from panic in server: %q %v\n%s", s.Addr, rec, string(debug.Stack()))
				} else {
					log.Errorf("Recovered from panic in server: %q %v", s.Addr, rec)
				}
				vars.Panic.Inc()
				errorAndMetricsFunc(s.Addr, w, r, dns.RcodeServerFailure)
			}
		}()
	}

	if !s.classChaos && r.Question[0].Qclass != dns.ClassINET {
		errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
		return
	}

	if m, err := edns.Version(r); err != nil { // Wrong EDNS version, return at once.
		w.WriteMsg(m)
		return
	}

	// Wrap the response writer in a ScrubWriter so we automatically make the reply fit in the client's buffer.
	w = request.NewScrubWriter(r, w)

	q := strings.ToLower(r.Question[0].Name)
	var (
		off       int
		end       bool
		dshandler *Config
	)

	for {
		if z, ok := s.zones[q[off:]]; ok {
			for _, h := range z {
				if h.pluginChain == nil { // zone defined, but has not got any plugins
					errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
					return
				}

				if h.metaCollector != nil {
					// Collect metadata now, so it can be used before we send a request down the plugin chain.
					ctx = h.metaCollector.Collect(ctx, request.Request{Req: r, W: w})
				}

				// If all filter funcs pass, use this config.
				if passAllFilterFuncs(ctx, h.FilterFuncs, &request.Request{Req: r, W: w}) {
					if h.ViewName != "" {
						// if there was a view defined for this Config, set the view name in the context
						ctx = context.WithValue(ctx, ViewKey{}, h.ViewName)
					}
					if r.Question[0].Qtype != dns.TypeDS {
						rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
						if !plugin.ClientWrite(rcode) {
							errorFunc(s.Addr, w, r, rcode)
						}
						return
					}
					// The type is DS, keep the handler, but keep on searching as maybe we are serving
					// the parent as well and the DS should be routed to it - this will probably *misroute* DS
					// queries to a possibly grand parent, but there is no way for us to know at this point
					// if there is an actual delegation from grandparent -> parent -> zone.
					// In all fairness: direct DS queries should not be needed.
					dshandler = h
				}
			}
		}
		off, end = dns.NextLabel(q, off)
		if end {
			break
		}
	}

	if r.Question[0].Qtype == dns.TypeDS && dshandler != nil && dshandler.pluginChain != nil {
		// DS request, and we found a zone, use the handler for the query.
		rcode, _ := dshandler.pluginChain.ServeDNS(ctx, w, r)
		if !plugin.ClientWrite(rcode) {
			errorFunc(s.Addr, w, r, rcode)
		}
		return
	}

	// Wildcard match, if we have found nothing try the root zone as a last resort.
	if z, ok := s.zones["."]; ok {
		for _, h := range z {
			if h.pluginChain == nil {
				continue
			}

			if h.metaCollector != nil {
				// Collect metadata now, so it can be used before we send a request down the plugin chain.
				ctx = h.metaCollector.Collect(ctx, request.Request{Req: r, W: w})
			}

			// If all filter funcs pass, use this config.
			if passAllFilterFuncs(ctx, h.FilterFuncs, &request.Request{Req: r, W: w}) {
				if h.ViewName != "" {
					// if there was a view defined for this Config, set the view name in the context
					ctx = context.WithValue(ctx, ViewKey{}, h.ViewName)
				}
				rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
				if !plugin.ClientWrite(rcode) {
					errorFunc(s.Addr, w, r, rcode)
				}
				return
			}
		}
	}

	// Still here? Error out with REFUSED.
	errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
}

// passAllFilterFuncs returns true if all filter funcs evaluate to true for the given request
func passAllFilterFuncs(ctx context.Context, filterFuncs []FilterFunc, req *request.Request) bool {
	for _, ff := range filterFuncs {
		if !ff(ctx, req) {
			return false
		}
	}
	return true
}

// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *Server) OnStartupComplete() {
	if Quiet {
		return
	}

	out := startUpZones("", s.Addr, s.zones)
	if out != "" {
		fmt.Print(out)
	}
}

// Tracer returns the tracer in the server if defined.
func (s *Server) Tracer() ot.Tracer {
	if s.trace == nil {
		return nil
	}

	return s.trace.Tracer()
}

// errorFunc responds to an DNS request with an error.
func errorFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int) {
	state := request.Request{W: w, Req: r}

	answer := new(dns.Msg)
	answer.SetRcode(r, rc)
	state.SizeAndDo(answer)

	w.WriteMsg(answer)
}

func errorAndMetricsFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int) {
	state := request.Request{W: w, Req: r}

	answer := new(dns.Msg)
	answer.SetRcode(r, rc)
	state.SizeAndDo(answer)

	vars.Report(server, state, vars.Dropped, "", rcode.ToString(rc), "" /* plugin */, answer.Len(), time.Now())

	w.WriteMsg(answer)
}

const (
	tcp = 0
	udp = 1

	tcpMaxQueries = -1
)

type (
	// Key is the context key for the current server added to the context.
	Key struct{}

	// LoopKey is the context key to detect server wide loops.
	LoopKey struct{}

	// ViewKey is the context key for the current view, if defined
	ViewKey struct{}
)

// EnableChaos is a map with plugin names for which we should open CH class queries as we block these by default.
var EnableChaos = map[string]struct{}{
	"chaos":   {},
	"forward": {},
	"proxy":   {},
}

// Quiet mode will not show any informative output on initialization.
var Quiet bool